use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ScoreExamplesFunctionAdapter method call.
@Override
public Iterable<Double> call(Iterator<DataSet> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}
MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue()));
network.init();
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParameters(val);
List<Double> ret = new ArrayList<>();
List<DataSet> collect = new ArrayList<>(batchSize);
int totalCount = 0;
while (iterator.hasNext()) {
collect.clear();
int nExamples = 0;
while (iterator.hasNext() && nExamples < batchSize) {
DataSet ds = iterator.next();
int n = ds.numExamples();
collect.add(ds);
nExamples += n;
}
totalCount += nExamples;
DataSet data = DataSet.merge(collect);
INDArray scores = network.scoreExamples(data, addRegularization);
double[] doubleScores = scores.data().asDouble();
for (double doubleScore : doubleScores) {
ret.add(doubleScore);
}
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
if (log.isDebugEnabled()) {
log.debug("Scored {} examples ", totalCount);
}
return ret;
}
use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ScoreFlatMapFunctionAdapter method call.
@Override
public Iterable<Tuple2<Integer, Double>> call(Iterator<DataSet> dataSetIterator) throws Exception {
if (!dataSetIterator.hasNext()) {
return Collections.singletonList(new Tuple2<>(0, 0.0));
}
//Does batching where appropriate
DataSetIterator iter = new IteratorDataSetIterator(dataSetIterator, minibatchSize);
MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json));
network.init();
//.value() object will be shared by all executors on each machine -> OK, as params are not modified by score function
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParameters(val);
List<Tuple2<Integer, Double>> out = new ArrayList<>();
while (iter.hasNext()) {
DataSet ds = iter.next();
double score = network.score(ds, false);
int numExamples = ds.getFeatureMatrix().size(0);
out.add(new Tuple2<>(numExamples, score * numExamples));
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
return out;
}
use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingMaster method processResults.
private void processResults(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<ParameterAveragingTrainingResult> results, int splitNum, int totalSplits) {
if (collectTrainingStats)
stats.logAggregateStartTime();
ParameterAveragingAggregationTuple tuple = results.aggregate(null, new ParameterAveragingElementAddFunction(), new ParameterAveragingElementCombineFunction());
INDArray params = tuple.getParametersSum();
int aggCount = tuple.getAggregationsCount();
SparkTrainingStats aggregatedStats = tuple.getSparkTrainingStats();
if (collectTrainingStats)
stats.logAggregationEndTime();
if (collectTrainingStats)
stats.logProcessParamsUpdaterStart();
if (params != null) {
params.divi(aggCount);
INDArray updaterState = tuple.getUpdaterStateSum();
if (updaterState != null)
//May be null if all SGD updaters, for example
updaterState.divi(aggCount);
if (network != null) {
MultiLayerNetwork net = network.getNetwork();
net.setParameters(params);
if (updaterState != null)
net.getUpdater().setStateViewArray(null, updaterState, false);
network.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
} else {
ComputationGraph g = graph.getNetwork();
g.setParams(params);
if (updaterState != null)
g.getUpdater().setStateViewArray(updaterState);
graph.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
}
} else {
log.info("Skipping imbalanced split with no data for all executors");
}
if (collectTrainingStats) {
stats.logProcessParamsUpdaterEnd();
stats.addWorkerStats(aggregatedStats);
}
if (statsStorage != null) {
Collection<StorageMetaData> meta = tuple.getListenerMetaData();
if (meta != null && meta.size() > 0) {
statsStorage.putStorageMetaData(meta);
}
Collection<Persistable> staticInfo = tuple.getListenerStaticInfo();
if (staticInfo != null && staticInfo.size() > 0) {
statsStorage.putStaticInfo(staticInfo);
}
Collection<Persistable> updates = tuple.getListenerUpdates();
if (updates != null && updates.size() > 0) {
statsStorage.putUpdate(updates);
}
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
log.info("Completed training of split {} of {}", splitNum, totalSplits);
if (params != null) {
//Params may be null for edge case (empty RDD)
if (network != null) {
MultiLayerConfiguration conf = network.getNetwork().getLayerWiseConfigurations();
int numUpdates = network.getNetwork().conf().getNumIterations() * averagingFrequency;
conf.setIterationCount(conf.getIterationCount() + numUpdates);
} else {
ComputationGraphConfiguration conf = graph.getNetwork().getConfiguration();
int numUpdates = graph.getNetwork().conf().getNumIterations() * averagingFrequency;
conf.setIterationCount(conf.getIterationCount() + numUpdates);
}
}
}
use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ScoreExamplesWithKeyFunctionAdapter method call.
@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, MultiDataSet>> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue()));
network.init();
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParams(val);
List<Tuple2<K, Double>> ret = new ArrayList<>();
List<MultiDataSet> collect = new ArrayList<>(batchSize);
List<K> collectKey = new ArrayList<>(batchSize);
int totalCount = 0;
while (iterator.hasNext()) {
collect.clear();
collectKey.clear();
int nExamples = 0;
while (iterator.hasNext() && nExamples < batchSize) {
Tuple2<K, MultiDataSet> t2 = iterator.next();
MultiDataSet ds = t2._2();
int n = ds.getFeatures(0).size(0);
if (n != 1)
throw new IllegalStateException("Cannot score examples with one key per data set if " + "data set contains more than 1 example (numExamples: " + n + ")");
collect.add(ds);
collectKey.add(t2._1());
nExamples += n;
}
totalCount += nExamples;
MultiDataSet data = org.nd4j.linalg.dataset.MultiDataSet.merge(collect);
INDArray scores = network.scoreExamples(data, addRegularization);
double[] doubleScores = scores.data().asDouble();
for (int i = 0; i < doubleScores.length; i++) {
ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
}
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
if (log.isDebugEnabled()) {
log.debug("Scored {} examples ", totalCount);
}
return ret;
}
use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getInitialModelGraph.
@Override
public ComputationGraph getInitialModelGraph() {
if (configuration.isCollectTrainingStats())
stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();
if (configuration.isCollectTrainingStats())
stats.logBroadcastGetValueStart();
NetBroadcastTuple tuple = broadcast.getValue();
if (configuration.isCollectTrainingStats())
stats.logBroadcastGetValueEnd();
//Don't want to have shared configuration object: each may update its iteration count (for LR schedule etc) individually
ComputationGraph net = new ComputationGraph(tuple.getGraphConfiguration().clone());
//Can't have shared parameter array across executors for parameter averaging, hence the 'true' for clone parameters array arg
net.init(tuple.getParameters().unsafeDuplication(), false);
if (tuple.getUpdaterState() != null) {
//Again: can't have shared updater state
net.setUpdater(new ComputationGraphUpdater(net, tuple.getUpdaterState().unsafeDuplication()));
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
configureListeners(net, tuple.getCounter().getAndIncrement());
if (configuration.isCollectTrainingStats())
stats.logInitEnd();
return net;
}
Aggregations