Search in sources :

Example 81 with ComputationGraphConfiguration

use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.

the class KerasModelConfigurationTest method importKerasMlpModelMultilossConfigTest.

@Test
public void importKerasMlpModelMultilossConfigTest() throws Exception {
    ClassPathResource configResource = new ClassPathResource("modelimport/keras/configs/mlp_fapi_multiloss_config.json", KerasModelConfigurationTest.class.getClassLoader());
    ComputationGraphConfiguration config = new KerasModel.ModelBuilder().modelJsonInputStream(configResource.getInputStream()).enforceTrainingConfig(true).buildModel().getComputationGraphConfiguration();
    ComputationGraph model = new ComputationGraph(config);
    model.init();
}
Also used : ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 82 with ComputationGraphConfiguration

use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.

the class KerasModelConfigurationTest method importKerasMlpModelConfigTest.

@Test
public void importKerasMlpModelConfigTest() throws Exception {
    ClassPathResource configResource = new ClassPathResource("modelimport/keras/configs/mlp_fapi_config.json", KerasModelConfigurationTest.class.getClassLoader());
    ComputationGraphConfiguration config = new KerasModel.ModelBuilder().modelJsonInputStream(configResource.getInputStream()).enforceTrainingConfig(true).buildModel().getComputationGraphConfiguration();
    ComputationGraph model = new ComputationGraph(config);
    model.init();
}
Also used : ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 83 with ComputationGraphConfiguration

use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.

the class ModelSerializer method restoreComputationGraph.

/**
     * Load a computation graph from a file
     * @param file the file to get the computation graph from
     * @return the loaded computation graph
     *
     * @throws IOException
     */
public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException {
    ZipFile zipFile = new ZipFile(file);
    boolean gotConfig = false;
    boolean gotCoefficients = false;
    boolean gotOldUpdater = false;
    boolean gotUpdaterState = false;
    boolean gotPreProcessor = false;
    String json = "";
    INDArray params = null;
    ComputationGraphUpdater updater = null;
    INDArray updaterState = null;
    DataSetPreProcessor preProcessor = null;
    ZipEntry config = zipFile.getEntry("configuration.json");
    if (config != null) {
        //restoring configuration
        InputStream stream = zipFile.getInputStream(config);
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
        String line = "";
        StringBuilder js = new StringBuilder();
        while ((line = reader.readLine()) != null) {
            js.append(line).append("\n");
        }
        json = js.toString();
        reader.close();
        stream.close();
        gotConfig = true;
    }
    ZipEntry coefficients = zipFile.getEntry("coefficients.bin");
    if (coefficients != null) {
        InputStream stream = zipFile.getInputStream(coefficients);
        DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
        params = Nd4j.read(dis);
        dis.close();
        gotCoefficients = true;
    }
    if (loadUpdater) {
        ZipEntry oldUpdaters = zipFile.getEntry(OLD_UPDATER_BIN);
        if (oldUpdaters != null) {
            InputStream stream = zipFile.getInputStream(oldUpdaters);
            ObjectInputStream ois = new ObjectInputStream(stream);
            try {
                updater = (ComputationGraphUpdater) ois.readObject();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
            gotOldUpdater = true;
        }
        ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
        if (updaterStateEntry != null) {
            InputStream stream = zipFile.getInputStream(updaterStateEntry);
            DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
            updaterState = Nd4j.read(dis);
            dis.close();
            gotUpdaterState = true;
        }
    }
    ZipEntry prep = zipFile.getEntry("preprocessor.bin");
    if (prep != null) {
        InputStream stream = zipFile.getInputStream(prep);
        ObjectInputStream ois = new ObjectInputStream(stream);
        try {
            preProcessor = (DataSetPreProcessor) ois.readObject();
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
        gotPreProcessor = true;
    }
    zipFile.close();
    if (gotConfig && gotCoefficients) {
        ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json);
        ComputationGraph cg = new ComputationGraph(confFromJson);
        cg.init(params, false);
        if (gotUpdaterState && updaterState != null) {
            cg.getUpdater().setStateViewArray(updaterState);
        } else if (gotOldUpdater && updater != null) {
            cg.setUpdater(updater);
        }
        return cg;
    } else
        throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
}
Also used : ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) ZipEntry(java.util.zip.ZipEntry) DataSetPreProcessor(org.nd4j.linalg.dataset.api.DataSetPreProcessor) ZipFile(java.util.zip.ZipFile) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Example 84 with ComputationGraphConfiguration

