Search in sources :

Example 1 with Add

use of org.deeplearning4j.spark.impl.common.Add in project deeplearning4j by deeplearning4j.

the class AddTest method testAdd.

@Test
public void testAdd() {
    List<INDArray> list = new ArrayList<>();
    for (int i = 0; i < 5; i++) list.add(Nd4j.ones(5));
    JavaRDD<INDArray> rdd = sc.parallelize(list);
    INDArray sum = rdd.fold(Nd4j.zeros(5), new Add());
    assertEquals(25, sum.sum(Integer.MAX_VALUE).getDouble(0), 1e-1);
}
Also used : Add(org.deeplearning4j.spark.impl.common.Add) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ArrayList(java.util.ArrayList) Test(org.junit.Test) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest)

Example 2 with Add

use of org.deeplearning4j.spark.impl.common.Add in project deeplearning4j by deeplearning4j.

the class SparkDl4jLayer method fitDataSet.

/**
     * Fit a java rdd of dataset
     * @param rdd the rdd to fit
     * @return the fit layer
     */
public Layer fitDataSet(JavaRDD<DataSet> rdd) {
    int iterations = conf.getNumIterations();
    long count = rdd.count();
    log.info("Running distributed training averaging each iteration " + averageEachIteration + " and " + rdd.partitions().size() + " partitions");
    if (!averageEachIteration) {
        int numParams = conf.getLayer().initializer().numParams(conf);
        final INDArray params = Nd4j.create(1, numParams);
        Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true);
        layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams));
        //            final INDArray params = layer.params();
        this.params = sc.broadcast(params);
        log.info("Broadcasting initial parameters of length " + params.length());
        int paramsLength = layer.numParams();
        if (params.length() != paramsLength)
            throw new IllegalStateException("Number of params " + paramsLength + " was not equal to " + params.length());
        JavaRDD<INDArray> results = rdd.sample(true, 0.4).mapPartitions(new IterativeReduceFlatMap(conf.toJson(), this.params));
        log.debug("Ran iterative reduce...averaging results now.");
        INDArray newParams = results.fold(Nd4j.zeros(results.first().shape()), new Add());
        newParams.divi(rdd.partitions().size());
        layer.setParams(newParams);
        this.layer = layer;
    } else {
        conf.setNumIterations(1);
        int numParams = conf.getLayer().initializer().numParams(conf);
        final INDArray params = Nd4j.create(1, numParams);
        Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true);
        layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams));
        //            final INDArray params = layer.params();
        this.params = sc.broadcast(params);
        for (int i = 0; i < iterations; i++) {
            JavaRDD<INDArray> results = rdd.sample(true, 0.3).mapPartitions(new IterativeReduceFlatMap(conf.toJson(), this.params));
            int paramsLength = layer.numParams();
            if (params.length() != paramsLength)
                throw new IllegalStateException("Number of params " + paramsLength + " was not equal to " + params.length());
            INDArray newParams = results.fold(Nd4j.zeros(results.first().shape()), new Add());
            newParams.divi(rdd.partitions().size());
        }
        layer.setParams(this.params.value().dup());
        this.layer = layer;
    }
    return layer;
}
Also used : Add(org.deeplearning4j.spark.impl.common.Add) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Layer(org.deeplearning4j.nn.api.Layer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) LabeledPoint(org.apache.spark.mllib.regression.LabeledPoint)

Aggregations

Add (org.deeplearning4j.spark.impl.common.Add)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 ArrayList (java.util.ArrayList)1 LabeledPoint (org.apache.spark.mllib.regression.LabeledPoint)1 Layer (org.deeplearning4j.nn.api.Layer)1 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)1 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)1 Test (org.junit.Test)1