Java Deep Learning Tutorial

So bauen Sie ein neuronales Netzwerk auf

14.01.2024 von Matthew Tyson
Um neuronale Netze wirklich zu verstehen, erstellen Sie am besten selbst eines. Dieses Tutorial zeigt Ihnen, wie das mit Java funktioniert.
Künstliche neuronale Netze bilden eine Säule moderner, künstlicher Intelligenz.
Foto: Sergey Tarasov - shutterstock.com

Künstliche neuronale Netze sind eine Form des Deep Learning. Der beste Weg, um ihre Funktionsweise vollständig zu durchdringen, besteht darin, sich selbst die Hände "schmutzig" zu machen. Dieser Artikel liefert Ihnen dafür die Grundlage und demonstriert, wie Sie ein neuronales Netzwerk in Java aufbauen und trainieren. Unser Beispiel für diesen Artikel ist dabei keineswegs ein produktionsreifes System - vielmehr gibt es in verständlicher Form Aufschluss über alle Hauptkomponenten.

Ein grundlegendes, neuronales Netz

Ein neuronales Netz ist ein Graph, bestehend aus Knoten (Nodes), die Neuronen genannt werden. Das Neuron ist die Grundeinheit der Berechnung: Es empfängt Eingaben und verarbeitet diese mithilfe:

Nachfolgende Abbildung zeigt ein Neuron mit zwei Inputs:

Ein neuronales Netz bestehend aus zwei Input-Neuronen.
Foto: Foundry / Matthew Tyson

Dieses Modell ist sehr variabel - im Folgenden verwenden wir diese Konfiguration.

Unser erster Schritt besteht darin, eine Neuron-Class zu modellieren, die diese Werte enthalten soll. Eine erste Version der Class sehen Sie im folgenden Listing 1 - diese wird sich im weiteren Verlauf verändern, wenn weitere Funktionen hinzukommen.

Listing 1: Eine einfache Neuron-Class

class Neuron {

Random random = new Random();

private Double bias = random.nextDouble(-1, 1);

public Double weight1 = random.nextDouble(-1, 1);

private Double weight2 = random.nextDouble(-1, 1);

public double compute(double input1, double input2){

double preActivation = (this.weight1 * input1) + (this.weight2 * input2) + this.bias;

double output = Util.sigmoid(preActivation);

return output;

}

}

Wie Sie sehen, ist die Neuron-Class recht simpel und weist drei Mitglieder auf: bias, weight1 und weight2. Jedes dieser Mitglieder wird mit einem zufälligen Double zwischen -1 und 1 initialisiert.

Geht es darum, den Output des Neurons zu berechnen, folgen wir dem in Abbildung 1 gezeigten Algorithmus: Wir multiplizieren jede Eingabe mit ihrer Gewichtung plus dem Bias: input1 * weight1 + input2 * weight2 + bias. So erhalten wir die unverarbeitete Berechnung (preActivation), die wir durch die Aktivierungsfunktion laufen lassen. In diesem Fall verwenden wir die Sigmoid-Aktivierungsfunktion, die Werte in einem Bereich von -1 bis 1 komprimiert. Im Folgenden die statische Util.sigmoid()-Methode.

Listing 2: Sigmoid-Aktivierungsfunktion

public class Util {

public static double sigmoid(double in){

return 1 / (1 + Math.exp(-in));

}

}

Nachdem wir nun die Funktionsweise von Neuronen beleuchtet haben, gilt es, einige Neuronen in ein Netzwerk einfügen. Dazu nutzen wir eine Network-Class mit einer Liste von Neuronen.

Listing 3: Die Neural Network Class

class Network {

List<Neuron> neurons = Arrays.asList(

new Neuron(), new Neuron(), new Neuron(), /* input nodes */

new Neuron(), new Neuron(), /* hidden nodes */

new Neuron()); /* output node */

}

}

Obwohl die Liste der Neuronen eindimensional ist, werden wir sie während der Nutzung zu einem Netzwerk verbinden. Die ersten drei Neuronen sind Inputs, die folgenden beiden versteckt und das letzte der Output-Knoten.

Eine Prediction anstoßen

Nun soll es darum gehen, ein Netzwerk zu Prediction-Zwecken einzusetzen. Dazu verwenden wir einen einfachen Datensatz mit zwei ganzzahligen Inputs und einem Antwortformat von 0 bis 1. In unserem Beispiel wird eine Kombination aus Gewicht und Größe verwendet, um das Geschlecht einer Person zu erraten.

Dabei wird davon ausgegangen, dass mehr Gewicht und Größe auf eine männliche Person hindeuten. Dieselbe Formel ließe sich für jede beliebige Wahrscheinlichkeitsrechnung mit zwei Faktoren und einem Output nutzen. Den Input könnte man auch als Vektor betrachten - und somit die Gesamtfunktion der Neuronen als Umwandlung eines Vektors in einem Skalarwert. Die Prediction-Phase des Netzes gestaltet sich wie folgt.

Listing 4: Network prediction

