use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class BaseVaeScoreWithKeyFunctionAdapter method call.
@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, INDArray>> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}
VariationalAutoencoder vae = getVaeLayer();
List<Tuple2<K, Double>> ret = new ArrayList<>();
List<INDArray> collect = new ArrayList<>(batchSize);
List<K> collectKey = new ArrayList<>(batchSize);
int totalCount = 0;
while (iterator.hasNext()) {
collect.clear();
collectKey.clear();
int nExamples = 0;
while (iterator.hasNext() && nExamples < batchSize) {
Tuple2<K, INDArray> t2 = iterator.next();
INDArray features = t2._2();
int n = features.size(0);
if (n != 1)
throw new IllegalStateException("Cannot score examples with one key per data set if " + "data set contains more than 1 example (numExamples: " + n + ")");
collect.add(features);
collectKey.add(t2._1());
nExamples += n;
}
totalCount += nExamples;
INDArray toScore = Nd4j.vstack(collect);
INDArray scores = computeScore(vae, toScore);
double[] doubleScores = scores.data().asDouble();
for (int i = 0; i < doubleScores.length; i++) {
ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
}
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
if (log.isDebugEnabled()) {
log.debug("Scored {} examples ", totalCount);
}
return ret;
}
use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class GraphFeedForwardWithKeyFunctionAdapter method call.
@Override
public Iterable<Tuple2<K, INDArray[]>> call(Iterator<Tuple2<K, INDArray[]>> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue()));
network.init();
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParams(val);
//Issue: for 2d data (MLPs etc) we can just stack the examples.
//But: for 3d and 4d: in principle the data sizes could be different
//We could handle that with mask arrays - but it gets messy. The approach used here is simpler but less efficient
List<INDArray[]> featuresList = new ArrayList<>(batchSize);
List<K> keyList = new ArrayList<>(batchSize);
List<Integer> origSizeList = new ArrayList<>();
int[][] firstShapes = null;
boolean sizesDiffer = false;
int tupleCount = 0;
while (iterator.hasNext()) {
Tuple2<K, INDArray[]> t2 = iterator.next();
if (firstShapes == null) {
firstShapes = new int[t2._2().length][0];
for (int i = 0; i < firstShapes.length; i++) {
firstShapes[i] = t2._2()[i].shape();
}
} else if (!sizesDiffer) {
for (int i = 0; i < firstShapes.length; i++) {
for (int j = 1; j < firstShapes[i].length; j++) {
if (firstShapes[i][j] != featuresList.get(tupleCount - 1)[i].size(j)) {
sizesDiffer = true;
break;
}
}
}
}
featuresList.add(t2._2());
keyList.add(t2._1());
origSizeList.add(t2._2()[0].size(0));
tupleCount++;
}
if (tupleCount == 0) {
return Collections.emptyList();
}
List<Tuple2<K, INDArray[]>> output = new ArrayList<>(tupleCount);
int currentArrayIndex = 0;
while (currentArrayIndex < featuresList.size()) {
int firstIdx = currentArrayIndex;
int nextIdx = currentArrayIndex;
int examplesInBatch = 0;
List<INDArray[]> toMerge = new ArrayList<>();
firstShapes = null;
while (nextIdx < featuresList.size() && examplesInBatch < batchSize) {
INDArray[] f = featuresList.get(nextIdx);
if (firstShapes == null) {
firstShapes = new int[f.length][0];
for (int i = 0; i < firstShapes.length; i++) {
firstShapes[i] = f[i].shape();
}
} else if (sizesDiffer) {
boolean breakWhile = false;
for (int i = 0; i < firstShapes.length; i++) {
for (int j = 1; j < firstShapes[i].length; j++) {
if (firstShapes[i][j] != featuresList.get(nextIdx)[i].size(j)) {
//Next example has a different size. So: don't add it to the current batch, just process what we have
breakWhile = true;
break;
}
}
}
if (breakWhile) {
break;
}
}
toMerge.add(f);
examplesInBatch += f[0].size(0);
nextIdx++;
}
INDArray[] batchFeatures = new INDArray[toMerge.get(0).length];
for (int i = 0; i < batchFeatures.length; i++) {
INDArray[] tempArr = new INDArray[toMerge.size()];
for (int j = 0; j < tempArr.length; j++) {
tempArr[j] = toMerge.get(j)[i];
}
batchFeatures[i] = Nd4j.concat(0, tempArr);
}
INDArray[] out = network.output(false, batchFeatures);
examplesInBatch = 0;
for (int i = firstIdx; i < nextIdx; i++) {
int numExamples = origSizeList.get(i);
INDArray[] outSubset = new INDArray[out.length];
for (int j = 0; j < out.length; j++) {
outSubset[j] = getSubset(examplesInBatch, examplesInBatch + numExamples, out[j]);
}
examplesInBatch += numExamples;
output.add(new Tuple2<>(keyList.get(i), outSubset));
}
currentArrayIndex += (nextIdx - firstIdx);
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
return output;
}
use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ScoreExamplesWithKeyFunctionAdapter method call.
@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, DataSet>> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}
MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue()));
network.init();
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParameters(val);
List<Tuple2<K, Double>> ret = new ArrayList<>();
List<DataSet> collect = new ArrayList<>(batchSize);
List<K> collectKey = new ArrayList<>(batchSize);
int totalCount = 0;
while (iterator.hasNext()) {
collect.clear();
collectKey.clear();
int nExamples = 0;
while (iterator.hasNext() && nExamples < batchSize) {
Tuple2<K, DataSet> t2 = iterator.next();
DataSet ds = t2._2();
int n = ds.numExamples();
if (n != 1)
throw new IllegalStateException("Cannot score examples with one key per data set if " + "data set contains more than 1 example (numExamples: " + n + ")");
collect.add(ds);
collectKey.add(t2._1());
nExamples += n;
}
totalCount += nExamples;
DataSet data = DataSet.merge(collect);
INDArray scores = network.scoreExamples(data, addRegularization);
double[] doubleScores = scores.data().asDouble();
for (int i = 0; i < doubleScores.length; i++) {
ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
}
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
if (log.isDebugEnabled()) {
log.debug("Scored {} examples ", totalCount);
}
return ret;
}
use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getInitialModel.
@Override
public MultiLayerNetwork getInitialModel() {
if (configuration.isCollectTrainingStats())
stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();
if (configuration.isCollectTrainingStats())
stats.logBroadcastGetValueStart();
NetBroadcastTuple tuple = broadcast.getValue();
if (configuration.isCollectTrainingStats())
stats.logBroadcastGetValueEnd();
//Don't want to have shared configuration object: each may update its iteration count (for LR schedule etc) individually
MultiLayerNetwork net = new MultiLayerNetwork(tuple.getConfiguration().clone());
//Can't have shared parameter array across executors for parameter averaging, hence the 'true' for clone parameters array arg
net.init(tuple.getParameters().unsafeDuplication(), false);
if (tuple.getUpdaterState() != null) {
//Can't have shared updater state
net.setUpdater(new MultiLayerUpdater(net, tuple.getUpdaterState().unsafeDuplication()));
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
configureListeners(net, tuple.getCounter().getAndIncrement());
if (configuration.isCollectTrainingStats())
stats.logInitEnd();
return net;
}
use of org.nd4j.linalg.api.ops.executioner.GridExecutioner in project deeplearning4j by deeplearning4j.
the class ParameterAveragingElementCombineFunction method call.
@Override
public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple v1, ParameterAveragingAggregationTuple v2) throws Exception {
if (v1 == null)
return v2;
else if (v2 == null)
return v1;
//Handle edge case of less data than executors: in this case, one (or both) of v1 and v2 might not have any contents...
if (v1.getParametersSum() == null)
return v2;
else if (v2.getParametersSum() == null)
return v1;
INDArray newParams = v1.getParametersSum().addi(v2.getParametersSum());
INDArray updaterStateSum;
if (v1.getUpdaterStateSum() == null) {
updaterStateSum = v2.getUpdaterStateSum();
} else {
updaterStateSum = v1.getUpdaterStateSum();
if (v2.getUpdaterStateSum() != null)
updaterStateSum.addi(v2.getUpdaterStateSum());
}
double scoreSum = v1.getScoreSum() + v2.getScoreSum();
int aggregationCount = v1.getAggregationsCount() + v2.getAggregationsCount();
SparkTrainingStats stats = v1.getSparkTrainingStats();
if (v2.getSparkTrainingStats() != null) {
if (stats == null)
stats = v2.getSparkTrainingStats();
else
stats.addOtherTrainingStats(v2.getSparkTrainingStats());
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> listenerMetaData = v1.getListenerMetaData();
if (listenerMetaData == null)
listenerMetaData = v2.getListenerMetaData();
else {
Collection<StorageMetaData> newMeta = v2.getListenerMetaData();
if (newMeta != null)
listenerMetaData.addAll(newMeta);
}
Collection<Persistable> listenerStaticInfo = v1.getListenerStaticInfo();
if (listenerStaticInfo == null)
listenerStaticInfo = v2.getListenerStaticInfo();
else {
Collection<Persistable> newStatic = v2.getListenerStaticInfo();
if (newStatic != null)
listenerStaticInfo.addAll(newStatic);
}
Collection<Persistable> listenerUpdates = v1.getListenerUpdates();
if (listenerUpdates == null)
listenerUpdates = v2.getListenerUpdates();
else {
Collection<Persistable> listenerUpdates2 = v2.getListenerUpdates();
if (listenerUpdates2 != null)
listenerUpdates.addAll(listenerUpdates2);
}
return new ParameterAveragingAggregationTuple(newParams, updaterStateSum, scoreSum, aggregationCount, stats, listenerMetaData, listenerStaticInfo, listenerUpdates);
}
Aggregations