Search in sources :

Example 36 with Updater

use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.

the class ModelSerializer method restoreMultiLayerNetwork.

/**
     * Load a multi layer network from a file
     *
     * @param file the file to load from
     * @return the loaded multi layer network
     * @throws IOException
     */
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boolean loadUpdater) throws IOException {
    ZipFile zipFile = new ZipFile(file);
    boolean gotConfig = false;
    boolean gotCoefficients = false;
    boolean gotOldUpdater = false;
    boolean gotUpdaterState = false;
    boolean gotPreProcessor = false;
    String json = "";
    INDArray params = null;
    Updater updater = null;
    INDArray updaterState = null;
    DataSetPreProcessor preProcessor = null;
    ZipEntry config = zipFile.getEntry("configuration.json");
    if (config != null) {
        //restoring configuration
        InputStream stream = zipFile.getInputStream(config);
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
        String line = "";
        StringBuilder js = new StringBuilder();
        while ((line = reader.readLine()) != null) {
            js.append(line).append("\n");
        }
        json = js.toString();
        reader.close();
        stream.close();
        gotConfig = true;
    }
    ZipEntry coefficients = zipFile.getEntry("coefficients.bin");
    if (coefficients != null) {
        InputStream stream = zipFile.getInputStream(coefficients);
        DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
        params = Nd4j.read(dis);
        dis.close();
        gotCoefficients = true;
    }
    if (loadUpdater) {
        //This can be removed a few releases after 0.4.1...
        ZipEntry oldUpdaters = zipFile.getEntry(OLD_UPDATER_BIN);
        if (oldUpdaters != null) {
            InputStream stream = zipFile.getInputStream(oldUpdaters);
            ObjectInputStream ois = new ObjectInputStream(stream);
            try {
                updater = (Updater) ois.readObject();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
            gotOldUpdater = true;
        }
        ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
        if (updaterStateEntry != null) {
            InputStream stream = zipFile.getInputStream(updaterStateEntry);
            DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
            updaterState = Nd4j.read(dis);
            dis.close();
            gotUpdaterState = true;
        }
    }
    ZipEntry prep = zipFile.getEntry("preprocessor.bin");
    if (prep != null) {
        InputStream stream = zipFile.getInputStream(prep);
        ObjectInputStream ois = new ObjectInputStream(stream);
        try {
            preProcessor = (DataSetPreProcessor) ois.readObject();
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
        gotPreProcessor = true;
    }
    zipFile.close();
    if (gotConfig && gotCoefficients) {
        MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
        MultiLayerNetwork network = new MultiLayerNetwork(confFromJson);
        network.init(params, false);
        if (gotUpdaterState && updaterState != null) {
            network.getUpdater().setStateViewArray(network, updaterState, false);
        } else if (gotOldUpdater && updater != null) {
            network.setUpdater(updater);
        }
        return network;
    } else
        throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
}
Also used : ZipEntry(java.util.zip.ZipEntry) DataSetPreProcessor(org.nd4j.linalg.dataset.api.DataSetPreProcessor) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) ZipFile(java.util.zip.ZipFile) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) Updater(org.deeplearning4j.nn.api.Updater) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 37 with Updater

use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getFinalResult.

@Override
public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) {
    INDArray updaterState = null;
    if (saveUpdater) {
        Updater u = network.getUpdater();
        if (u != null)
            updaterState = u.getStateViewArray();
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> storageMetaData = null;
    Collection<Persistable> listenerStaticInfo = null;
    Collection<Persistable> listenerUpdates = null;
    if (listenerRouterProvider != null) {
        StatsStorageRouter r = listenerRouterProvider.getRouter();
        if (r instanceof VanillaStatsStorageRouter) {
            //TODO this is ugly... need to find a better solution
            VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
            storageMetaData = ssr.getStorageMetaData();
            listenerStaticInfo = ssr.getStaticInfo();
            listenerUpdates = ssr.getUpdates();
        }
    }
    return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) Updater(org.deeplearning4j.nn.api.Updater) MultiLayerUpdater(org.deeplearning4j.nn.updater.MultiLayerUpdater) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter)

Aggregations

Updater (org.deeplearning4j.nn.api.Updater)37 Test (org.junit.Test)28 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)27 INDArray (org.nd4j.linalg.api.ndarray.INDArray)27 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)25 Gradient (org.deeplearning4j.nn.gradient.Gradient)25 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)23 Layer (org.deeplearning4j.nn.api.Layer)21 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)18 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)9 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)8 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)5 HashMap (java.util.HashMap)4 Solver (org.deeplearning4j.optimize.Solver)4 ArrayList (java.util.ArrayList)2 Field (java.lang.reflect.Field)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 ZipEntry (java.util.zip.ZipEntry)1 ZipFile (java.util.zip.ZipFile)1 Persistable (org.deeplearning4j.api.storage.Persistable)1