public Double predict(Integer input1, Integer input2){

return neurons.get(5).compute(

neurons.get(4).compute(

neurons.get(2).compute(input1, input2),

neurons.get(1).compute(input1, input2)

),

neurons.get(3).compute(

neurons.get(1).compute(input1, input2),

neurons.get(0).compute(input1, input2)

)

);

}

Listing 4 zeigt, dass die beiden Inputs an die ersten drei Neuronen fließen. Deren Outputs werden an die Neuronen 4 und 5 weitergeleitet wird, die wiederum in das Output-Neuron einspeisen. Dieser Prozess wird als Feedforward bezeichnet. Nun könnten wir das Netz zu einer Prediction auffordern.

Listing 5: Prediction

Network network = new Network();

Double prediction = network.predict(Arrays.asList(115, 66));

System.out.println("prediction: " + prediction);

Das würde sicher zu Ergebnissen führen - die allerdings nur auf Zufallswerten und Bias basieren. Für eine echte Prediction ist es nötig, das Netzwerk zuvor zu trainieren.

Das Netzwerk trainieren

Das Training eines neuronalen Netzwerks folgt einem Prozess, der als Backpropagation bekannt ist. Der beinhaltet im Grunde, Änderungen rückwärts durch das Netzwerk zu "schieben", damit sich der Output in Richtung eines gewünschten Zielwerts bewegt. Backpropagation lässt sich mit Hilfe von Funktionsdifferenzierung durchführen - für unser Beispiel werden wir allerdings einen anderen Weg gehen und jedem Neuron die Fähigkeit verleihen, zu "mutieren".

In jeder Trainingsrunde (auch Epoch genannt) wählen wir ein anderes Neuron aus, um eine kleine, zufällige Anpassung an einer seiner Eigenschaften (weight1, weight2 oder bias) vorzunehmen und dann zu prüfen, ob sich die Ergebnisse verbessern. Ist das der Fall, behalten wir diese Änderung mit einer remember()-Methode bei. Wenn sich die Ergebnisse verschlechtert haben, machen wir sie mit einer forget()-Methode rückgängig.

Um die Änderungen zu tracken, fügen wir Class-Mitglieder hinzu (old*-Versionen von weights und bias). Im Folgenden betrachten wir die Methoden mutate(), remember() und forget().

Listing 6: mutate(), remember(), forget()

public class Neuron() {

private Double oldBias = random.nextDouble(-1, 1), bias = random.nextDouble(-1, 1);

public Double oldWeight1 = random.nextDouble(-1, 1), weight1 = random.nextDouble(-1, 1);

private Double oldWeight2 = random.nextDouble(-1, 1), weight2 = random.nextDouble(-1, 1);

public void mutate(){

int propertyToChange = random.nextInt(0, 3);

Double changeFactor = random.nextDouble(-1, 1);

if (propertyToChange == 0){

this.bias += changeFactor;

} else if (propertyToChange == 1){

this.weight1 += changeFactor;

} else {

this.weight2 += changeFactor;

};

}

public void forget(){

bias = oldBias;

weight1 = oldWeight1;

weight2 = oldWeight2;

}

public void remember(){

oldBias = bias;

oldWeight1 = weight1;

oldWeight2 = weight2;

}

}

Zusammengefasst:

Um nun die neuen Fähigkeiten unseres Neurons zu nutzen, fügen wir Network eine train()-Methode hinzu.

Listing 7: Die Network.train()-Methode

