Search in sources :

Example 16 with SparkTrainingStats

use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method processMinibatchWithStats.

@Override
public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> processMinibatchWithStats(MultiDataSet dataSet, ComputationGraph graph, boolean isLast) {
    ParameterAveragingTrainingResult result = processMinibatch(dataSet, graph, isLast);
    if (result == null)
        return null;
    SparkTrainingStats statsToReturn = (stats != null ? stats.build() : null);
    return new Pair<>(result, statsToReturn);
}
Also used : SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) Pair(org.deeplearning4j.berkeley.Pair)

Example 17 with SparkTrainingStats

use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.

the class ParameterAveragingElementAddFunction method call.

@Override
public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple tuple, ParameterAveragingTrainingResult result) throws Exception {
    if (tuple == null) {
        return ParameterAveragingAggregationTuple.builder().parametersSum(result.getParameters()).updaterStateSum(result.getUpdaterState()).scoreSum(result.getScore()).aggregationsCount(1).sparkTrainingStats(result.getSparkTrainingStats()).listenerMetaData(result.getListenerMetaData()).listenerStaticInfo(result.getListenerStaticInfo()).listenerUpdates(result.getListenerUpdates()).build();
    }
    INDArray params = tuple.getParametersSum().addi(result.getParameters());
    INDArray updaterStateSum;
    if (tuple.getUpdaterStateSum() == null) {
        updaterStateSum = result.getUpdaterState();
    } else {
        updaterStateSum = tuple.getUpdaterStateSum();
        if (result.getUpdaterState() != null)
            updaterStateSum.addi(result.getUpdaterState());
    }
    double scoreSum = tuple.getScoreSum() + result.getScore();
    SparkTrainingStats stats = tuple.getSparkTrainingStats();
    if (result.getSparkTrainingStats() != null) {
        if (stats == null)
            stats = result.getSparkTrainingStats();
        else
            stats.addOtherTrainingStats(result.getSparkTrainingStats());
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> listenerMetaData = tuple.getListenerMetaData();
    if (listenerMetaData == null)
        listenerMetaData = result.getListenerMetaData();
    else {
        Collection<StorageMetaData> newMeta = result.getListenerMetaData();
        if (newMeta != null)
            listenerMetaData.addAll(newMeta);
    }
    Collection<Persistable> listenerStaticInfo = tuple.getListenerStaticInfo();
    if (listenerStaticInfo == null)
        listenerStaticInfo = result.getListenerStaticInfo();
    else {
        Collection<Persistable> newStatic = tuple.getListenerStaticInfo();
        if (newStatic != null)
            listenerStaticInfo.addAll(newStatic);
    }
    Collection<Persistable> listenerUpdates = tuple.getListenerUpdates();
    if (listenerUpdates == null)
        listenerUpdates = result.getListenerUpdates();
    else {
        Collection<Persistable> newUpdates = result.getListenerUpdates();
        if (newUpdates != null)
            listenerUpdates.addAll(newUpdates);
    }
    return new ParameterAveragingAggregationTuple(params, updaterStateSum, scoreSum, tuple.getAggregationsCount() + 1, stats, listenerMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats)

Aggregations

SparkTrainingStats (org.deeplearning4j.spark.api.stats.SparkTrainingStats)17 Test (org.junit.Test)8 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)8 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)7 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)7 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)7 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)7 INDArray (org.nd4j.linalg.api.ndarray.INDArray)7 DataSet (org.nd4j.linalg.dataset.DataSet)7 File (java.io.File)6 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)6 SparkDl4jMultiLayer (org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer)5 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)5 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)5 LabeledPoint (org.apache.spark.mllib.regression.LabeledPoint)4 Pair (org.deeplearning4j.berkeley.Pair)4 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)4 SparkComputationGraph (org.deeplearning4j.spark.impl.graph.SparkComputationGraph)4 ParameterAveragingTrainingMaster (org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster)4 Path (java.nio.file.Path)3