Search in sources :

Example 1 with Solver

use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.

the class ComputationGraph method fit.

/**
     * Fit the ComputationGraph using a MultiDataSetIterator
     */
public void fit(MultiDataSetIterator multi) {
    if (flattenedGradients == null)
        initGradientsView();
    MultiDataSetIterator multiDataSetIterator;
    if (multi.asyncSupported()) {
        multiDataSetIterator = new AsyncMultiDataSetIterator(multi, 2);
    } else
        multiDataSetIterator = multi;
    if (configuration.isPretrain()) {
        pretrain(multiDataSetIterator);
    }
    if (configuration.isBackprop()) {
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet next = multiDataSetIterator.next();
            if (next.getFeatures() == null || next.getLabels() == null)
                break;
            if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
            } else {
                boolean hasMaskArrays = next.hasMaskArrays();
                if (hasMaskArrays) {
                    setLayerMaskArrays(next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
                }
                setInputs(next.getFeatures());
                setLabels(next.getLabels());
                if (solver == null) {
                    solver = new Solver.Builder().configure(defaultConfiguration).listeners(listeners).model(this).build();
                }
                solver.optimize();
                if (hasMaskArrays) {
                    clearLayerMaskArrays();
                }
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
        }
    }
}
Also used : SingletonMultiDataSetIterator(org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) Solver(org.deeplearning4j.optimize.Solver) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Example 2 with Solver

use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.

the class ComputationGraph method doTruncatedBPTT.

/**
     * Fit the network using truncated BPTT
     */