public void train(List<List<Integer>> data, List<Double> answers){

Double bestEpochLoss = null;

for (int epoch = 0; epoch < 1000; epoch++){

// adapt neuron

Neuron epochNeuron = neurons.get(epoch % 6);

List<Double> predictions = new ArrayList<Double>();

for (int i = 0; i < data.size(); i++){

predictions.add(i, this.predict(data.get(i).get(0), data.get(i).get(1)));

}

Double thisEpochLoss = Util.meanSquareLoss(answers, predictions);

if (bestEpochLoss == null){

bestEpochLoss = thisEpochLoss;

epochNeuron.remember();

} else {

if (thisEpochLoss < bestEpochLoss){

bestEpochLoss = thisEpochLoss;

epochNeuron.remember();

} else {

epochNeuron.forget();

}

}

}

Die train()-Methode iteriert eintausendmal über die aufgeführten data, answers und lists. Es handelt sich um gleich große Trainingsmengen: data beinhaltet Input-Werte, answers die bekannten, richtigen Antworten. Die Methode ermittelt dann einen Wert darüber, wie nahe das Ergebnis des Netzwerks den bekannten, richtigen Antworten kommt. Dann wird ein zufälliges Neuron verändert (mutiert), wobei die Änderung beibehalten wird, wenn ein neuer Test ergibt, dass sie eine bessere Vorhersage zur Folge hatte.

Die Ergebnisse lassen sich mithilfe der Mean-Squared-Error (MSE) -Formel überprüfen - einer dafür gängigen Methode.

Listing 8: MSE-Funktion

public static Double meanSquareLoss(List<Double> correctAnswers, List<Double> predictedAnswers){

double sumSquare = 0;

for (int i = 0; i < correctAnswers.size(); i++){

double error = correctAnswers.get(i) - predictedAnswers.get(i);

sumSquare += (error * error);

}

return sumSquare / (correctAnswers.size());

}

System feinabstimmen

Nun müssen wir nur noch einige Trainingsdaten in das Netz fließen lassen und es mit weiteren Predictions austesten. Im Folgenden betrachten wir, wie man Trainingsdaten bereitstellt.

Listing 9: Trainingsdaten

List<List<Integer>> data = new ArrayList<List<Integer>>();

data.add(Arrays.asList(115, 66));

data.add(Arrays.asList(175, 78));

data.add(Arrays.asList(205, 72));

data.add(Arrays.asList(120, 67));

List<Double> answers = Arrays.asList(1.0,0.0,0.0,1.0);

Network network = new Network();

network.train(data, answers);

In Listing 9 bestehen unsere Trainingsdaten aus einer Liste von zweidimensionalen Integer-Sets (wir könnten sie uns als Gewicht und Größe vorstellen) und einer Liste von Antworten (wobei 1.0 weiblich und 0.0 männlich ist).

Wenn wir den Trainingsalgorithmus eine Logging-Funktionalität hinzufügen, erhalten wir nachfolgendes Resultat.

Listing 10. Trainingsdaten-Logging

// Logging:

if (epoch % 10 == 0) System.out.println(String.format("Epoch: %s | bestEpochLoss: %.15f | thisEpochLoss: %.15f", epoch, bestEpochLoss, thisEpochLoss));

// output:

Epoch: 910 | bestEpochLoss: 0.034404863820424 | thisEpochLoss: 0.034437939546120

Epoch: 920 | bestEpochLoss: 0.033875954196897 | thisEpochLoss: 0.431451026477016

Epoch: 930 | bestEpochLoss: 0.032509260025490 | thisEpochLoss: 0.032509260025490

Epoch: 940 | bestEpochLoss: 0.003092720117159 | thisEpochLoss: 0.003098025397281

Epoch: 950 | bestEpochLoss: 0.002990128276146 | thisEpochLoss: 0.431062364628853

Epoch: 960 | bestEpochLoss: 0.001651762688346 | thisEpochLoss: 0.001651762688346

Epoch: 970 | bestEpochLoss: 0.001637709485751 | thisEpochLoss: 0.001636810460399

Epoch: 980 | bestEpochLoss: 0.001083365453009 | thisEpochLoss: 0.391527869500699

Epoch: 990 | bestEpochLoss: 0.001078338540452 | thisEpochLoss: 0.001078338540452

Wie in Listing 10 zu sehen, nimmt der "Loss" (also die Fehlerabweichung von "100 Prozent korrekt") langsam ab. Das Modell nähert sich also immer mehr einer genauen Vorhersage an. Nun gilt es zu überprüfen, wie gut unser Modell mit echten Daten funktioniert.

Listing 11: Vorhersagen

System.out.println("");

System.out.println(String.format(" male, 167, 73: %.10f", network.predict(167, 73)));

System.out.println(String.format("female, 105, 67: %.10", network.predict(105, 67)));

System.out.println(String.format("female, 120, 72: %.10f | network1000: %.10f", network.predict(120, 72)));

System.out.println(String.format(" male, 143, 67: %.10f | network1000: %.10f", network.predict(143, 67)));

System.out.println(String.format(" male', 130, 66: %.10f | network: %.10f", network.predict(130, 66)));

Wie in Listing 11 zu sehen, füttern wir unser trainiertes neuronales Netz mit Daten und geben die Vorhersagen aus. Das Resultat sieht in etwa wie folgt aus.

Listing 12: Trainierte Predictions

male, 167, 73: 0.0279697143

female, 105, 67: 0.9075809407

female, 120, 72: 0.9075808235

male, 143, 67: 0.0305401413

male, 130, 66: network: 0.9009811922

Listing 12 zeigt, dass das Netzwerk bei den meisten Wertepaaren (Vektoren) ziemlich gute Arbeit geleistet hat. Es gibt den weiblichen Datensätzen eine Schätzung um 0.907 - was ziemlich nahe an 1 ist. Zwei männliche Datensätze weisen 0.027 und 0.030 auf - und nähern sich damit der 0. Der männliche Ausreißer-Datensatz (130, 67) wird als "wahrscheinlich weiblich" angesehen, bei einem Wert von 0.900 allerdings mit geringerer Zuversicht.

Es gibt eine Reihe von Möglichkeiten, die Einstellungen an diesem System zu verändern: Die Anzahl der Epochs in einem Trainingslauf ist dabei ein wichtiger Faktor. Je mehr Epochs, desto besser wird das Modell auf die Daten abgestimmt. Das kann auch die Genauigkeit von Live-Daten verbessern, die mit den Trainingssätzen übereinstimmen. Allerdings kann es auch in einem "Overtraining" resultieren - also einem Modell, das zuversichtlich die falschen Ergebnisse für Randfälle vorhersagt.

Den vollständigen Code für dieses Tutorial finden Sie in diesem GitHub-Repository - zusammen mit einigen zusätzlichen Funktionen. (fm)

Dieser Beitrag basiert auf einem Artikel unserer US-Schwesterpublikation Infoworld.