Search in sources :

Example 1 with DistractedSequenceRecall

use of jcog.learn.lstm.DistractedSequenceRecall in project narchy by automenta.

the class TestLSTM1 method testLSTM1.

@Test
public void testLSTM1() {
    // System.out.println("Test of SimpleLSTM\n");
    Random r = new XorShift128PlusRandom(1234);
    DistractedSequenceRecall task = new DistractedSequenceRecall(r, 12, 3, 22, 1000);
    int cell_blocks = 4;
    // double learningRate = 0.05;
    SimpleLSTM slstm = task.lstm(cell_blocks);
    int epochs = 150;
    double error = 0;
    for (int epoch = 0; epoch < epochs; epoch++) {
        double fit = task.scoreSupervised(slstm, 0.1f);
        error = 1 - fit;
        if (epoch % 10 == 0)
            System.out.println("[" + epoch + "] error = " + error);
    }
    // System.out.println("done.");
    assertTrue(error < 0.01f);
}
Also used : XorShift128PlusRandom(jcog.math.random.XorShift128PlusRandom) Random(java.util.Random) XorShift128PlusRandom(jcog.math.random.XorShift128PlusRandom) SimpleLSTM(jcog.learn.lstm.SimpleLSTM) DistractedSequenceRecall(jcog.learn.lstm.DistractedSequenceRecall) Test(org.junit.jupiter.api.Test)

Example 2 with DistractedSequenceRecall

use of jcog.learn.lstm.DistractedSequenceRecall in project narchy by automenta.

the class LSTMView method main.

public static void main(String[] arg) {
    Random r = new XorShift128PlusRandom(1234);
    DistractedSequenceRecall task = new DistractedSequenceRecall(r, 32, 8, 8, 100);
    int cell_blocks = 16;
    SimpleLSTM lstm = task.lstm(cell_blocks);
    float lr = 0.1f;
    // initialize
    task.scoreSupervised(lstm, lr);
    SpaceGraph.window(new LSTMView(lstm), 800, 800);
    int epochs = 5000;
    for (int epoch = 0; epoch < epochs; epoch++) {
        double fit = task.scoreSupervised(lstm, lr);
        if (epoch % 10 == 0)
            System.out.println("[" + epoch + "] error = " + (1 - fit));
        Util.sleep(1);
    }
    System.out.println("done.");
}
Also used : XorShift128PlusRandom(jcog.math.random.XorShift128PlusRandom) Random(java.util.Random) XorShift128PlusRandom(jcog.math.random.XorShift128PlusRandom) SimpleLSTM(jcog.learn.lstm.SimpleLSTM) DistractedSequenceRecall(jcog.learn.lstm.DistractedSequenceRecall)

Aggregations

Random (java.util.Random)2 DistractedSequenceRecall (jcog.learn.lstm.DistractedSequenceRecall)2 SimpleLSTM (jcog.learn.lstm.SimpleLSTM)2 XorShift128PlusRandom (jcog.math.random.XorShift128PlusRandom)2 Test (org.junit.jupiter.api.Test)1