protected void doTruncatedBPTT(INDArray[] inputs, INDArray[] labels, INDArray[] featureMasks, INDArray[] labelMasks) {
    if (flattenedGradients == null)
        initGradientsView();
    //Approach used here to implement truncated BPTT: if input is 3d, split it. Otherwise: input is unmodified
    int timeSeriesLength = -1;
    for (INDArray in : inputs) {
        if (in.rank() != 3)
            continue;
        if (timeSeriesLength == -1)
            timeSeriesLength = in.size(2);
        else if (timeSeriesLength != in.size(2)) {
            log.warn("Cannot do TBPTT with time series of different lengths");
            return;
        }
    }
    for (INDArray out : labels) {
        if (out.rank() != 3)
            continue;
        if (timeSeriesLength == -1)
            timeSeriesLength = out.size(2);
        else if (timeSeriesLength != out.size(2)) {
            log.warn("Cannot do TBPTT with time series of different lengths");
            return;
        }
    }
    int fwdLen = configuration.getTbpttFwdLength();
    int nSubsets = timeSeriesLength / fwdLen;
    if (timeSeriesLength % fwdLen != 0)
        nSubsets++;
    rnnClearPreviousState();
    INDArray[] newInputs = new INDArray[inputs.length];
    INDArray[] newLabels = new INDArray[labels.length];
    INDArray[] newFeatureMasks = (featureMasks != null ? new INDArray[featureMasks.length] : null);
    INDArray[] newLabelMasks = (labelMasks != null ? new INDArray[labelMasks.length] : null);
    for (int i = 0; i < nSubsets; i++) {
        int startTimeIdx = i * fwdLen;
        int endTimeIdx = startTimeIdx + fwdLen;
        if (endTimeIdx > timeSeriesLength)
            endTimeIdx = timeSeriesLength;
        for (int j = 0; j < inputs.length; j++) {
            if (inputs[j].rank() != 3)
                newInputs[j] = inputs[j];
            else {
                newInputs[j] = inputs[j].get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
            }
        }
        for (int j = 0; j < labels.length; j++) {
            if (labels[j].rank() != 3)
                newLabels[j] = labels[j];
            else {
                newLabels[j] = labels[j].get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
            }
        }
        if (featureMasks != null) {
            for (int j = 0; j < featureMasks.length; j++) {
                if (featureMasks[j] == null)
                    continue;
                newFeatureMasks[j] = featureMasks[j].get(NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
            }
        }
        if (labelMasks != null) {
            for (int j = 0; j < labelMasks.length; j++) {
                if (labelMasks[j] == null)
                    continue;
                newLabelMasks[j] = labelMasks[j].get(NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
            }
        }
        setInputs(newInputs);
        setLabels(newLabels);
        setLayerMaskArrays(newFeatureMasks, newLabelMasks);
        if (solver == null) {
            solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        }
        solver.optimize();
        //Finally, update the state of the RNN layers:
        rnnUpdateStateWithTBPTTState();
    }
    rnnClearPreviousState();
    if (featureMasks != null || labelMasks != null) {
        clearLayerMaskArrays();
    }
}
Also used : Solver(org.deeplearning4j.optimize.Solver) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 3 with Solver

use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.

the class VariationalAutoencoder method fit.

@Override
public void fit() {
    if (input == null) {
        throw new IllegalStateException("Cannot fit layer: layer input is null (not set)");
    }
    if (solver == null) {
        solver = new Solver.Builder().model(this).configure(conf()).listeners(getListeners()).build();
        //Set the updater state view array. For MLN and CG, this is done by MultiLayerUpdater and ComputationGraphUpdater respectively
        Updater updater = solver.getOptimizer().getUpdater();
        int updaterStateSize = updater.stateSizeForLayer(this);
        if (updaterStateSize > 0) {
            updater.setStateViewArray(this, Nd4j.createUninitialized(new int[] { 1, updaterStateSize }, Nd4j.order()), true);
        }
    }
    this.optimizer = solver.getOptimizer();
    solver.optimize();
}
Also used : Solver(org.deeplearning4j.optimize.Solver) Updater(org.deeplearning4j.nn.api.Updater)

Example 4 with Solver

use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.

the class BaseOutputLayer method fit.

/**
     * Fit the model
     *
     * @param input the examples to classify (one example in each row)
     * @param labels   the example labels(a binary outcome matrix)
     */
@Override
public void fit(INDArray input, INDArray labels) {
    setInput(input);
    setLabels(labels);
    applyDropOutIfNecessary(true);
    if (solver == null) {
        solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        //Set the updater state view array. For MLN and CG, this is done by MultiLayerUpdater and ComputationGraphUpdater respectively
        Updater updater = solver.getOptimizer().getUpdater();
        int updaterStateSize = updater.stateSizeForLayer(this);
        if (updaterStateSize > 0)
            updater.setStateViewArray(this, Nd4j.createUninitialized(new int[] { 1, updaterStateSize }, Nd4j.order()), true);
    }
    solver.optimize();
}
Also used : Solver(org.deeplearning4j.optimize.Solver) Updater(org.deeplearning4j.nn.api.Updater)

Example 5 with Solver

use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.

the class BaseLayer method fit.

@Override
public void fit(INDArray input) {
    if (input != null) {
        setInput(input.dup());
        applyDropOutIfNecessary(true);
    }
    if (solver == null) {
        solver = new Solver.Builder().model(this).configure(conf()).listeners(getListeners()).build();
        //Set the updater state view array. For MLN and CG, this is done by MultiLayerUpdater and ComputationGraphUpdater respectively
        Updater updater = solver.getOptimizer().getUpdater();
        int updaterStateSize = updater.stateSizeForLayer(this);
        if (updaterStateSize > 0)
            updater.setStateViewArray(this, Nd4j.createUninitialized(new int[] { 1, updaterStateSize }, Nd4j.order()), true);
    }
    this.optimizer = solver.getOptimizer();
    solver.optimize();
}
Also used : Solver(org.deeplearning4j.optimize.Solver) Updater(org.deeplearning4j.nn.api.Updater)

Aggregations

Solver (org.deeplearning4j.optimize.Solver)12 Updater (org.deeplearning4j.nn.api.Updater)4 AsyncMultiDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)3 SingletonMultiDataSetIterator (org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)3 AsyncDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncDataSetIterator)2 TrainingListener (org.deeplearning4j.optimize.api.TrainingListener)2 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)2 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)2 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)1 DataSet (org.nd4j.linalg.dataset.DataSet)1 DataSet (org.nd4j.linalg.dataset.api.DataSet)1