use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.

the class BaseOptimizer method incrementIterationCount.

public static void incrementIterationCount(Model model, int incrementBy) {
    if (model instanceof MultiLayerNetwork) {
        MultiLayerConfiguration conf = ((MultiLayerNetwork) model).getLayerWiseConfigurations();
        conf.setIterationCount(conf.getIterationCount() + incrementBy);
    } else if (model instanceof ComputationGraph) {
        ComputationGraphConfiguration conf = ((ComputationGraph) model).getConfiguration();
        conf.setIterationCount(conf.getIterationCount() + incrementBy);
    } else {
        model.conf().setIterationCount(model.conf().getIterationCount() + incrementBy);
    }
}
Also used : MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 85 with ComputationGraphConfiguration

use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingMaster method processResults.

private void processResults(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<ParameterAveragingTrainingResult> results, int splitNum, int totalSplits) {
    if (collectTrainingStats)
        stats.logAggregateStartTime();
    ParameterAveragingAggregationTuple tuple = results.aggregate(null, new ParameterAveragingElementAddFunction(), new ParameterAveragingElementCombineFunction());
    INDArray params = tuple.getParametersSum();
    int aggCount = tuple.getAggregationsCount();
    SparkTrainingStats aggregatedStats = tuple.getSparkTrainingStats();
    if (collectTrainingStats)
        stats.logAggregationEndTime();
    if (collectTrainingStats)
        stats.logProcessParamsUpdaterStart();
    if (params != null) {
        params.divi(aggCount);
        INDArray updaterState = tuple.getUpdaterStateSum();
        if (updaterState != null)
            //May be null if all SGD updaters, for example
            updaterState.divi(aggCount);
        if (network != null) {
            MultiLayerNetwork net = network.getNetwork();
            net.setParameters(params);
            if (updaterState != null)
                net.getUpdater().setStateViewArray(null, updaterState, false);
            network.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
        } else {
            ComputationGraph g = graph.getNetwork();
            g.setParams(params);
            if (updaterState != null)
                g.getUpdater().setStateViewArray(updaterState);
            graph.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
        }
    } else {
        log.info("Skipping imbalanced split with no data for all executors");
    }
    if (collectTrainingStats) {
        stats.logProcessParamsUpdaterEnd();
        stats.addWorkerStats(aggregatedStats);
    }
    if (statsStorage != null) {
        Collection<StorageMetaData> meta = tuple.getListenerMetaData();
        if (meta != null && meta.size() > 0) {
            statsStorage.putStorageMetaData(meta);
        }
        Collection<Persistable> staticInfo = tuple.getListenerStaticInfo();
        if (staticInfo != null && staticInfo.size() > 0) {
            statsStorage.putStaticInfo(staticInfo);
        }
        Collection<Persistable> updates = tuple.getListenerUpdates();
        if (updates != null && updates.size() > 0) {
            statsStorage.putUpdate(updates);
        }
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    log.info("Completed training of split {} of {}", splitNum, totalSplits);
    if (params != null) {
        //Params may be null for edge case (empty RDD)
        if (network != null) {
            MultiLayerConfiguration conf = network.getNetwork().getLayerWiseConfigurations();
            int numUpdates = network.getNetwork().conf().getNumIterations() * averagingFrequency;
            conf.setIterationCount(conf.getIterationCount() + numUpdates);
        } else {
            ComputationGraphConfiguration conf = graph.getNetwork().getConfiguration();
            int numUpdates = graph.getNetwork().conf().getNumIterations() * averagingFrequency;
            conf.setIterationCount(conf.getIterationCount() + numUpdates);
        }
    }
}
Also used : ParameterAveragingElementCombineFunction(org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementCombineFunction) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) ParameterAveragingAggregationTuple(org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingAggregationTuple) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ParameterAveragingElementAddFunction(org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementAddFunction) SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Aggregations

ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)96 Test (org.junit.Test)84 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)63 INDArray (org.nd4j.linalg.api.ndarray.INDArray)51 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)50 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)28 DataSet (org.nd4j.linalg.dataset.DataSet)22 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)20 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)17 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)17 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)17 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)14 Random (java.util.Random)13 ParameterAveragingTrainingMaster (org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster)11 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)10 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)10 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)10 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)9 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)9 RnnOutputLayer (org.deeplearning4j.nn.conf.layers.RnnOutputLayer)9