Search in sources :

Example 1 with VanillaStatsStorageRouter

use of org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getFinalResult.

@Override
public ParameterAveragingTrainingResult getFinalResult(ComputationGraph network) {
    INDArray updaterState = null;
    if (saveUpdater) {
        ComputationGraphUpdater u = network.getUpdater();
        if (u != null)
            updaterState = u.getStateViewArray();
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> storageMetaData = null;
    Collection<Persistable> listenerStaticInfo = null;
    Collection<Persistable> listenerUpdates = null;
    if (listenerRouterProvider != null) {
        StatsStorageRouter r = listenerRouterProvider.getRouter();
        if (r instanceof VanillaStatsStorageRouter) {
            //TODO this is ugly... need to find a better solution
            VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
            storageMetaData = ssr.getStorageMetaData();
            listenerStaticInfo = ssr.getStaticInfo();
            listenerUpdates = ssr.getUpdates();
        }
    }
    return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter)

Example 2 with VanillaStatsStorageRouter

use of org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getFinalResult.

@Override
public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) {
    INDArray updaterState = null;
    if (saveUpdater) {
        Updater u = network.getUpdater();
        if (u != null)
            updaterState = u.getStateViewArray();
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> storageMetaData = null;
    Collection<Persistable> listenerStaticInfo = null;
    Collection<Persistable> listenerUpdates = null;
    if (listenerRouterProvider != null) {
        StatsStorageRouter r = listenerRouterProvider.getRouter();
        if (r instanceof VanillaStatsStorageRouter) {
            //TODO this is ugly... need to find a better solution
            VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
            storageMetaData = ssr.getStorageMetaData();
            listenerStaticInfo = ssr.getStaticInfo();
            listenerUpdates = ssr.getUpdates();
        }
    }
    return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) Updater(org.deeplearning4j.nn.api.Updater) MultiLayerUpdater(org.deeplearning4j.nn.updater.MultiLayerUpdater) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter)

Aggregations

Persistable (org.deeplearning4j.api.storage.Persistable)2 StatsStorageRouter (org.deeplearning4j.api.storage.StatsStorageRouter)2 StorageMetaData (org.deeplearning4j.api.storage.StorageMetaData)2 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)2 VanillaStatsStorageRouter (org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)2 Updater (org.deeplearning4j.nn.api.Updater)1 MultiLayerUpdater (org.deeplearning4j.nn.updater.MultiLayerUpdater)1