Search in sources :

Example 1 with SimpleMLPLocalBatchTrainerInput

use of org.apache.ignite.ml.nn.SimpleMLPLocalBatchTrainerInput in project ignite by apache.

the class MnistLocal method tstMNISTLocal.

/**
 * Run nn classifier on MNIST using bi-indexed cache as a storage for dataset.
 * To run this test rename this method so it starts from 'test'.
 *
 * @throws IOException In case of loading MNIST dataset errors.
 */
@Test
public void tstMNISTLocal() throws IOException {
    int samplesCnt = 60_000;
    int featCnt = 28 * 28;
    int hiddenNeuronsCnt = 100;
    IgniteBiTuple<Stream<DenseLocalOnHeapVector>, Stream<DenseLocalOnHeapVector>> trainingAndTest = loadMnist(samplesCnt);
    Stream<DenseLocalOnHeapVector> trainingMnistStream = trainingAndTest.get1();
    Stream<DenseLocalOnHeapVector> testMnistStream = trainingAndTest.get2();
    IgniteBiTuple<Matrix, Matrix> ds = createDataset(trainingMnistStream, samplesCnt, featCnt);
    IgniteBiTuple<Matrix, Matrix> testDs = createDataset(testMnistStream, 10000, featCnt);
    MLPArchitecture conf = new MLPArchitecture(featCnt).withAddedLayer(hiddenNeuronsCnt, true, Activators.SIGMOID).withAddedLayer(10, false, Activators.SIGMOID);
    SimpleMLPLocalBatchTrainerInput input = new SimpleMLPLocalBatchTrainerInput(conf, new Random(), ds.get1(), ds.get2(), 2000);
    MultilayerPerceptron mdl = new MLPLocalBatchTrainer<>(LossFunctions.MSE, () -> new RPropUpdateCalculator(0.1, 1.2, 0.5), 1E-7, 200).train(input);
    X.println("Training started");
    long before = System.currentTimeMillis();
    X.println("Training finished in " + (System.currentTimeMillis() - before));
    Vector predicted = mdl.apply(testDs.get1()).foldColumns(VectorUtils::vec2Num);
    Vector truth = testDs.get2().foldColumns(VectorUtils::vec2Num);
    Tracer.showAscii(truth);
    Tracer.showAscii(predicted);
    X.println("Accuracy: " + VectorUtils.zipWith(predicted, truth, (x, y) -> x.equals(y) ? 1.0 : 0.0).sum() / truth.size() * 100 + "%.");
}
Also used : SimpleMLPLocalBatchTrainerInput(org.apache.ignite.ml.nn.SimpleMLPLocalBatchTrainerInput) VectorUtils(org.apache.ignite.ml.math.VectorUtils) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) RPropUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator) MultilayerPerceptron(org.apache.ignite.ml.nn.MultilayerPerceptron) Matrix(org.apache.ignite.ml.math.Matrix) Random(java.util.Random) Stream(java.util.stream.Stream) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) Vector(org.apache.ignite.ml.math.Vector) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) Test(org.junit.Test)

Aggregations

Random (java.util.Random)1 Stream (java.util.stream.Stream)1 Matrix (org.apache.ignite.ml.math.Matrix)1 Vector (org.apache.ignite.ml.math.Vector)1 VectorUtils (org.apache.ignite.ml.math.VectorUtils)1 DenseLocalOnHeapVector (org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)1 MultilayerPerceptron (org.apache.ignite.ml.nn.MultilayerPerceptron)1 SimpleMLPLocalBatchTrainerInput (org.apache.ignite.ml.nn.SimpleMLPLocalBatchTrainerInput)1 MLPArchitecture (org.apache.ignite.ml.nn.architecture.MLPArchitecture)1 RPropUpdateCalculator (org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator)1 Test (org.junit.Test)1