use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.
the class MultiLayerNetwork method doTruncatedBPTT.
protected void doTruncatedBPTT(INDArray input, INDArray labels, INDArray featuresMaskArray, INDArray labelsMaskArray) {
if (input.rank() != 3 || labels.rank() != 3) {
log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got " + Arrays.toString(input.shape()) + "\tand labels with shape " + Arrays.toString(labels.shape()));
return;
}
if (input.size(2) != labels.size(2)) {
log.warn("Input and label time series have different lengths: {} input length, {} label length", input.size(2), labels.size(2));
return;
}
int fwdLen = layerWiseConfigurations.getTbpttFwdLength();
update(TaskUtils.buildTask(input, labels));
int timeSeriesLength = input.size(2);
int nSubsets = timeSeriesLength / fwdLen;
if (timeSeriesLength % fwdLen != 0)
//Example: 100 fwdLen with timeSeriesLength=100 -> want 2 subsets (1 of size 100, 1 of size 20)
nSubsets++;
rnnClearPreviousState();
for (int i = 0; i < nSubsets; i++) {
int startTimeIdx = i * fwdLen;
int endTimeIdx = startTimeIdx + fwdLen;
if (endTimeIdx > timeSeriesLength)
endTimeIdx = timeSeriesLength;
INDArray inputSubset = input.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
INDArray labelSubset = labels.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
setInput(inputSubset);
setLabels(labelSubset);
INDArray featuresMaskSubset = null;
INDArray labelsMaskSubset = null;
if (featuresMaskArray != null) {
featuresMaskSubset = featuresMaskArray.get(NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
}
if (labelsMaskArray != null) {
labelsMaskSubset = labelsMaskArray.get(NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
}
if (featuresMaskSubset != null || labelsMaskSubset != null)
setLayerMaskArrays(featuresMaskSubset, labelsMaskSubset);
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
}
solver.optimize();
//Finally, update the state of the RNN layers:
updateRnnStateWithTBPTTState();
}
rnnClearPreviousState();
if (featuresMaskArray != null || labelsMaskArray != null)
clearLayerMaskArrays();
}
use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.
the class LossLayer method fit.
/**
* Fit the model
*
* @param input the examples to classify (one example in each row)
* @param labels the example labels(a binary outcome matrix)
*/
@Override
public void fit(INDArray input, INDArray labels) {
setInput(input);
setLabels(labels);
applyDropOutIfNecessary(true);
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
//Set the updater state view array. For MLN and CG, this is done by MultiLayerUpdater and ComputationGraphUpdater respectively
Updater updater = solver.getOptimizer().getUpdater();
int updaterStateSize = updater.stateSizeForLayer(this);
if (updaterStateSize > 0)
updater.setStateViewArray(this, Nd4j.createUninitialized(new int[] { 1, updaterStateSize }, Nd4j.order()), true);
}
solver.optimize();
}
Aggregations