use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class TestMasking method checkMaskArrayClearance.
@Test
public void checkMaskArrayClearance() {
for (boolean tbptt : new boolean[] { true, false }) {
//Simple "does it throw an exception" type test...
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).seed(12345).list().layer(0, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(1).nOut(1).build()).backpropType(tbptt ? BackpropType.TruncatedBPTT : BackpropType.Standard).tBPTTForwardLength(8).tBPTTBackwardLength(8).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSet data = new DataSet(Nd4j.linspace(1, 10, 10).reshape(1, 1, 10), Nd4j.linspace(2, 20, 10).reshape(1, 1, 10), Nd4j.ones(10), Nd4j.ones(10));
net.fit(data);
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
net.fit(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArray(), data.getLabelsMaskArray());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(data).iterator());
net.fit(iter);
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
}
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class TestOptimizers method testOptimizersBasicMLPBackprop.
@Test
public void testOptimizersBasicMLPBackprop() {
//Basic tests of the 'does it throw an exception' variety.
DataSetIterator iter = new IrisDataSetIterator(5, 50);
OptimizationAlgorithm[] toTest = { OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT, OptimizationAlgorithm.LINE_GRADIENT_DESCENT, OptimizationAlgorithm.CONJUGATE_GRADIENT, OptimizationAlgorithm.LBFGS };
for (OptimizationAlgorithm oa : toTest) {
MultiLayerNetwork network = new MultiLayerNetwork(getMLPConfigIris(oa, 1));
network.init();
iter.reset();
network.fit(iter);
}
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class Dl4jServingRouteTest method createRouteBuilder.
@Override
protected RouteBuilder createRouteBuilder() throws Exception {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
next = iter.next();
next.normalizeZeroMeanZeroUnitVariance();
return new RouteBuilder() {
@Override
public void configure() throws Exception {
final String kafkaUri = String.format("kafka:%s?topic=%s&groupId=dl4j-serving", kafkaCluster.getBrokerList(), topicName);
from("direct:start").process(new Processor() {
@Override
public void process(Exchange exchange) throws Exception {
final INDArray arr = next.getFeatureMatrix();
ByteArrayOutputStream bos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(bos);
Nd4j.write(arr, dos);
byte[] bytes = bos.toByteArray();
String base64 = Base64.encodeBase64String(bytes);
exchange.getIn().setBody(base64, String.class);
exchange.getIn().setHeader(KafkaConstants.KEY, UUID.randomUUID().toString());
exchange.getIn().setHeader(KafkaConstants.PARTITION_KEY, "1");
}
}).to(kafkaUri);
}
};
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class ExecuteWorkerFlatMapAdapter method call.
@Override
public Iterable<R> call(Iterator<DataSet> dataSetIterator) throws Exception {
WorkerConfiguration dataConfig = worker.getDataConfiguration();
final boolean isGraph = dataConfig.isGraphNetwork();
boolean stats = dataConfig.isCollectTrainingStats();
StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
if (stats)
s.logMethodStartTime();
if (!dataSetIterator.hasNext()) {
if (stats) {
s.logReturnTime();
Pair<R, SparkTrainingStats> pair = worker.getFinalResultNoDataWithStats();
pair.getFirst().setStats(s.build(pair.getSecond()));
return Collections.singletonList(pair.getFirst());
} else {
return Collections.singletonList(worker.getFinalResultNoData());
}
}
int batchSize = dataConfig.getBatchSizePerWorker();
final int prefetchCount = dataConfig.getPrefetchNumBatches();
DataSetIterator batchedIterator = new IteratorDataSetIterator(dataSetIterator, batchSize);
if (prefetchCount > 0) {
batchedIterator = new AsyncDataSetIterator(batchedIterator, prefetchCount);
}
try {
MultiLayerNetwork net = null;
ComputationGraph graph = null;
if (stats)
s.logInitialModelBefore();
if (isGraph)
graph = worker.getInitialModelGraph();
else
net = worker.getInitialModel();
if (stats)
s.logInitialModelAfter();
int miniBatchCount = 0;
int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
if (stats)
s.logNextDataSetBefore();
DataSet next = batchedIterator.next();
if (stats)
s.logNextDataSetAfter(next.numExamples());
if (stats) {
s.logProcessMinibatchBefore();
Pair<R, SparkTrainingStats> result;
if (isGraph)
result = worker.processMinibatchWithStats(next, graph, !batchedIterator.hasNext());
else
result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
s.logProcessMinibatchAfter();
if (result != null) {
//Terminate training immediately
s.logReturnTime();
SparkTrainingStats workerStats = result.getSecond();
SparkTrainingStats returnStats = s.build(workerStats);
result.getFirst().setStats(returnStats);
return Collections.singletonList(result.getFirst());
}
} else {
R result;
if (isGraph)
result = worker.processMinibatch(next, graph, !batchedIterator.hasNext());
else
result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
if (result != null) {
//Terminate training immediately
return Collections.singletonList(result);
}
}
}
//For some reason, we didn't return already. Normally this shouldn't happen
if (stats) {
s.logReturnTime();
Pair<R, SparkTrainingStats> pair;
if (isGraph)
pair = worker.getFinalResultWithStats(graph);
else
pair = worker.getFinalResultWithStats(net);
pair.getFirst().setStats(s.build(pair.getSecond()));
return Collections.singletonList(pair.getFirst());
} else {
if (isGraph)
return Collections.singletonList(worker.getFinalResult(graph));
else
return Collections.singletonList(worker.getFinalResult(net));
}
} finally {
//Make sure we shut down the async thread properly...
if (batchedIterator instanceof AsyncDataSetIterator) {
((AsyncDataSetIterator) batchedIterator).shutdown();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
}
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class MultiLayerNetwork method fit.
@Override
public void fit(DataSetIterator iterator) {
DataSetIterator iter;
// we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
if (iterator.asyncSupported()) {
iter = new AsyncDataSetIterator(iterator, 2);
} else {
iter = iterator;
}
if (trainingListeners.size() > 0) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochStart(this);
}
}
if (layerWiseConfigurations.isPretrain()) {
pretrain(iter);
if (iter.resetSupported()) {
iter.reset();
}
// while (iter.hasNext()) {
// DataSet next = iter.next();
// if (next.getFeatureMatrix() == null || next.getLabels() == null)
// break;
// setInput(next.getFeatureMatrix());
// setLabels(next.getLabels());
// finetune();
// }
}
if (layerWiseConfigurations.isBackprop()) {
update(TaskUtils.buildTask(iter));
if (!iter.hasNext() && iter.resetSupported()) {
iter.reset();
}
while (iter.hasNext()) {
DataSet next = iter.next();
if (next.getFeatureMatrix() == null || next.getLabels() == null)
break;
boolean hasMaskArrays = next.hasMaskArrays();
if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(next.getFeatureMatrix(), next.getLabels(), next.getFeaturesMaskArray(), next.getLabelsMaskArray());
} else {
if (hasMaskArrays)
setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
setInput(next.getFeatureMatrix());
setLabels(next.getLabels());
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
}
solver.optimize();
}
if (hasMaskArrays)
clearLayerMaskArrays();
Nd4j.getMemoryManager().invokeGcOccasionally();
}
} else if (layerWiseConfigurations.isPretrain()) {
log.warn("Warning: finetune is not applied.");
}
if (trainingListeners.size() > 0) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochEnd(this);
}
}
}
Aggregations