use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class TestRenders method testHistogramComputationGraph.
@Test
public void testHistogramComputationGraph() throws Exception {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input").addLayer("cnn1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(3).build(), "input").addLayer("cnn2", new ConvolutionLayer.Builder(4, 4).stride(2, 2).padding(1, 1).nIn(1).nOut(3).build(), "input").addLayer("max1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).build(), "cnn1", "cnn2").addLayer("output", new OutputLayer.Builder().nIn(7 * 7 * 6).nOut(10).build(), "max1").setOutputs("output").inputPreProcessor("cnn1", new FeedForwardToCnnPreProcessor(28, 28, 1)).inputPreProcessor("cnn2", new FeedForwardToCnnPreProcessor(28, 28, 1)).inputPreProcessor("output", new CnnToFeedForwardPreProcessor(7, 7, 6)).pretrain(false).backprop(true).build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
graph.setListeners(new HistogramIterationListener(1), new ScoreIterationListener(1));
DataSetIterator mnist = new MnistDataSetIterator(32, 640, false, true, false, 12345);
graph.fit(mnist);
}
use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class TrainModule method getModelData.
private Result getModelData(String str) {
Long lastUpdateTime = lastUpdateForSession.get(currentSessionID);
if (lastUpdateTime == null)
lastUpdateTime = -1L;
//TODO validation
int layerIdx = Integer.parseInt(str);
I18N i18N = I18NProvider.getInstance();
//Model info for layer
boolean noData = currentSessionID == null;
//First pass (optimize later): query all data...
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
String wid = getWorkerIdForIndex(currentWorkerIdx);
if (wid == null) {
noData = true;
}
Map<String, Object> result = new HashMap<>();
result.put("updateTimestamp", lastUpdateTime);
Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> conf = getConfig();
if (conf == null) {
return ok(Json.toJson(result));
}
TrainModuleUtils.GraphInfo gi = getGraphInfo();
if (gi == null) {
return ok(Json.toJson(result));
}
// Get static layer info
String[][] layerInfoTable = getLayerInfoTable(layerIdx, gi, i18N, noData, ss, wid);
result.put("layerInfo", layerInfoTable);
//First: get all data, and subsample it if necessary, to avoid returning too many points...
List<Persistable> updates = (noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, wid, 0));
List<Integer> iterationCounts = null;
boolean needToHandleLegacyIterCounts = false;
if (updates != null && updates.size() > maxChartPoints) {
int subsamplingFrequency = updates.size() / maxChartPoints;
List<Persistable> subsampled = new ArrayList<>();
iterationCounts = new ArrayList<>();
int pCount = -1;
int lastUpdateIdx = updates.size() - 1;
int lastIterCount = -1;
for (Persistable p : updates) {
if (!(p instanceof StatsReport))
continue;
;
StatsReport sr = (StatsReport) p;
pCount++;
int iterCount = sr.getIterationCount();
if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
lastIterCount = iterCount;
if (pCount > 0 && subsamplingFrequency > 1 && pCount % subsamplingFrequency != 0) {
//Skip this to subsample the data
if (pCount != lastUpdateIdx)
//Always keep the most recent value
continue;
}
subsampled.add(p);
iterationCounts.add(iterCount);
}
updates = subsampled;
} else if (updates != null) {
int offset = 0;
iterationCounts = new ArrayList<>(updates.size());
int lastIterCount = -1;
for (Persistable p : updates) {
if (!(p instanceof StatsReport))
continue;
;
StatsReport sr = (StatsReport) p;
int iterCount = sr.getIterationCount();
if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
iterationCounts.add(iterCount);
}
}
//Now, it should use the proper iteration counts
if (needToHandleLegacyIterCounts) {
cleanLegacyIterationCounts(iterationCounts);
}
//Get mean magnitudes line chart
ModelType mt;
if (conf.getFirst() != null)
mt = ModelType.MLN;
else if (conf.getSecond() != null)
mt = ModelType.CG;
else
mt = ModelType.Layer;
MeanMagnitudes mm = getLayerMeanMagnitudes(layerIdx, gi, updates, iterationCounts, mt);
Map<String, Object> mmRatioMap = new HashMap<>();
mmRatioMap.put("layerParamNames", mm.getRatios().keySet());
mmRatioMap.put("iterCounts", mm.getIterations());
mmRatioMap.put("ratios", mm.getRatios());
mmRatioMap.put("paramMM", mm.getParamMM());
mmRatioMap.put("updateMM", mm.getUpdateMM());
result.put("meanMag", mmRatioMap);
//Get activations line chart for layer
Triple<int[], float[], float[]> activationsData = getLayerActivations(layerIdx, gi, updates, iterationCounts);
Map<String, Object> activationMap = new HashMap<>();
activationMap.put("iterCount", activationsData.getFirst());
activationMap.put("mean", activationsData.getSecond());
activationMap.put("stdev", activationsData.getThird());
result.put("activations", activationMap);
//Get learning rate vs. time chart for layer
Map<String, Object> lrs = getLayerLearningRates(layerIdx, gi, updates, iterationCounts, mt);
result.put("learningRates", lrs);
//Parameters histogram data
Persistable lastUpdate = (updates != null && updates.size() > 0 ? updates.get(updates.size() - 1) : null);
Map<String, Object> paramHistograms = getHistograms(layerIdx, gi, StatsType.Parameters, lastUpdate);
result.put("paramHist", paramHistograms);
//Updates histogram data
Map<String, Object> updateHistograms = getHistograms(layerIdx, gi, StatsType.Updates, lastUpdate);
result.put("updateHist", updateHistograms);
return ok(Json.toJson(result));
}
use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class TestPlayUI method testUICompGraph.
@Test
@Ignore
public void testUICompGraph() throws Exception {
StatsStorage ss = new InMemoryStatsStorage();
UIServer uiServer = UIServer.getInstance();
uiServer.attach(ss);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addLayer("L0", new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build(), "in").addLayer("L1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build(), "L0").pretrain(false).backprop(true).setOutputs("L1").build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 100; i++) {
net.fit(iter);
Thread.sleep(100);
}
Thread.sleep(100000);
}
use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class TestEarlyStoppingCompGraph method testNoImprovementNEpochsTermination.
@Test
public void testNoImprovementNEpochsTermination() {
//Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs
//Simulate this by setting LR = 0.0
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(0.0).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in").setOutputs("0").pretrain(false).backprop(true).build();
ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>().epochTerminationConditions(new MaxEpochsTerminationCondition(100), new ScoreImprovementEpochTerminationCondition(5)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS), //Initial score is ~2.5
new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new DataSetLossCalculatorCG(irisIter, true)).modelSaver(saver).build();
IEarlyStoppingTrainer trainer = new EarlyStoppingGraphTrainer(esConf, net, irisIter);
EarlyStoppingResult result = trainer.fit();
//Expect no score change due to 0 LR -> terminate after 6 total epochs
assertEquals(6, result.getTotalEpochs());
assertEquals(0, result.getBestModelEpoch());
assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
String expDetails = new ScoreImprovementEpochTerminationCondition(5).toString();
assertEquals(expDetails, result.getTerminationDetails());
}
use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class TestEarlyStoppingCompGraph method testEarlyStoppingIris.
@Test
public void testEarlyStoppingIris() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in").setOutputs("0").pretrain(false).backprop(true).build();
ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>().epochTerminationConditions(new MaxEpochsTerminationCondition(5)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES)).scoreCalculator(new DataSetLossCalculatorCG(irisIter, true)).modelSaver(saver).build();
IEarlyStoppingTrainer<ComputationGraph> trainer = new EarlyStoppingGraphTrainer(esConf, net, irisIter);
EarlyStoppingResult<ComputationGraph> result = trainer.fit();
System.out.println(result);
assertEquals(5, result.getTotalEpochs());
assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
Map<Integer, Double> scoreVsIter = result.getScoreVsEpoch();
assertEquals(5, scoreVsIter.size());
String expDetails = esConf.getEpochTerminationConditions().get(0).toString();
assertEquals(expDetails, result.getTerminationDetails());
ComputationGraph out = result.getBestModel();
assertNotNull(out);
//Check that best score actually matches (returned model vs. manually calculated score)
ComputationGraph bestNetwork = result.getBestModel();
irisIter.reset();
double score = bestNetwork.score(irisIter.next());
assertEquals(result.getBestModelScore(), score, 1e-2);
}
Aggregations