Search in sources :

Example 6 with Solver

use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method fit.

@Override
public void fit(DataSetIterator iterator) {
    DataSetIterator iter;
    // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
    if (iterator.asyncSupported()) {
        iter = new AsyncDataSetIterator(iterator, 2);
    } else {
        iter = iterator;
    }
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochStart(this);
        }
    }
    if (layerWiseConfigurations.isPretrain()) {
        pretrain(iter);
        if (iter.resetSupported()) {
            iter.reset();
        }
    //            while (iter.hasNext()) {
    //                DataSet next = iter.next();
    //                if (next.getFeatureMatrix() == null || next.getLabels() == null)
    //                    break;
    //                setInput(next.getFeatureMatrix());
    //                setLabels(next.getLabels());
    //                finetune();
    //            }
    }
    if (layerWiseConfigurations.isBackprop()) {
        update(TaskUtils.buildTask(iter));
        if (!iter.hasNext() && iter.resetSupported()) {
            iter.reset();
        }
        while (iter.hasNext()) {
            DataSet next = iter.next();
            if (next.getFeatureMatrix() == null || next.getLabels() == null)
                break;
            boolean hasMaskArrays = next.hasMaskArrays();
            if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(next.getFeatureMatrix(), next.getLabels(), next.getFeaturesMaskArray(), next.getLabelsMaskArray());
            } else {
                if (hasMaskArrays)
                    setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
                setInput(next.getFeatureMatrix());
                setLabels(next.getLabels());
                if (solver == null) {
                    solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                }
                solver.optimize();
            }
            if (hasMaskArrays)
                clearLayerMaskArrays();
            Nd4j.getMemoryManager().invokeGcOccasionally();
        }
    } else if (layerWiseConfigurations.isPretrain()) {
        log.warn("Warning: finetune is not applied.");
    }
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochEnd(this);
        }
    }
}
Also used : Solver(org.deeplearning4j.optimize.Solver) DataSet(org.nd4j.linalg.dataset.DataSet) TrainingListener(org.deeplearning4j.optimize.api.TrainingListener) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator)

Example 7 with Solver

use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.

the class ComputationGraph method fit.

/**
     * Fit the ComputationGraph using the specified inputs and labels (and mask arrays)
     *
     * @param inputs            The network inputs (features)
     * @param labels            The network labels
     * @param featureMaskArrays Mask arrays for inputs/features. Typically used for RNN training. May be null.
     * @param labelMaskArrays   Mas arrays for the labels/outputs. Typically used for RNN training. May be null.
     */
public void fit(INDArray[] inputs, INDArray[] labels, INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
    if (flattenedGradients == null)
        initGradientsView();
    setInputs(inputs);
    setLabels(labels);
    setLayerMaskArrays(featureMaskArrays, labelMaskArrays);
    update(TaskUtils.buildTask(inputs, labels));
    if (configuration.isPretrain()) {
        MultiDataSetIterator iter = new SingletonMultiDataSetIterator(new org.nd4j.linalg.dataset.MultiDataSet(inputs, labels, featureMaskArrays, labelMaskArrays));
        pretrain(iter);
    }
    if (configuration.isBackprop()) {
        if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
            doTruncatedBPTT(inputs, labels, featureMaskArrays, labelMaskArrays);
        } else {
            if (solver == null) {
                solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
            }
            solver.optimize();
        }
    }
    if (featureMaskArrays != null || labelMaskArrays != null) {
        clearLayerMaskArrays();
    }
}
Also used : SingletonMultiDataSetIterator(org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) Solver(org.deeplearning4j.optimize.Solver) SingletonMultiDataSetIterator(org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator)

Example 8 with Solver

use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.

the class ComputationGraph method getUpdater.

/**
     * Get the ComputationGraphUpdater for the network
     */
public ComputationGraphUpdater getUpdater() {
    if (solver == null) {
        solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this));
    }
    return solver.getOptimizer().getComputationGraphUpdater();
}
Also used : Solver(org.deeplearning4j.optimize.Solver) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)

Example 9 with Solver

use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.

the class BaseLayer method getOptimizer.

@Override
public ConvexOptimizer getOptimizer() {
    if (optimizer == null) {
        Solver solver = new Solver.Builder().model(this).configure(conf()).build();
        this.optimizer = solver.getOptimizer();
    }
    return optimizer;
}
Also used : Solver(org.deeplearning4j.optimize.Solver)

Example 10 with Solver

use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.

the class ComputationGraph method fit.

/**
     * Fit the ComputationGraph using a DataSetIterator.
     * Note that this method can only be used with ComputationGraphs with 1 input and 1 output
     */
public void fit(DataSetIterator iterator) {
    if (flattenedGradients == null)
        initGradientsView();
    if (numInputArrays != 1 || numOutputArrays != 1)
        throw new UnsupportedOperationException("Cannot train ComputationGraph network with " + " multiple inputs or outputs using a DataSetIterator");
    DataSetIterator dataSetIterator;
    // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
    if (iterator.asyncSupported()) {
        dataSetIterator = new AsyncDataSetIterator(iterator, 2);
    } else
        dataSetIterator = iterator;
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochStart(this);
        }
    }
    if (configuration.isPretrain()) {
        pretrain(dataSetIterator);
    }
    if (configuration.isBackprop()) {
        update(TaskUtils.buildTask(dataSetIterator));
        while (dataSetIterator.hasNext()) {
            DataSet next = dataSetIterator.next();
            if (next.getFeatures() == null || next.getLabels() == null)
                break;
            boolean hasMaskArrays = next.hasMaskArrays();
            if (hasMaskArrays) {
                INDArray[] fMask = (next.getFeaturesMaskArray() != null ? new INDArray[] { next.getFeaturesMaskArray() } : null);
                INDArray[] lMask = (next.getLabelsMaskArray() != null ? new INDArray[] { next.getLabelsMaskArray() } : null);
                setLayerMaskArrays(fMask, lMask);
            }
            if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(new INDArray[] { next.getFeatures() }, new INDArray[] { next.getLabels() }, (hasMaskArrays ? new INDArray[] { next.getFeaturesMaskArray() } : null), (hasMaskArrays ? new INDArray[] { next.getLabelsMaskArray() } : null));
            } else {
                setInput(0, next.getFeatures());
                setLabel(0, next.getLabels());
                if (solver == null) {
                    solver = //TODO; don't like this
                    new Solver.Builder().configure(defaultConfiguration).listeners(listeners).model(this).build();
                }
                solver.optimize();
            }
            if (hasMaskArrays) {
                clearLayerMaskArrays();
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
        }
    }
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochEnd(this);
        }
    }
}
Also used : Solver(org.deeplearning4j.optimize.Solver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSet(org.nd4j.linalg.dataset.api.DataSet) TrainingListener(org.deeplearning4j.optimize.api.TrainingListener) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) SingletonMultiDataSetIterator(org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Aggregations

Solver (org.deeplearning4j.optimize.Solver)12 Updater (org.deeplearning4j.nn.api.Updater)4 AsyncMultiDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)3 SingletonMultiDataSetIterator (org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)3 AsyncDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncDataSetIterator)2 TrainingListener (org.deeplearning4j.optimize.api.TrainingListener)2 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)2 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)2 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)1 DataSet (org.nd4j.linalg.dataset.DataSet)1 DataSet (org.nd4j.linalg.dataset.api.DataSet)1