Search in sources :

Example 26 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class InMemoryStatsStorage method putUpdate.

@Override
public void putUpdate(Persistable update) {
    List<StatsStorageEvent> sses = checkStorageEvents(update);
    Map<Long, Persistable> updateMap = getUpdateMap(update.getSessionID(), update.getTypeID(), update.getWorkerID(), true);
    updateMap.put(update.getTimeStamp(), update);
    StatsStorageEvent sse = null;
    if (listeners.size() > 0)
        sse = new StatsStorageEvent(this, StatsStorageListener.EventType.PostUpdate, update.getSessionID(), update.getTypeID(), update.getWorkerID(), update.getTimeStamp());
    for (StatsStorageListener l : listeners) {
        l.notify(sse);
    }
    notifyListeners(sses);
}
Also used : Persistable(org.deeplearning4j.api.storage.Persistable) StatsStorageEvent(org.deeplearning4j.api.storage.StatsStorageEvent) StatsStorageListener(org.deeplearning4j.api.storage.StatsStorageListener)

Example 27 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getFinalResult.

@Override
public ParameterAveragingTrainingResult getFinalResult(ComputationGraph network) {
    INDArray updaterState = null;
    if (saveUpdater) {
        ComputationGraphUpdater u = network.getUpdater();
        if (u != null)
            updaterState = u.getStateViewArray();
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> storageMetaData = null;
    Collection<Persistable> listenerStaticInfo = null;
    Collection<Persistable> listenerUpdates = null;
    if (listenerRouterProvider != null) {
        StatsStorageRouter r = listenerRouterProvider.getRouter();
        if (r instanceof VanillaStatsStorageRouter) {
            //TODO this is ugly... need to find a better solution
            VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
            storageMetaData = ssr.getStorageMetaData();
            listenerStaticInfo = ssr.getStaticInfo();
            listenerUpdates = ssr.getUpdates();
        }
    }
    return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter)

Example 28 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getFinalResult.

@Override
public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) {
    INDArray updaterState = null;
    if (saveUpdater) {
        Updater u = network.getUpdater();
        if (u != null)
            updaterState = u.getStateViewArray();
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> storageMetaData = null;
    Collection<Persistable> listenerStaticInfo = null;
    Collection<Persistable> listenerUpdates = null;
    if (listenerRouterProvider != null) {
        StatsStorageRouter r = listenerRouterProvider.getRouter();
        if (r instanceof VanillaStatsStorageRouter) {
            //TODO this is ugly... need to find a better solution
            VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
            storageMetaData = ssr.getStorageMetaData();
            listenerStaticInfo = ssr.getStaticInfo();
            listenerUpdates = ssr.getUpdates();
        }
    }
    return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) Updater(org.deeplearning4j.nn.api.Updater) MultiLayerUpdater(org.deeplearning4j.nn.updater.MultiLayerUpdater) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter)

Example 29 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class ParameterAveragingElementAddFunction method call.

