Search in sources :

Example 11 with Solver

use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method doTruncatedBPTT.

protected void doTruncatedBPTT(INDArray input, INDArray labels, INDArray featuresMaskArray, INDArray labelsMaskArray) {
    if (input.rank() != 3 || labels.rank() != 3) {
        log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got " + Arrays.toString(input.shape()) + "\tand labels with shape " + Arrays.toString(labels.shape()));
        return;
    }
    if (input.size(2) != labels.size(2)) {
        log.warn("Input and label time series have different lengths: {} input length, {} label length", input.size(2), labels.size(2));
        return;
    }
    int fwdLen = layerWiseConfigurations.getTbpttFwdLength();
    update(TaskUtils.buildTask(input, labels));
    int timeSeriesLength = input.size(2);
    int nSubsets = timeSeriesLength / fwdLen;
    if (timeSeriesLength % fwdLen != 0)
        //Example: 100 fwdLen with timeSeriesLength=100 -> want 2 subsets (1 of size 100, 1 of size 20)
        nSubsets++;
    rnnClearPreviousState();
    for (int i = 0; i < nSubsets; i++) {
        int startTimeIdx = i * fwdLen;
        int endTimeIdx = startTimeIdx + fwdLen;
        if (endTimeIdx > timeSeriesLength)
            endTimeIdx = timeSeriesLength;
        INDArray inputSubset = input.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
        INDArray labelSubset = labels.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
        setInput(inputSubset);
        setLabels(labelSubset);
        INDArray featuresMaskSubset = null;
        INDArray labelsMaskSubset = null;
        if (featuresMaskArray != null) {
            featuresMaskSubset = featuresMaskArray.get(NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
        }
        if (labelsMaskArray != null) {
            labelsMaskSubset = labelsMaskArray.get(NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
        }
        if (featuresMaskSubset != null || labelsMaskSubset != null)
            setLayerMaskArrays(featuresMaskSubset, labelsMaskSubset);
        if (solver == null) {
            solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        }
        solver.optimize();
        //Finally, update the state of the RNN layers:
        updateRnnStateWithTBPTTState();
    }
    rnnClearPreviousState();
    if (featuresMaskArray != null || labelsMaskArray != null)
        clearLayerMaskArrays();
}
Also used : Solver(org.deeplearning4j.optimize.Solver) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 12 with Solver

use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.

the class LossLayer method fit.

/**
     * Fit the model
     *
     * @param input the examples to classify (one example in each row)
     * @param labels   the example labels(a binary outcome matrix)
     */
@Override
public void fit(INDArray input, INDArray labels) {
    setInput(input);
    setLabels(labels);
    applyDropOutIfNecessary(true);
    if (solver == null) {
        solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        //Set the updater state view array. For MLN and CG, this is done by MultiLayerUpdater and ComputationGraphUpdater respectively
        Updater updater = solver.getOptimizer().getUpdater();
        int updaterStateSize = updater.stateSizeForLayer(this);
        if (updaterStateSize > 0)
            updater.setStateViewArray(this, Nd4j.createUninitialized(new int[] { 1, updaterStateSize }, Nd4j.order()), true);
    }
    solver.optimize();
}
Also used : Solver(org.deeplearning4j.optimize.Solver) Updater(org.deeplearning4j.nn.api.Updater)

Aggregations

Solver (org.deeplearning4j.optimize.Solver)12 Updater (org.deeplearning4j.nn.api.Updater)4 AsyncMultiDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)3 SingletonMultiDataSetIterator (org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)3 AsyncDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncDataSetIterator)2 TrainingListener (org.deeplearning4j.optimize.api.TrainingListener)2 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)2 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)2 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)1 DataSet (org.nd4j.linalg.dataset.DataSet)1 DataSet (org.nd4j.linalg.dataset.api.DataSet)1