use of org.deeplearning4j.spark.impl.common.Add in project deeplearning4j by deeplearning4j.
the class AddTest method testAdd.
@Test
public void testAdd() {
List<INDArray> list = new ArrayList<>();
for (int i = 0; i < 5; i++) list.add(Nd4j.ones(5));
JavaRDD<INDArray> rdd = sc.parallelize(list);
INDArray sum = rdd.fold(Nd4j.zeros(5), new Add());
assertEquals(25, sum.sum(Integer.MAX_VALUE).getDouble(0), 1e-1);
}
use of org.deeplearning4j.spark.impl.common.Add in project deeplearning4j by deeplearning4j.
the class SparkDl4jLayer method fitDataSet.
/**
* Fit a java rdd of dataset
* @param rdd the rdd to fit
* @return the fit layer
*/
public Layer fitDataSet(JavaRDD<DataSet> rdd) {
int iterations = conf.getNumIterations();
long count = rdd.count();
log.info("Running distributed training averaging each iteration " + averageEachIteration + " and " + rdd.partitions().size() + " partitions");
if (!averageEachIteration) {
int numParams = conf.getLayer().initializer().numParams(conf);
final INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true);
layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams));
// final INDArray params = layer.params();
this.params = sc.broadcast(params);
log.info("Broadcasting initial parameters of length " + params.length());
int paramsLength = layer.numParams();
if (params.length() != paramsLength)
throw new IllegalStateException("Number of params " + paramsLength + " was not equal to " + params.length());
JavaRDD<INDArray> results = rdd.sample(true, 0.4).mapPartitions(new IterativeReduceFlatMap(conf.toJson(), this.params));
log.debug("Ran iterative reduce...averaging results now.");
INDArray newParams = results.fold(Nd4j.zeros(results.first().shape()), new Add());
newParams.divi(rdd.partitions().size());
layer.setParams(newParams);
this.layer = layer;
} else {
conf.setNumIterations(1);
int numParams = conf.getLayer().initializer().numParams(conf);
final INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true);
layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams));
// final INDArray params = layer.params();
this.params = sc.broadcast(params);
for (int i = 0; i < iterations; i++) {
JavaRDD<INDArray> results = rdd.sample(true, 0.3).mapPartitions(new IterativeReduceFlatMap(conf.toJson(), this.params));
int paramsLength = layer.numParams();
if (params.length() != paramsLength)
throw new IllegalStateException("Number of params " + paramsLength + " was not equal to " + params.length());
INDArray newParams = results.fold(Nd4j.zeros(results.first().shape()), new Add());
newParams.divi(rdd.partitions().size());
}
layer.setParams(this.params.value().dup());
this.layer = layer;
}
return layer;
}
Aggregations