Search in sources :

Example 1 with ParameterAveragingElementAddFunction

use of org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementAddFunction in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingMaster method processResults.

private void processResults(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<ParameterAveragingTrainingResult> results, int splitNum, int totalSplits) {
    if (collectTrainingStats)
        stats.logAggregateStartTime();
    ParameterAveragingAggregationTuple tuple = results.aggregate(null, new ParameterAveragingElementAddFunction(), new ParameterAveragingElementCombineFunction());
    INDArray params = tuple.getParametersSum();
    int aggCount = tuple.getAggregationsCount();
    SparkTrainingStats aggregatedStats = tuple.getSparkTrainingStats();
    if (collectTrainingStats)
        stats.logAggregationEndTime();
    if (collectTrainingStats)
        stats.logProcessParamsUpdaterStart();
    if (params != null) {
        params.divi(aggCount);
        INDArray updaterState = tuple.getUpdaterStateSum();
        if (updaterState != null)
            //May be null if all SGD updaters, for example
            updaterState.divi(aggCount);
        if (network != null) {
            MultiLayerNetwork net = network.getNetwork();
            net.setParameters(params);
            if (updaterState != null)
                net.getUpdater().setStateViewArray(null, updaterState, false);
            network.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
        } else {
            ComputationGraph g = graph.getNetwork();
            g.setParams(params);
            if (updaterState != null)
                g.getUpdater().setStateViewArray(updaterState);
            graph.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
        }
    } else {
        log.info("Skipping imbalanced split with no data for all executors");
    }
    if (collectTrainingStats) {
        stats.logProcessParamsUpdaterEnd();
        stats.addWorkerStats(aggregatedStats);
    }
    if (statsStorage != null) {
        Collection<StorageMetaData> meta = tuple.getListenerMetaData();
        if (meta != null && meta.size() > 0) {
            statsStorage.putStorageMetaData(meta);
        }
        Collection<Persistable> staticInfo = tuple.getListenerStaticInfo();
        if (staticInfo != null && staticInfo.size() > 0) {
            statsStorage.putStaticInfo(staticInfo);
        }
        Collection<Persistable> updates = tuple.getListenerUpdates();
        if (updates != null && updates.size() > 0) {
            statsStorage.putUpdate(updates);
        }
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    log.info("Completed training of split {} of {}", splitNum, totalSplits);
    if (params != null) {
        //Params may be null for edge case (empty RDD)
        if (network != null) {
            MultiLayerConfiguration conf = network.getNetwork().getLayerWiseConfigurations();
            int numUpdates = network.getNetwork().conf().getNumIterations() * averagingFrequency;
            conf.setIterationCount(conf.getIterationCount() + numUpdates);
        } else {
            ComputationGraphConfiguration conf = graph.getNetwork().getConfiguration();
            int numUpdates = graph.getNetwork().conf().getNumIterations() * averagingFrequency;
            conf.setIterationCount(conf.getIterationCount() + numUpdates);
        }
    }
}
Also used : ParameterAveragingElementCombineFunction(org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementCombineFunction) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) ParameterAveragingAggregationTuple(org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingAggregationTuple) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ParameterAveragingElementAddFunction(org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementAddFunction) SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Aggregations

ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 SparkTrainingStats (org.deeplearning4j.spark.api.stats.SparkTrainingStats)1 SparkComputationGraph (org.deeplearning4j.spark.impl.graph.SparkComputationGraph)1 ParameterAveragingAggregationTuple (org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingAggregationTuple)1 ParameterAveragingElementAddFunction (org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementAddFunction)1 ParameterAveragingElementCombineFunction (org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementCombineFunction)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)1