Search in sources :

Example 1 with DataSetPreProcessor

use of org.nd4j.linalg.dataset.api.DataSetPreProcessor in project deeplearning4j by deeplearning4j.

the class ModelSerializer method restoreComputationGraph.

/**
     * Load a computation graph from a file
     * @param file the file to get the computation graph from
     * @return the loaded computation graph
     *
     * @throws IOException
     */
public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException {
    ZipFile zipFile = new ZipFile(file);
    boolean gotConfig = false;
    boolean gotCoefficients = false;
    boolean gotOldUpdater = false;
    boolean gotUpdaterState = false;
    boolean gotPreProcessor = false;
    String json = "";
    INDArray params = null;
    ComputationGraphUpdater updater = null;
    INDArray updaterState = null;
    DataSetPreProcessor preProcessor = null;
    ZipEntry config = zipFile.getEntry("configuration.json");
    if (config != null) {
        //restoring configuration
        InputStream stream = zipFile.getInputStream(config);
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
        String line = "";
        StringBuilder js = new StringBuilder();
        while ((line = reader.readLine()) != null) {
            js.append(line).append("\n");
        }
        json = js.toString();
        reader.close();
        stream.close();
        gotConfig = true;
    }
    ZipEntry coefficients = zipFile.getEntry("coefficients.bin");
    if (coefficients != null) {
        InputStream stream = zipFile.getInputStream(coefficients);
        DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
        params = Nd4j.read(dis);
        dis.close();
        gotCoefficients = true;
    }
    if (loadUpdater) {
        ZipEntry oldUpdaters = zipFile.getEntry(OLD_UPDATER_BIN);
        if (oldUpdaters != null) {
            InputStream stream = zipFile.getInputStream(oldUpdaters);
            ObjectInputStream ois = new ObjectInputStream(stream);
            try {
                updater = (ComputationGraphUpdater) ois.readObject();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
            gotOldUpdater = true;
        }
        ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
        if (updaterStateEntry != null) {
            InputStream stream = zipFile.getInputStream(updaterStateEntry);
            DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
            updaterState = Nd4j.read(dis);
            dis.close();
            gotUpdaterState = true;
        }
    }
    ZipEntry prep = zipFile.getEntry("preprocessor.bin");
    if (prep != null) {
        InputStream stream = zipFile.getInputStream(prep);
        ObjectInputStream ois = new ObjectInputStream(stream);
        try {
            preProcessor = (DataSetPreProcessor) ois.readObject();
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
        gotPreProcessor = true;
    }
    zipFile.close();
    if (gotConfig && gotCoefficients) {
        ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json);
        ComputationGraph cg = new ComputationGraph(confFromJson);
        cg.init(params, false);
        if (gotUpdaterState && updaterState != null) {
            cg.getUpdater().setStateViewArray(updaterState);
        } else if (gotOldUpdater && updater != null) {
            cg.setUpdater(updater);
        }
        return cg;
    } else
        throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
}
Also used : ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) ZipEntry(java.util.zip.ZipEntry) DataSetPreProcessor(org.nd4j.linalg.dataset.api.DataSetPreProcessor) ZipFile(java.util.zip.ZipFile) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Example 2 with DataSetPreProcessor

use of org.nd4j.linalg.dataset.api.DataSetPreProcessor in project deeplearning4j by deeplearning4j.

the class ModelSerializer method restoreMultiLayerNetwork.

/**
     * Load a multi layer network from a file
     *
     * @param file the file to load from
     * @return the loaded multi layer network
     * @throws IOException
     */
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boolean loadUpdater) throws IOException {
    ZipFile zipFile = new ZipFile(file);
    boolean gotConfig = false;
    boolean gotCoefficients = false;
    boolean gotOldUpdater = false;
    boolean gotUpdaterState = false;
    boolean gotPreProcessor = false;
    String json = "";
    INDArray params = null;
    Updater updater = null;
    INDArray updaterState = null;
    DataSetPreProcessor preProcessor = null;
    ZipEntry config = zipFile.getEntry("configuration.json");
    if (config != null) {
        //restoring configuration
        InputStream stream = zipFile.getInputStream(config);
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
        String line = "";
        StringBuilder js = new StringBuilder();
        while ((line = reader.readLine()) != null) {
            js.append(line).append("\n");
        }
        json = js.toString();
        reader.close();
        stream.close();
        gotConfig = true;
    }
    ZipEntry coefficients = zipFile.getEntry("coefficients.bin");
    if (coefficients != null) {
        InputStream stream = zipFile.getInputStream(coefficients);
        DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
        params = Nd4j.read(dis);
        dis.close();
        gotCoefficients = true;
    }
    if (loadUpdater) {
        //This can be removed a few releases after 0.4.1...
        ZipEntry oldUpdaters = zipFile.getEntry(OLD_UPDATER_BIN);
        if (oldUpdaters != null) {
            InputStream stream = zipFile.getInputStream(oldUpdaters);
            ObjectInputStream ois = new ObjectInputStream(stream);
            try {
                updater = (Updater) ois.readObject();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
            gotOldUpdater = true;
        }
        ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
        if (updaterStateEntry != null) {
            InputStream stream = zipFile.getInputStream(updaterStateEntry);
            DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
            updaterState = Nd4j.read(dis);
            dis.close();
            gotUpdaterState = true;
        }
    }
    ZipEntry prep = zipFile.getEntry("preprocessor.bin");
    if (prep != null) {
        InputStream stream = zipFile.getInputStream(prep);
        ObjectInputStream ois = new ObjectInputStream(stream);
        try {
            preProcessor = (DataSetPreProcessor) ois.readObject();
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
        gotPreProcessor = true;
    }
    zipFile.close();
    if (gotConfig && gotCoefficients) {
        MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
        MultiLayerNetwork network = new MultiLayerNetwork(confFromJson);
        network.init(params, false);
        if (gotUpdaterState && updaterState != null) {
            network.getUpdater().setStateViewArray(network, updaterState, false);
        } else if (gotOldUpdater && updater != null) {
            network.setUpdater(updater);
        }
        return network;
    } else
        throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
}
Also used : ZipEntry(java.util.zip.ZipEntry) DataSetPreProcessor(org.nd4j.linalg.dataset.api.DataSetPreProcessor) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) ZipFile(java.util.zip.ZipFile) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) Updater(org.deeplearning4j.nn.api.Updater) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Aggregations

ZipEntry (java.util.zip.ZipEntry)2 ZipFile (java.util.zip.ZipFile)2 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 DataSetPreProcessor (org.nd4j.linalg.dataset.api.DataSetPreProcessor)2 Updater (org.deeplearning4j.nn.api.Updater)1 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1