Search in sources :

Example 86 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class KerasModelConfigurationTest method importKerasMlpModelConfigTest.

@Test
public void importKerasMlpModelConfigTest() throws Exception {
    ClassPathResource configResource = new ClassPathResource("modelimport/keras/configs/mlp_fapi_config.json", KerasModelConfigurationTest.class.getClassLoader());
    ComputationGraphConfiguration config = new KerasModel.ModelBuilder().modelJsonInputStream(configResource.getInputStream()).enforceTrainingConfig(true).buildModel().getComputationGraphConfiguration();
    ComputationGraph model = new ComputationGraph(config);
    model.init();
}
Also used : ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 87 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class ParallelWrapper method averageUpdatersState.

private void averageUpdatersState(AtomicInteger locker, double score) {
    if (averageUpdaters) {
        ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater();
        int batchSize = 0;
        if (updater != null && updater.getStateViewArray() != null) {
            if (!legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
                List<INDArray> updaters = new ArrayList<>();
                for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                    ComputationGraph workerModel = (ComputationGraph) zoo[cnt].getModel();
                    updaters.add(workerModel.getUpdater().getStateViewArray());
                    batchSize += workerModel.batchSize();
                }
                Nd4j.averageAndPropagate(updater.getStateViewArray(), updaters);
            } else {
                INDArray state = Nd4j.zeros(updater.getStateViewArray().shape());
                int cnt = 0;
                for (; cnt < workers && cnt < locker.get(); cnt++) {
                    ComputationGraph workerModel = (ComputationGraph) zoo[cnt].getModel();
                    state.addi(workerModel.getUpdater().getStateViewArray());
                    batchSize += workerModel.batchSize();
                }
                state.divi(cnt);
                updater.setStateViewArray(state);
            }
        }
    }
    ((ComputationGraph) model).setScore(score);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Example 88 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class ParallelWrapper method fit.

/**
     *
     * @param source
     */
public synchronized void fit(@NonNull MultiDataSetIterator source) {
    stopFit.set(false);
    if (zoo == null) {
        zoo = new Trainer[workers];
        for (int cnt = 0; cnt < workers; cnt++) {
            // we pass true here, to tell Trainer to use MultiDataSet queue for training
            zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread(), true);
            zoo[cnt].setUncaughtExceptionHandler(handler);
            zoo[cnt].start();
        }
    } else {
        for (int cnt = 0; cnt < workers; cnt++) {
            zoo[cnt].useMDS = true;
        }
    }
    source.reset();
    MultiDataSetIterator iterator;
    if (prefetchSize > 0 && source.asyncSupported()) {
        iterator = new AsyncMultiDataSetIterator(source, prefetchSize);
    } else
        iterator = source;
    AtomicInteger locker = new AtomicInteger(0);
    while (iterator.hasNext() && !stopFit.get()) {
        MultiDataSet dataSet = iterator.next();
        if (dataSet == null)
            throw new ND4JIllegalStateException("You can't have NULL as MultiDataSet");
        /*
             now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
            */
        int pos = locker.getAndIncrement();
        zoo[pos].feedMultiDataSet(dataSet);
        /*
                if all workers are dispatched now, join till all are finished
            */
        if (pos + 1 == workers || !iterator.hasNext()) {
            iterationsCounter.incrementAndGet();
            for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                try {
                    zoo[cnt].waitTillRunning();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
            /*
                    average model, and propagate it to whole
                */
            if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
                double score = getScore(locker);
                // averaging updaters state
                if (model instanceof ComputationGraph) {
                    averageUpdatersState(locker, score);
                } else
                    throw new RuntimeException("MultiDataSet must only be used with ComputationGraph model");
                if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
                    for (int cnt = 0; cnt < workers; cnt++) {
                        zoo[cnt].updateModel(model);
                    }
                }
            }
            locker.set(0);
        }
    }
    // sanity checks, or the dataset may never average
    if (!wasAveraged)
        log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
    //            throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
    log.debug("Iterations passed: {}", iterationsCounter.get());
//        iterationsCounter.set(0);
}
Also used : MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 89 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class ModelSerializer method writeModel.

/**
     * Write a model to an output stream
     * @param model the model to save
     * @param stream the output stream to write to
     * @param saveUpdater whether to save the updater for the model or not
     * @throws IOException
     */
