Search in sources :

Example 1 with CustomLayer

use of org.deeplearning4j.spark.impl.customlayer.layer.CustomLayer in project deeplearning4j by deeplearning4j.

the class TestCustomLayer method testSparkWithCustomLayer.

@Test
public void testSparkWithCustomLayer() {
    //Basic test - checks whether exceptions etc are thrown with custom layers + spark
    //Custom layers are tested more extensively in dl4j core
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(0.1).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new CustomLayer(3.14159)).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(10).nOut(10).build()).pretrain(false).backprop(true).build();
    ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1).averagingFrequency(2).batchSizePerWorker(5).saveUpdater(true).workerPrefetchNumBatches(0).build();
    SparkDl4jMultiLayer net = new SparkDl4jMultiLayer(sc, conf, tm);
    List<DataSet> testData = new ArrayList<>();
    Random r = new Random(12345);
    for (int i = 0; i < 200; i++) {
        INDArray f = Nd4j.rand(1, 10);
        INDArray l = Nd4j.zeros(1, 10);
        l.putScalar(0, r.nextInt(10), 1.0);
        testData.add(new DataSet(f, l));
    }
    JavaRDD<DataSet> rdd = sc.parallelize(testData);
    net.fit(rdd);
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) CustomLayer(org.deeplearning4j.spark.impl.customlayer.layer.CustomLayer) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) Random(java.util.Random) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SparkDl4jMultiLayer(org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer) Test(org.junit.Test) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest)

Aggregations

ArrayList (java.util.ArrayList)1 Random (java.util.Random)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)1 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)1 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)1 CustomLayer (org.deeplearning4j.spark.impl.customlayer.layer.CustomLayer)1 SparkDl4jMultiLayer (org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer)1 ParameterAveragingTrainingMaster (org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 DataSet (org.nd4j.linalg.dataset.DataSet)1