Search in sources :

Example 1 with SparkTrainingStats

use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method processMinibatchWithStats.

@Override
public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork network, boolean isLast) {
    ParameterAveragingTrainingResult result = processMinibatch(dataSet, network, isLast);
    if (result == null)
        return null;
    SparkTrainingStats statsToReturn = (stats != null ? stats.build() : null);
    return new Pair<>(result, statsToReturn);
}
Also used : SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) Pair(org.deeplearning4j.berkeley.Pair)

Example 2 with SparkTrainingStats

use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getFinalResultWithStats.

@Override
public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> getFinalResultWithStats(ComputationGraph graph) {
    ParameterAveragingTrainingResult result = getFinalResult(graph);
    if (result == null)
        return null;
    SparkTrainingStats statsToReturn = (stats != null ? stats.build() : null);
    return new Pair<>(result, statsToReturn);
}
Also used : SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) Pair(org.deeplearning4j.berkeley.Pair)

Example 3 with SparkTrainingStats

use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.

the class ParameterAveragingElementCombineFunction method call.

@Override
public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple v1, ParameterAveragingAggregationTuple v2) throws Exception {
    if (v1 == null)
        return v2;
    else if (v2 == null)
        return v1;
    //Handle edge case of less data than executors: in this case, one (or both) of v1 and v2 might not have any contents...
    if (v1.getParametersSum() == null)
        return v2;
    else if (v2.getParametersSum() == null)
        return v1;
    INDArray newParams = v1.getParametersSum().addi(v2.getParametersSum());
    INDArray updaterStateSum;
    if (v1.getUpdaterStateSum() == null) {
        updaterStateSum = v2.getUpdaterStateSum();
    } else {
        updaterStateSum = v1.getUpdaterStateSum();
        if (v2.getUpdaterStateSum() != null)
            updaterStateSum.addi(v2.getUpdaterStateSum());
    }
    double scoreSum = v1.getScoreSum() + v2.getScoreSum();
    int aggregationCount = v1.getAggregationsCount() + v2.getAggregationsCount();
    SparkTrainingStats stats = v1.getSparkTrainingStats();
    if (v2.getSparkTrainingStats() != null) {
        if (stats == null)
            stats = v2.getSparkTrainingStats();
        else
            stats.addOtherTrainingStats(v2.getSparkTrainingStats());
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> listenerMetaData = v1.getListenerMetaData();
    if (listenerMetaData == null)
        listenerMetaData = v2.getListenerMetaData();
    else {
        Collection<StorageMetaData> newMeta = v2.getListenerMetaData();
        if (newMeta != null)
            listenerMetaData.addAll(newMeta);
    }
    Collection<Persistable> listenerStaticInfo = v1.getListenerStaticInfo();
    if (listenerStaticInfo == null)
        listenerStaticInfo = v2.getListenerStaticInfo();
    else {
        Collection<Persistable> newStatic = v2.getListenerStaticInfo();
        if (newStatic != null)
            listenerStaticInfo.addAll(newStatic);
    }
    Collection<Persistable> listenerUpdates = v1.getListenerUpdates();
    if (listenerUpdates == null)
        listenerUpdates = v2.getListenerUpdates();
    else {
        Collection<Persistable> listenerUpdates2 = v2.getListenerUpdates();
        if (listenerUpdates2 != null)
            listenerUpdates.addAll(listenerUpdates2);
    }
    return new ParameterAveragingAggregationTuple(newParams, updaterStateSum, scoreSum, aggregationCount, stats, listenerMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats)

Example 4 with SparkTrainingStats

use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.

the class ExecuteWorkerFlatMapAdapter method call.

@Override
public Iterable<R> call(Iterator<DataSet> dataSetIterator) throws Exception {
    WorkerConfiguration dataConfig = worker.getDataConfiguration();
    final boolean isGraph = dataConfig.isGraphNetwork();
    boolean stats = dataConfig.isCollectTrainingStats();
    StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
    if (stats)
        s.logMethodStartTime();
    if (!dataSetIterator.hasNext()) {
        if (stats) {
            s.logReturnTime();
            Pair<R, SparkTrainingStats> pair = worker.getFinalResultNoDataWithStats();
            pair.getFirst().setStats(s.build(pair.getSecond()));
            return Collections.singletonList(pair.getFirst());
        } else {
            return Collections.singletonList(worker.getFinalResultNoData());
        }
    }
    int batchSize = dataConfig.getBatchSizePerWorker();
    final int prefetchCount = dataConfig.getPrefetchNumBatches();
    DataSetIterator batchedIterator = new IteratorDataSetIterator(dataSetIterator, batchSize);
    if (prefetchCount > 0) {
        batchedIterator = new AsyncDataSetIterator(batchedIterator, prefetchCount);
    }
    try {
        MultiLayerNetwork net = null;
        ComputationGraph graph = null;
        if (stats)
            s.logInitialModelBefore();
        if (isGraph)
            graph = worker.getInitialModelGraph();
        else
            net = worker.getInitialModel();
        if (stats)
            s.logInitialModelAfter();
        int miniBatchCount = 0;
        int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
        while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
            if (stats)
                s.logNextDataSetBefore();
            DataSet next = batchedIterator.next();
            if (stats)
                s.logNextDataSetAfter(next.numExamples());
            if (stats) {
                s.logProcessMinibatchBefore();
                Pair<R, SparkTrainingStats> result;
                if (isGraph)
                    result = worker.processMinibatchWithStats(next, graph, !batchedIterator.hasNext());
                else
                    result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
                s.logProcessMinibatchAfter();
                if (result != null) {
                    //Terminate training immediately
                    s.logReturnTime();
                    SparkTrainingStats workerStats = result.getSecond();
                    SparkTrainingStats returnStats = s.build(workerStats);
                    result.getFirst().setStats(returnStats);
                    return Collections.singletonList(result.getFirst());
                }
            } else {
                R result;
                if (isGraph)
                    result = worker.processMinibatch(next, graph, !batchedIterator.hasNext());
                else
                    result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
                if (result != null) {
                    //Terminate training immediately
                    return Collections.singletonList(result);
                }
            }
        }
        //For some reason, we didn't return already. Normally this shouldn't happen
        if (stats) {
            s.logReturnTime();
            Pair<R, SparkTrainingStats> pair;
            if (isGraph)
                pair = worker.getFinalResultWithStats(graph);
            else
                pair = worker.getFinalResultWithStats(net);
            pair.getFirst().setStats(s.build(pair.getSecond()));
            return Collections.singletonList(pair.getFirst());
        } else {
            if (isGraph)
                return Collections.singletonList(worker.getFinalResult(graph));
            else
                return Collections.singletonList(worker.getFinalResult(net));
        }
    } finally {
        //Make sure we shut down the async thread properly...
        if (batchedIterator instanceof AsyncDataSetIterator) {
            ((AsyncDataSetIterator) batchedIterator).shutdown();
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner)
            ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) WorkerConfiguration(org.deeplearning4j.spark.api.WorkerConfiguration) StatsCalculationHelper(org.deeplearning4j.spark.api.stats.StatsCalculationHelper) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) IteratorDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorDataSetIterator) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) IteratorDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator)

