use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ScoreExamplesWithKeyFunctionAdapter method call.
@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, MultiDataSet>> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue()));
network.init();
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParams(val);
List<Tuple2<K, Double>> ret = new ArrayList<>();
List<MultiDataSet> collect = new ArrayList<>(batchSize);
List<K> collectKey = new ArrayList<>(batchSize);
int totalCount = 0;
while (iterator.hasNext()) {
collect.clear();
collectKey.clear();
int nExamples = 0;
while (iterator.hasNext() && nExamples < batchSize) {
Tuple2<K, MultiDataSet> t2 = iterator.next();
MultiDataSet ds = t2._2();
int n = ds.getFeatures(0).size(0);
if (n != 1)
throw new IllegalStateException("Cannot score examples with one key per data set if " + "data set contains more than 1 example (numExamples: " + n + ")");
collect.add(ds);
collectKey.add(t2._1());
nExamples += n;
}
totalCount += nExamples;
MultiDataSet data = org.nd4j.linalg.dataset.MultiDataSet.merge(collect);
INDArray scores = network.scoreExamples(data, addRegularization);
double[] doubleScores = scores.data().asDouble();
for (int i = 0; i < doubleScores.length; i++) {
ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
}
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
if (log.isDebugEnabled()) {
log.debug("Scored {} examples ", totalCount);
}
return ret;
}
use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getInitialModelGraph.
@Override
public ComputationGraph getInitialModelGraph() {
if (configuration.isCollectTrainingStats())
stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();
if (configuration.isCollectTrainingStats())
stats.logBroadcastGetValueStart();
NetBroadcastTuple tuple = broadcast.getValue();
if (configuration.isCollectTrainingStats())
stats.logBroadcastGetValueEnd();
//Don't want to have shared configuration object: each may update its iteration count (for LR schedule etc) individually
ComputationGraph net = new ComputationGraph(tuple.getGraphConfiguration().clone());
//Can't have shared parameter array across executors for parameter averaging, hence the 'true' for clone parameters array arg
net.init(tuple.getParameters().unsafeDuplication(), false);
if (tuple.getUpdaterState() != null) {
//Again: can't have shared updater state
net.setUpdater(new ComputationGraphUpdater(net, tuple.getUpdaterState().unsafeDuplication()));
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
configureListeners(net, tuple.getCounter().getAndIncrement());
if (configuration.isCollectTrainingStats())
stats.logInitEnd();
return net;
}
use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getFinalResult.
@Override
public ParameterAveragingTrainingResult getFinalResult(ComputationGraph network) {
INDArray updaterState = null;
if (saveUpdater) {
ComputationGraphUpdater u = network.getUpdater();
if (u != null)
updaterState = u.getStateViewArray();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> storageMetaData = null;
Collection<Persistable> listenerStaticInfo = null;
Collection<Persistable> listenerUpdates = null;
if (listenerRouterProvider != null) {
StatsStorageRouter r = listenerRouterProvider.getRouter();
if (r instanceof VanillaStatsStorageRouter) {
//TODO this is ugly... need to find a better solution
VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
storageMetaData = ssr.getStorageMetaData();
listenerStaticInfo = ssr.getStaticInfo();
listenerUpdates = ssr.getUpdates();
}
}
return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getFinalResult.
@Override
public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) {
INDArray updaterState = null;
if (saveUpdater) {
Updater u = network.getUpdater();
if (u != null)
updaterState = u.getStateViewArray();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> storageMetaData = null;
Collection<Persistable> listenerStaticInfo = null;
Collection<Persistable> listenerUpdates = null;
if (listenerRouterProvider != null) {
StatsStorageRouter r = listenerRouterProvider.getRouter();
if (r instanceof VanillaStatsStorageRouter) {
//TODO this is ugly... need to find a better solution
VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
storageMetaData = ssr.getStorageMetaData();
listenerStaticInfo = ssr.getStaticInfo();
listenerUpdates = ssr.getUpdates();
}
}
return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ParameterAveragingElementAddFunction method call.
@Override
public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple tuple, ParameterAveragingTrainingResult result) throws Exception {
if (tuple == null) {
return ParameterAveragingAggregationTuple.builder().parametersSum(result.getParameters()).updaterStateSum(result.getUpdaterState()).scoreSum(result.getScore()).aggregationsCount(1).sparkTrainingStats(result.getSparkTrainingStats()).listenerMetaData(result.getListenerMetaData()).listenerStaticInfo(result.getListenerStaticInfo()).listenerUpdates(result.getListenerUpdates()).build();
}
INDArray params = tuple.getParametersSum().addi(result.getParameters());
INDArray updaterStateSum;
if (tuple.getUpdaterStateSum() == null) {
updaterStateSum = result.getUpdaterState();
} else {
updaterStateSum = tuple.getUpdaterStateSum();
if (result.getUpdaterState() != null)
updaterStateSum.addi(result.getUpdaterState());
}
double scoreSum = tuple.getScoreSum() + result.getScore();
SparkTrainingStats stats = tuple.getSparkTrainingStats();
if (result.getSparkTrainingStats() != null) {
if (stats == null)
stats = result.getSparkTrainingStats();
else
stats.addOtherTrainingStats(result.getSparkTrainingStats());
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> listenerMetaData = tuple.getListenerMetaData();
if (listenerMetaData == null)
listenerMetaData = result.getListenerMetaData();
else {
Collection<StorageMetaData> newMeta = result.getListenerMetaData();
if (newMeta != null)
listenerMetaData.addAll(newMeta);
}
Collection<Persistable> listenerStaticInfo = tuple.getListenerStaticInfo();
if (listenerStaticInfo == null)
listenerStaticInfo = result.getListenerStaticInfo();
else {
Collection<Persistable> newStatic = tuple.getListenerStaticInfo();
if (newStatic != null)
listenerStaticInfo.addAll(newStatic);
}
Collection<Persistable> listenerUpdates = tuple.getListenerUpdates();
if (listenerUpdates == null)
listenerUpdates = result.getListenerUpdates();
else {
Collection<Persistable> newUpdates = result.getListenerUpdates();
if (newUpdates != null)
listenerUpdates.addAll(newUpdates);
}
return new ParameterAveragingAggregationTuple(params, updaterStateSum, scoreSum, tuple.getAggregationsCount() + 1, stats, listenerMetaData, listenerStaticInfo, listenerUpdates);
}
Aggregations