Search in sources :

Example 6 with ComputationGraphUpdater

use of org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater in project deeplearning4j by deeplearning4j.

the class ModelSerializer method restoreComputationGraph.

/**
     * Load a computation graph from a file
     * @param file the file to get the computation graph from
     * @return the loaded computation graph
     *
     * @throws IOException
     */
public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException {
    ZipFile zipFile = new ZipFile(file);
    boolean gotConfig = false;
    boolean gotCoefficients = false;
    boolean gotOldUpdater = false;
    boolean gotUpdaterState = false;
    boolean gotPreProcessor = false;
    String json = "";
    INDArray params = null;
    ComputationGraphUpdater updater = null;
    INDArray updaterState = null;
    DataSetPreProcessor preProcessor = null;
    ZipEntry config = zipFile.getEntry("configuration.json");
    if (config != null) {
        //restoring configuration
        InputStream stream = zipFile.getInputStream(config);
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
        String line = "";
        StringBuilder js = new StringBuilder();
        while ((line = reader.readLine()) != null) {
            js.append(line).append("\n");
        }
        json = js.toString();
        reader.close();
        stream.close();
        gotConfig = true;
    }
    ZipEntry coefficients = zipFile.getEntry("coefficients.bin");
    if (coefficients != null) {
        InputStream stream = zipFile.getInputStream(coefficients);
        DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
        params = Nd4j.read(dis);
        dis.close();
        gotCoefficients = true;
    }
    if (loadUpdater) {
        ZipEntry oldUpdaters = zipFile.getEntry(OLD_UPDATER_BIN);
        if (oldUpdaters != null) {
            InputStream stream = zipFile.getInputStream(oldUpdaters);
            ObjectInputStream ois = new ObjectInputStream(stream);
            try {
                updater = (ComputationGraphUpdater) ois.readObject();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
            gotOldUpdater = true;
        }
        ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
        if (updaterStateEntry != null) {
            InputStream stream = zipFile.getInputStream(updaterStateEntry);
            DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
            updaterState = Nd4j.read(dis);
            dis.close();
            gotUpdaterState = true;
        }
    }
    ZipEntry prep = zipFile.getEntry("preprocessor.bin");
    if (prep != null) {
        InputStream stream = zipFile.getInputStream(prep);
        ObjectInputStream ois = new ObjectInputStream(stream);
        try {
            preProcessor = (DataSetPreProcessor) ois.readObject();
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
        gotPreProcessor = true;
    }
    zipFile.close();
    if (gotConfig && gotCoefficients) {
        ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json);
        ComputationGraph cg = new ComputationGraph(confFromJson);
        cg.init(params, false);
        if (gotUpdaterState && updaterState != null) {
            cg.getUpdater().setStateViewArray(updaterState);
        } else if (gotOldUpdater && updater != null) {
            cg.setUpdater(updater);
        }
        return cg;
    } else
        throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
}
Also used : ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) ZipEntry(java.util.zip.ZipEntry) DataSetPreProcessor(org.nd4j.linalg.dataset.api.DataSetPreProcessor) ZipFile(java.util.zip.ZipFile) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Example 7 with ComputationGraphUpdater

use of org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater in project deeplearning4j by deeplearning4j.

the class BaseOptimizer method updateGradientAccordingToParams.

@Override
public void updateGradientAccordingToParams(Gradient gradient, Model model, int batchSize) {
    if (model instanceof ComputationGraph) {
        ComputationGraph graph = (ComputationGraph) model;
        if (computationGraphUpdater == null) {
            computationGraphUpdater = new ComputationGraphUpdater(graph);
        }
        computationGraphUpdater.update(graph, gradient, getIterationCount(model), batchSize);
    } else {
        if (updater == null)
            updater = UpdaterCreator.getUpdater(model);
        Layer layer = (Layer) model;
        updater.update(layer, gradient, getIterationCount(model), batchSize);
    }
}
Also used : ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Layer(org.deeplearning4j.nn.api.Layer)

Example 8 with ComputationGraphUpdater

use of org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getInitialModelGraph.

@Override
public ComputationGraph getInitialModelGraph() {
    if (configuration.isCollectTrainingStats())
        stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();
    if (configuration.isCollectTrainingStats())
        stats.logBroadcastGetValueStart();
    NetBroadcastTuple tuple = broadcast.getValue();
    if (configuration.isCollectTrainingStats())
        stats.logBroadcastGetValueEnd();
    //Don't want to have shared configuration object: each may update its iteration count (for LR schedule etc) individually
    ComputationGraph net = new ComputationGraph(tuple.getGraphConfiguration().clone());
    //Can't have shared parameter array across executors for parameter averaging, hence the 'true' for clone parameters array arg
    net.init(tuple.getParameters().unsafeDuplication(), false);
    if (tuple.getUpdaterState() != null) {
        //Again: can't have shared updater state
        net.setUpdater(new ComputationGraphUpdater(net, tuple.getUpdaterState().unsafeDuplication()));
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    configureListeners(net, tuple.getCounter().getAndIncrement());
    if (configuration.isCollectTrainingStats())
        stats.logInitEnd();
    return net;
}
Also used : GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) NetBroadcastTuple(org.deeplearning4j.spark.api.worker.NetBroadcastTuple) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Example 9 with ComputationGraphUpdater

use of org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getFinalResult.

@Override
public ParameterAveragingTrainingResult getFinalResult(ComputationGraph network) {
    INDArray updaterState = null;
    if (saveUpdater) {
        ComputationGraphUpdater u = network.getUpdater();
        if (u != null)
            updaterState = u.getStateViewArray();
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> storageMetaData = null;
    Collection<Persistable> listenerStaticInfo = null;
    Collection<Persistable> listenerUpdates = null;
    if (listenerRouterProvider != null) {
        StatsStorageRouter r = listenerRouterProvider.getRouter();
        if (r instanceof VanillaStatsStorageRouter) {
            //TODO this is ugly... need to find a better solution
            VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
            storageMetaData = ssr.getStorageMetaData();
            listenerStaticInfo = ssr.getStaticInfo();
            listenerUpdates = ssr.getUpdates();
        }
    }
    return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter)

Aggregations

ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)9 INDArray (org.nd4j.linalg.api.ndarray.INDArray)6 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)4 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)2 GraphVertex (org.deeplearning4j.nn.conf.graph.GraphVertex)2 LayerVertex (org.deeplearning4j.nn.conf.graph.LayerVertex)2 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)2 Field (java.lang.reflect.Field)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1 Map (java.util.Map)1 ZipEntry (java.util.zip.ZipEntry)1 ZipFile (java.util.zip.ZipFile)1 Persistable (org.deeplearning4j.api.storage.Persistable)1 StatsStorageRouter (org.deeplearning4j.api.storage.StatsStorageRouter)1 StorageMetaData (org.deeplearning4j.api.storage.StorageMetaData)1 Layer (org.deeplearning4j.nn.api.Layer)1 Updater (org.deeplearning4j.nn.conf.Updater)1 Gradient (org.deeplearning4j.nn.gradient.Gradient)1