Example 5 with SparkTrainingStats

use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.

the class ExecuteWorkerMultiDataSetFlatMapAdapter method call.

@Override
public Iterable<R> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
    WorkerConfiguration dataConfig = worker.getDataConfiguration();
    boolean stats = dataConfig.isCollectTrainingStats();
    StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
    if (stats)
        s.logMethodStartTime();
    if (!dataSetIterator.hasNext()) {
        if (stats)
            s.logReturnTime();
        //Sometimes: no data
        return Collections.emptyList();
    }
    int batchSize = dataConfig.getBatchSizePerWorker();
    final int prefetchCount = dataConfig.getPrefetchNumBatches();
    MultiDataSetIterator batchedIterator = new IteratorMultiDataSetIterator(dataSetIterator, batchSize);
    if (prefetchCount > 0) {
        batchedIterator = new AsyncMultiDataSetIterator(batchedIterator, prefetchCount);
    }
    try {
        if (stats)
            s.logInitialModelBefore();
        ComputationGraph net = worker.getInitialModelGraph();
        if (stats)
            s.logInitialModelAfter();
        int miniBatchCount = 0;
        int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
        while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
            if (stats)
                s.logNextDataSetBefore();
            MultiDataSet next = batchedIterator.next();
            if (stats)
                s.logNextDataSetAfter(next.getFeatures(0).size(0));
            if (stats) {
                s.logProcessMinibatchBefore();
                Pair<R, SparkTrainingStats> result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
                s.logProcessMinibatchAfter();
                if (result != null) {
                    //Terminate training immediately
                    s.logReturnTime();
                    SparkTrainingStats workerStats = result.getSecond();
                    SparkTrainingStats returnStats = s.build(workerStats);
                    result.getFirst().setStats(returnStats);
                    return Collections.singletonList(result.getFirst());
                }
            } else {
                R result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
                if (result != null) {
                    //Terminate training immediately
                    return Collections.singletonList(result);
                }
            }
        }
        //For some reason, we didn't return already. Normally this shouldn't happen
        if (stats) {
            s.logReturnTime();
            Pair<R, SparkTrainingStats> pair = worker.getFinalResultWithStats(net);
            pair.getFirst().setStats(s.build(pair.getSecond()));
            return Collections.singletonList(pair.getFirst());
        } else {
            return Collections.singletonList(worker.getFinalResult(net));
        }
    } finally {
        //Make sure we shut down the async thread properly...
        if (batchedIterator instanceof AsyncMultiDataSetIterator) {
            ((AsyncMultiDataSetIterator) batchedIterator).shutdown();
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner)
            ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    }
}
Also used : IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) WorkerConfiguration(org.deeplearning4j.spark.api.WorkerConfiguration) StatsCalculationHelper(org.deeplearning4j.spark.api.stats.StatsCalculationHelper) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Aggregations

SparkTrainingStats (org.deeplearning4j.spark.api.stats.SparkTrainingStats)17 Test (org.junit.Test)8 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)8 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)7 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)7 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)7 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)7 INDArray (org.nd4j.linalg.api.ndarray.INDArray)7 DataSet (org.nd4j.linalg.dataset.DataSet)7 File (java.io.File)6 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)6 SparkDl4jMultiLayer (org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer)5 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)5 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)5 LabeledPoint (org.apache.spark.mllib.regression.LabeledPoint)4 Pair (org.deeplearning4j.berkeley.Pair)4 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)4 SparkComputationGraph (org.deeplearning4j.spark.impl.graph.SparkComputationGraph)4 ParameterAveragingTrainingMaster (org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster)4 Path (java.nio.file.Path)3