public static void writeModel(@NonNull Model model, @NonNull OutputStream stream, boolean saveUpdater) throws IOException {
    ZipOutputStream zipfile = new ZipOutputStream(new CloseShieldOutputStream(stream));
    // Save configuration as JSON
    String json = "";
    if (model instanceof MultiLayerNetwork) {
        json = ((MultiLayerNetwork) model).getLayerWiseConfigurations().toJson();
    } else if (model instanceof ComputationGraph) {
        json = ((ComputationGraph) model).getConfiguration().toJson();
    }
    ZipEntry config = new ZipEntry("configuration.json");
    zipfile.putNextEntry(config);
    zipfile.write(json.getBytes());
    // Save parameters as binary
    ZipEntry coefficients = new ZipEntry("coefficients.bin");
    zipfile.putNextEntry(coefficients);
    DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(zipfile));
    try {
        Nd4j.write(model.params(), dos);
    } finally {
        dos.flush();
        if (!saveUpdater)
            dos.close();
    }
    if (saveUpdater) {
        INDArray updaterState = null;
        if (model instanceof MultiLayerNetwork) {
            updaterState = ((MultiLayerNetwork) model).getUpdater().getStateViewArray();
        } else if (model instanceof ComputationGraph) {
            updaterState = ((ComputationGraph) model).getUpdater().getStateViewArray();
        }
        if (updaterState != null && updaterState.length() > 0) {
            ZipEntry updater = new ZipEntry(UPDATER_BIN);
            zipfile.putNextEntry(updater);
            try {
                Nd4j.write(updaterState, dos);
            } finally {
                dos.flush();
                dos.close();
            }
        }
    }
    zipfile.close();
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) ZipOutputStream(java.util.zip.ZipOutputStream) ZipEntry(java.util.zip.ZipEntry) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) CloseShieldOutputStream(org.apache.commons.io.output.CloseShieldOutputStream)

Example 90 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class ModelSerializer method restoreComputationGraph.

/**
     * Load a computation graph from a file
     * @param file the file to get the computation graph from
     * @return the loaded computation graph
     *
     * @throws IOException
     */
public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException {
    ZipFile zipFile = new ZipFile(file);
    boolean gotConfig = false;
    boolean gotCoefficients = false;
    boolean gotOldUpdater = false;
    boolean gotUpdaterState = false;
    boolean gotPreProcessor = false;
    String json = "";
    INDArray params = null;
    ComputationGraphUpdater updater = null;
    INDArray updaterState = null;
    DataSetPreProcessor preProcessor = null;
    ZipEntry config = zipFile.getEntry("configuration.json");
    if (config != null) {
        //restoring configuration
        InputStream stream = zipFile.getInputStream(config);
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
        String line = "";
        StringBuilder js = new StringBuilder();
        while ((line = reader.readLine()) != null) {
            js.append(line).append("\n");
        }
        json = js.toString();
        reader.close();
        stream.close();
        gotConfig = true;
    }
    ZipEntry coefficients = zipFile.getEntry("coefficients.bin");
    if (coefficients != null) {
        InputStream stream = zipFile.getInputStream(coefficients);
        DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
        params = Nd4j.read(dis);
        dis.close();
        gotCoefficients = true;
    }
    if (loadUpdater) {
        ZipEntry oldUpdaters = zipFile.getEntry(OLD_UPDATER_BIN);
        if (oldUpdaters != null) {
            InputStream stream = zipFile.getInputStream(oldUpdaters);
            ObjectInputStream ois = new ObjectInputStream(stream);
            try {
                updater = (ComputationGraphUpdater) ois.readObject();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
            gotOldUpdater = true;
        }
        ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
        if (updaterStateEntry != null) {
            InputStream stream = zipFile.getInputStream(updaterStateEntry);
            DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
            updaterState = Nd4j.read(dis);
            dis.close();
            gotUpdaterState = true;
        }
    }
    ZipEntry prep = zipFile.getEntry("preprocessor.bin");
    if (prep != null) {
        InputStream stream = zipFile.getInputStream(prep);
        ObjectInputStream ois = new ObjectInputStream(stream);
        try {
            preProcessor = (DataSetPreProcessor) ois.readObject();
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
        gotPreProcessor = true;
    }
    zipFile.close();
    if (gotConfig && gotCoefficients) {
        ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json);
        ComputationGraph cg = new ComputationGraph(confFromJson);
        cg.init(params, false);
        if (gotUpdaterState && updaterState != null) {
            cg.getUpdater().setStateViewArray(updaterState);
        } else if (gotOldUpdater && updater != null) {
            cg.setUpdater(updater);
        }
        return cg;
    } else
        throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
}
Also used : ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) ZipEntry(java.util.zip.ZipEntry) DataSetPreProcessor(org.nd4j.linalg.dataset.api.DataSetPreProcessor) ZipFile(java.util.zip.ZipFile) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Aggregations

ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)109 Test (org.junit.Test)73 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)63 INDArray (org.nd4j.linalg.api.ndarray.INDArray)62 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)36 DataSet (org.nd4j.linalg.dataset.DataSet)25 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)22 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)21 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)19 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)19 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)17 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)17 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)14 Layer (org.deeplearning4j.nn.api.Layer)14 Random (java.util.Random)11 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)10 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)10 TrainingMaster (org.deeplearning4j.spark.api.TrainingMaster)10 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)9 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)9