@Override
public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple tuple, ParameterAveragingTrainingResult result) throws Exception {
    if (tuple == null) {
        return ParameterAveragingAggregationTuple.builder().parametersSum(result.getParameters()).updaterStateSum(result.getUpdaterState()).scoreSum(result.getScore()).aggregationsCount(1).sparkTrainingStats(result.getSparkTrainingStats()).listenerMetaData(result.getListenerMetaData()).listenerStaticInfo(result.getListenerStaticInfo()).listenerUpdates(result.getListenerUpdates()).build();
    }
    INDArray params = tuple.getParametersSum().addi(result.getParameters());
    INDArray updaterStateSum;
    if (tuple.getUpdaterStateSum() == null) {
        updaterStateSum = result.getUpdaterState();
    } else {
        updaterStateSum = tuple.getUpdaterStateSum();
        if (result.getUpdaterState() != null)
            updaterStateSum.addi(result.getUpdaterState());
    }
    double scoreSum = tuple.getScoreSum() + result.getScore();
    SparkTrainingStats stats = tuple.getSparkTrainingStats();
    if (result.getSparkTrainingStats() != null) {
        if (stats == null)
            stats = result.getSparkTrainingStats();
        else
            stats.addOtherTrainingStats(result.getSparkTrainingStats());
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> listenerMetaData = tuple.getListenerMetaData();
    if (listenerMetaData == null)
        listenerMetaData = result.getListenerMetaData();
    else {
        Collection<StorageMetaData> newMeta = result.getListenerMetaData();
        if (newMeta != null)
            listenerMetaData.addAll(newMeta);
    }
    Collection<Persistable> listenerStaticInfo = tuple.getListenerStaticInfo();
    if (listenerStaticInfo == null)
        listenerStaticInfo = result.getListenerStaticInfo();
    else {
        Collection<Persistable> newStatic = tuple.getListenerStaticInfo();
        if (newStatic != null)
            listenerStaticInfo.addAll(newStatic);
    }
    Collection<Persistable> listenerUpdates = tuple.getListenerUpdates();
    if (listenerUpdates == null)
        listenerUpdates = result.getListenerUpdates();
    else {
        Collection<Persistable> newUpdates = result.getListenerUpdates();
        if (newUpdates != null)
            listenerUpdates.addAll(newUpdates);
    }
    return new ParameterAveragingAggregationTuple(params, updaterStateSum, scoreSum, tuple.getAggregationsCount() + 1, stats, listenerMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats)

Example 30 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TestListeners method testStatsCollection.

@Test
public void testStatsCollection() {
    JavaSparkContext sc = getContext();
    int nExecutors = numExecutors();
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3).activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();
    TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1).batchSizePerWorker(5).averagingFrequency(6).build();
    SparkDl4jMultiLayer net = new SparkDl4jMultiLayer(sc, conf, tm);
    //In-memory
    StatsStorage ss = new MapDBStatsStorage();
    net.setListeners(ss, Collections.singletonList(new StatsListener(null)));
    List<DataSet> list = new IrisDataSetIterator(120, 150).next().asList();
    //120 examples, 4 executors, 30 examples per executor -> 6 updates of size 5 per executor
    JavaRDD<DataSet> rdd = sc.parallelize(list);
    net.fit(rdd);
    List<String> sessions = ss.listSessionIDs();
    System.out.println("Sessions: " + sessions);
    assertEquals(1, sessions.size());
    String sid = sessions.get(0);
    List<String> typeIDs = ss.listTypeIDsForSession(sid);
    List<String> workers = ss.listWorkerIDsForSession(sid);
    System.out.println(sid + "\t" + typeIDs + "\t" + workers);
    List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
    System.out.println(lastUpdates);
    System.out.println("Static info:");
    for (String wid : workers) {
        Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid);
        System.out.println(sid + "\t" + wid);
    }
    assertEquals(1, typeIDs.size());
    assertEquals(numExecutors(), workers.size());
    String firstWorker = workers.get(0);
    String firstWorkerSubstring = workers.get(0).substring(0, firstWorker.length() - 1);
    for (String wid : workers) {
        String widSubstring = wid.substring(0, wid.length() - 1);
        assertEquals(firstWorkerSubstring, widSubstring);
        String counterVal = wid.substring(wid.length() - 1, wid.length());
        int cv = Integer.parseInt(counterVal);
        assertTrue(0 <= cv && cv < numExecutors());
    }
}
Also used : Persistable(org.deeplearning4j.api.storage.Persistable) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) SparkDl4jMultiLayer(org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer) MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) StatsListener(org.deeplearning4j.ui.stats.StatsListener) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

Persistable (org.deeplearning4j.api.storage.Persistable)30 StatsStorage (org.deeplearning4j.api.storage.StatsStorage)14 StatsReport (org.deeplearning4j.ui.stats.api.StatsReport)7 StorageMetaData (org.deeplearning4j.api.storage.StorageMetaData)6 StatsInitializationReport (org.deeplearning4j.ui.stats.api.StatsInitializationReport)6 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)5 Test (org.junit.Test)5 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)4 MapDBStatsStorage (org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 IOException (java.io.IOException)3 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)3 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)3 FlowStaticPersistable (org.deeplearning4j.ui.flow.data.FlowStaticPersistable)3 FlowUpdatePersistable (org.deeplearning4j.ui.flow.data.FlowUpdatePersistable)3 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)3 BufferedImage (java.awt.image.BufferedImage)2 File (java.io.File)2 ArrayList (java.util.ArrayList)2