use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class TestEarlyStoppingSpark method testTimeTermination.
@Test
public void testTimeTermination() {
//test termination after max time
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(1e-6).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
JavaRDD<DataSet> irisData = getIris();
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(10000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS), //Initial score is ~2.5
new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkDataSetLossCalculator(irisData, true, sc.sc())).modelSaver(saver).build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new SparkEarlyStoppingTrainer(getContext().sc(), new ParameterAveragingTrainingMaster(true, 4, 1, 150 / 15, 1, 0), esConf, net, irisData);
long startTime = System.currentTimeMillis();
EarlyStoppingResult result = trainer.fit();
long endTime = System.currentTimeMillis();
int durationSeconds = (int) (endTime - startTime) / 1000;
assertTrue("durationSeconds = " + durationSeconds, durationSeconds >= 3);
assertTrue("durationSeconds = " + durationSeconds, durationSeconds <= 9);
assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
String expDetails = new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS).toString();
assertEquals(expDetails, result.getTerminationDetails());
}
use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class IEvaluateFlatMapFunctionAdapter method call.
@Override
public Iterable<T> call(Iterator<DataSet> dataSetIterator) throws Exception {
if (!dataSetIterator.hasNext()) {
return Collections.emptyList();
}
MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json.getValue()));
network.init();
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParameters(val);
List<DataSet> collect = new ArrayList<>();
int totalCount = 0;
while (dataSetIterator.hasNext()) {
collect.clear();
int nExamples = 0;
while (dataSetIterator.hasNext() && nExamples < evalBatchSize) {
DataSet next = dataSetIterator.next();
nExamples += next.numExamples();
collect.add(next);
}
totalCount += nExamples;
DataSet data = DataSet.merge(collect);
INDArray out;
if (data.hasMaskArrays()) {
out = network.output(data.getFeatureMatrix(), false, data.getFeaturesMaskArray(), data.getLabelsMaskArray());
} else {
out = network.output(data.getFeatureMatrix(), false);
}
if (data.getLabels().rank() == 3) {
if (data.getLabelsMaskArray() == null) {
evaluation.evalTimeSeries(data.getLabels(), out);
} else {
evaluation.evalTimeSeries(data.getLabels(), out, data.getLabelsMaskArray());
}
} else {
evaluation.eval(data.getLabels(), out);
}
}
if (log.isDebugEnabled()) {
log.debug("Evaluated {} examples ", totalCount);
}
return Collections.singletonList(evaluation);
}
use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class ScoreExamplesWithKeyFunctionAdapter method call.
@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, DataSet>> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}
MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue()));
network.init();
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParameters(val);
List<Tuple2<K, Double>> ret = new ArrayList<>();
List<DataSet> collect = new ArrayList<>(batchSize);
List<K> collectKey = new ArrayList<>(batchSize);
int totalCount = 0;
while (iterator.hasNext()) {
collect.clear();
collectKey.clear();
int nExamples = 0;
while (iterator.hasNext() && nExamples < batchSize) {
Tuple2<K, DataSet> t2 = iterator.next();
DataSet ds = t2._2();
int n = ds.numExamples();
if (n != 1)
throw new IllegalStateException("Cannot score examples with one key per data set if " + "data set contains more than 1 example (numExamples: " + n + ")");
collect.add(ds);
collectKey.add(t2._1());
nExamples += n;
}
totalCount += nExamples;
DataSet data = DataSet.merge(collect);
INDArray scores = network.scoreExamples(data, addRegularization);
double[] doubleScores = scores.data().asDouble();
for (int i = 0; i < doubleScores.length; i++) {
ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
}
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
if (log.isDebugEnabled()) {
log.debug("Scored {} examples ", totalCount);
}
return ret;
}
use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getInitialModel.
@Override
public MultiLayerNetwork getInitialModel() {
if (configuration.isCollectTrainingStats())
stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();
if (configuration.isCollectTrainingStats())
stats.logBroadcastGetValueStart();
NetBroadcastTuple tuple = broadcast.getValue();
if (configuration.isCollectTrainingStats())
stats.logBroadcastGetValueEnd();
//Don't want to have shared configuration object: each may update its iteration count (for LR schedule etc) individually
MultiLayerNetwork net = new MultiLayerNetwork(tuple.getConfiguration().clone());
//Can't have shared parameter array across executors for parameter averaging, hence the 'true' for clone parameters array arg
net.init(tuple.getParameters().unsafeDuplication(), false);
if (tuple.getUpdaterState() != null) {
//Can't have shared updater state
net.setUpdater(new MultiLayerUpdater(net, tuple.getUpdaterState().unsafeDuplication()));
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
configureListeners(net, tuple.getCounter().getAndIncrement());
if (configuration.isCollectTrainingStats())
stats.logInitEnd();
return net;
}
use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class SparkDl4jMultiLayer method initNetwork.
private static MultiLayerNetwork initNetwork(MultiLayerConfiguration conf) {
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
return net;
}
Aggregations