use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class InMemoryStatsStorage method putUpdate.
@Override
public void putUpdate(Persistable update) {
List<StatsStorageEvent> sses = checkStorageEvents(update);
Map<Long, Persistable> updateMap = getUpdateMap(update.getSessionID(), update.getTypeID(), update.getWorkerID(), true);
updateMap.put(update.getTimeStamp(), update);
StatsStorageEvent sse = null;
if (listeners.size() > 0)
sse = new StatsStorageEvent(this, StatsStorageListener.EventType.PostUpdate, update.getSessionID(), update.getTypeID(), update.getWorkerID(), update.getTimeStamp());
for (StatsStorageListener l : listeners) {
l.notify(sse);
}
notifyListeners(sses);
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getFinalResult.
@Override
public ParameterAveragingTrainingResult getFinalResult(ComputationGraph network) {
INDArray updaterState = null;
if (saveUpdater) {
ComputationGraphUpdater u = network.getUpdater();
if (u != null)
updaterState = u.getStateViewArray();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> storageMetaData = null;
Collection<Persistable> listenerStaticInfo = null;
Collection<Persistable> listenerUpdates = null;
if (listenerRouterProvider != null) {
StatsStorageRouter r = listenerRouterProvider.getRouter();
if (r instanceof VanillaStatsStorageRouter) {
//TODO this is ugly... need to find a better solution
VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
storageMetaData = ssr.getStorageMetaData();
listenerStaticInfo = ssr.getStaticInfo();
listenerUpdates = ssr.getUpdates();
}
}
return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getFinalResult.
@Override
public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) {
INDArray updaterState = null;
if (saveUpdater) {
Updater u = network.getUpdater();
if (u != null)
updaterState = u.getStateViewArray();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> storageMetaData = null;
Collection<Persistable> listenerStaticInfo = null;
Collection<Persistable> listenerUpdates = null;
if (listenerRouterProvider != null) {
StatsStorageRouter r = listenerRouterProvider.getRouter();
if (r instanceof VanillaStatsStorageRouter) {
//TODO this is ugly... need to find a better solution
VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
storageMetaData = ssr.getStorageMetaData();
listenerStaticInfo = ssr.getStaticInfo();
listenerUpdates = ssr.getUpdates();
}
}
return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class ParameterAveragingElementAddFunction method call.
@Override
public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple tuple, ParameterAveragingTrainingResult result) throws Exception {
if (tuple == null) {
return ParameterAveragingAggregationTuple.builder().parametersSum(result.getParameters()).updaterStateSum(result.getUpdaterState()).scoreSum(result.getScore()).aggregationsCount(1).sparkTrainingStats(result.getSparkTrainingStats()).listenerMetaData(result.getListenerMetaData()).listenerStaticInfo(result.getListenerStaticInfo()).listenerUpdates(result.getListenerUpdates()).build();
}
INDArray params = tuple.getParametersSum().addi(result.getParameters());
INDArray updaterStateSum;
if (tuple.getUpdaterStateSum() == null) {
updaterStateSum = result.getUpdaterState();
} else {
updaterStateSum = tuple.getUpdaterStateSum();
if (result.getUpdaterState() != null)
updaterStateSum.addi(result.getUpdaterState());
}
double scoreSum = tuple.getScoreSum() + result.getScore();
SparkTrainingStats stats = tuple.getSparkTrainingStats();
if (result.getSparkTrainingStats() != null) {
if (stats == null)
stats = result.getSparkTrainingStats();
else
stats.addOtherTrainingStats(result.getSparkTrainingStats());
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> listenerMetaData = tuple.getListenerMetaData();
if (listenerMetaData == null)
listenerMetaData = result.getListenerMetaData();
else {
Collection<StorageMetaData> newMeta = result.getListenerMetaData();
if (newMeta != null)
listenerMetaData.addAll(newMeta);
}
Collection<Persistable> listenerStaticInfo = tuple.getListenerStaticInfo();
if (listenerStaticInfo == null)
listenerStaticInfo = result.getListenerStaticInfo();
else {
Collection<Persistable> newStatic = tuple.getListenerStaticInfo();
if (newStatic != null)
listenerStaticInfo.addAll(newStatic);
}
Collection<Persistable> listenerUpdates = tuple.getListenerUpdates();
if (listenerUpdates == null)
listenerUpdates = result.getListenerUpdates();
else {
Collection<Persistable> newUpdates = result.getListenerUpdates();
if (newUpdates != null)
listenerUpdates.addAll(newUpdates);
}
return new ParameterAveragingAggregationTuple(params, updaterStateSum, scoreSum, tuple.getAggregationsCount() + 1, stats, listenerMetaData, listenerStaticInfo, listenerUpdates);
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TestListeners method testStatsCollection.
@Test
public void testStatsCollection() {
JavaSparkContext sc = getContext();
int nExecutors = numExecutors();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3).activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1).batchSizePerWorker(5).averagingFrequency(6).build();
SparkDl4jMultiLayer net = new SparkDl4jMultiLayer(sc, conf, tm);
//In-memory
StatsStorage ss = new MapDBStatsStorage();
net.setListeners(ss, Collections.singletonList(new StatsListener(null)));
List<DataSet> list = new IrisDataSetIterator(120, 150).next().asList();
//120 examples, 4 executors, 30 examples per executor -> 6 updates of size 5 per executor
JavaRDD<DataSet> rdd = sc.parallelize(list);
net.fit(rdd);
List<String> sessions = ss.listSessionIDs();
System.out.println("Sessions: " + sessions);
assertEquals(1, sessions.size());
String sid = sessions.get(0);
List<String> typeIDs = ss.listTypeIDsForSession(sid);
List<String> workers = ss.listWorkerIDsForSession(sid);
System.out.println(sid + "\t" + typeIDs + "\t" + workers);
List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
System.out.println(lastUpdates);
System.out.println("Static info:");
for (String wid : workers) {
Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid);
System.out.println(sid + "\t" + wid);
}
assertEquals(1, typeIDs.size());
assertEquals(numExecutors(), workers.size());
String firstWorker = workers.get(0);
String firstWorkerSubstring = workers.get(0).substring(0, firstWorker.length() - 1);
for (String wid : workers) {
String widSubstring = wid.substring(0, wid.length() - 1);
assertEquals(firstWorkerSubstring, widSubstring);
String counterVal = wid.substring(wid.length() - 1, wid.length());
int cv = Integer.parseInt(counterVal);
assertTrue(0 <= cv && cv < numExecutors());
}
}
Aggregations