use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class TestParallelEarlyStoppingUI method testParallelStatsListenerCompatibility.
@Test
//To be run manually
@Ignore
public void testParallelStatsListenerCompatibility() throws Exception {
UIServer uiServer = UIServer.getInstance();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new OutputLayer.Builder().nIn(3).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
// it's important that the UI can report results from parallel training
// there's potential for StatsListener to fail if certain properties aren't set in the model
StatsStorage statsStorage = new InMemoryStatsStorage();
net.setListeners(new StatsListener(statsStorage));
uiServer.attach(statsStorage);
DataSetIterator irisIter = new IrisDataSetIterator(50, 500);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(500)).scoreCalculator(new DataSetLossCalculator(irisIter, true)).evaluateEveryNEpochs(2).modelSaver(saver).build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingParallelTrainer<>(esConf, net, irisIter, null, 3, 6, 2);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
System.out.println(result);
assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
}
use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class TestListeners method testStatsCollection.
@Test
public void testStatsCollection() {
JavaSparkContext sc = getContext();
int nExecutors = numExecutors();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3).activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1).batchSizePerWorker(5).averagingFrequency(6).build();
SparkDl4jMultiLayer net = new SparkDl4jMultiLayer(sc, conf, tm);
//In-memory
StatsStorage ss = new MapDBStatsStorage();
net.setListeners(ss, Collections.singletonList(new StatsListener(null)));
List<DataSet> list = new IrisDataSetIterator(120, 150).next().asList();
//120 examples, 4 executors, 30 examples per executor -> 6 updates of size 5 per executor
JavaRDD<DataSet> rdd = sc.parallelize(list);
net.fit(rdd);
List<String> sessions = ss.listSessionIDs();
System.out.println("Sessions: " + sessions);
assertEquals(1, sessions.size());
String sid = sessions.get(0);
List<String> typeIDs = ss.listTypeIDsForSession(sid);
List<String> workers = ss.listWorkerIDsForSession(sid);
System.out.println(sid + "\t" + typeIDs + "\t" + workers);
List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
System.out.println(lastUpdates);
System.out.println("Static info:");
for (String wid : workers) {
Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid);
System.out.println(sid + "\t" + wid);
}
assertEquals(1, typeIDs.size());
assertEquals(numExecutors(), workers.size());
String firstWorker = workers.get(0);
String firstWorkerSubstring = workers.get(0).substring(0, firstWorker.length() - 1);
for (String wid : workers) {
String widSubstring = wid.substring(0, wid.length() - 1);
assertEquals(firstWorkerSubstring, widSubstring);
String counterVal = wid.substring(wid.length() - 1, wid.length());
int cv = Integer.parseInt(counterVal);
assertTrue(0 <= cv && cv < numExecutors());
}
}
Aggregations