use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.
the class ComputationGraph method fit.
/**
* Fit the ComputationGraph using a MultiDataSetIterator
*/
public void fit(MultiDataSetIterator multi) {
if (flattenedGradients == null)
initGradientsView();
MultiDataSetIterator multiDataSetIterator;
if (multi.asyncSupported()) {
multiDataSetIterator = new AsyncMultiDataSetIterator(multi, 2);
} else
multiDataSetIterator = multi;
if (configuration.isPretrain()) {
pretrain(multiDataSetIterator);
}
if (configuration.isBackprop()) {
while (multiDataSetIterator.hasNext()) {
MultiDataSet next = multiDataSetIterator.next();
if (next.getFeatures() == null || next.getLabels() == null)
break;
if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
} else {
boolean hasMaskArrays = next.hasMaskArrays();
if (hasMaskArrays) {
setLayerMaskArrays(next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
}
setInputs(next.getFeatures());
setLabels(next.getLabels());
if (solver == null) {
solver = new Solver.Builder().configure(defaultConfiguration).listeners(listeners).model(this).build();
}
solver.optimize();
if (hasMaskArrays) {
clearLayerMaskArrays();
}
}
Nd4j.getMemoryManager().invokeGcOccasionally();
}
}
}
use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.
the class ComputationGraph method doTruncatedBPTT.
/**
* Fit the network using truncated BPTT
*/
protected void doTruncatedBPTT(INDArray[] inputs, INDArray[] labels, INDArray[] featureMasks, INDArray[] labelMasks) {
if (flattenedGradients == null)
initGradientsView();
//Approach used here to implement truncated BPTT: if input is 3d, split it. Otherwise: input is unmodified
int timeSeriesLength = -1;
for (INDArray in : inputs) {
if (in.rank() != 3)
continue;
if (timeSeriesLength == -1)
timeSeriesLength = in.size(2);
else if (timeSeriesLength != in.size(2)) {
log.warn("Cannot do TBPTT with time series of different lengths");
return;
}
}
for (INDArray out : labels) {
if (out.rank() != 3)
continue;
if (timeSeriesLength == -1)
timeSeriesLength = out.size(2);
else if (timeSeriesLength != out.size(2)) {
log.warn("Cannot do TBPTT with time series of different lengths");
return;
}
}
int fwdLen = configuration.getTbpttFwdLength();
int nSubsets = timeSeriesLength / fwdLen;
if (timeSeriesLength % fwdLen != 0)
nSubsets++;
rnnClearPreviousState();
INDArray[] newInputs = new INDArray[inputs.length];
INDArray[] newLabels = new INDArray[labels.length];
INDArray[] newFeatureMasks = (featureMasks != null ? new INDArray[featureMasks.length] : null);
INDArray[] newLabelMasks = (labelMasks != null ? new INDArray[labelMasks.length] : null);
for (int i = 0; i < nSubsets; i++) {
int startTimeIdx = i * fwdLen;
int endTimeIdx = startTimeIdx + fwdLen;
if (endTimeIdx > timeSeriesLength)
endTimeIdx = timeSeriesLength;
for (int j = 0; j < inputs.length; j++) {
if (inputs[j].rank() != 3)
newInputs[j] = inputs[j];
else {
newInputs[j] = inputs[j].get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
}
}
for (int j = 0; j < labels.length; j++) {
if (labels[j].rank() != 3)
newLabels[j] = labels[j];
else {
newLabels[j] = labels[j].get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
}
}
if (featureMasks != null) {
for (int j = 0; j < featureMasks.length; j++) {
if (featureMasks[j] == null)
continue;
newFeatureMasks[j] = featureMasks[j].get(NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
}
}
if (labelMasks != null) {
for (int j = 0; j < labelMasks.length; j++) {
if (labelMasks[j] == null)
continue;
newLabelMasks[j] = labelMasks[j].get(NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx));
}
}
setInputs(newInputs);
setLabels(newLabels);
setLayerMaskArrays(newFeatureMasks, newLabelMasks);
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
}
solver.optimize();
//Finally, update the state of the RNN layers:
rnnUpdateStateWithTBPTTState();
}
rnnClearPreviousState();
if (featureMasks != null || labelMasks != null) {
clearLayerMaskArrays();
}
}
use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.
the class VariationalAutoencoder method fit.
@Override
public void fit() {
if (input == null) {
throw new IllegalStateException("Cannot fit layer: layer input is null (not set)");
}
if (solver == null) {
solver = new Solver.Builder().model(this).configure(conf()).listeners(getListeners()).build();
//Set the updater state view array. For MLN and CG, this is done by MultiLayerUpdater and ComputationGraphUpdater respectively
Updater updater = solver.getOptimizer().getUpdater();
int updaterStateSize = updater.stateSizeForLayer(this);
if (updaterStateSize > 0) {
updater.setStateViewArray(this, Nd4j.createUninitialized(new int[] { 1, updaterStateSize }, Nd4j.order()), true);
}
}
this.optimizer = solver.getOptimizer();
solver.optimize();
}
use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.
the class BaseOutputLayer method fit.
/**
* Fit the model
*
* @param input the examples to classify (one example in each row)
* @param labels the example labels(a binary outcome matrix)
*/
@Override
public void fit(INDArray input, INDArray labels) {
setInput(input);
setLabels(labels);
applyDropOutIfNecessary(true);
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
//Set the updater state view array. For MLN and CG, this is done by MultiLayerUpdater and ComputationGraphUpdater respectively
Updater updater = solver.getOptimizer().getUpdater();
int updaterStateSize = updater.stateSizeForLayer(this);
if (updaterStateSize > 0)
updater.setStateViewArray(this, Nd4j.createUninitialized(new int[] { 1, updaterStateSize }, Nd4j.order()), true);
}
solver.optimize();
}
use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.
the class BaseLayer method fit.
@Override
public void fit(INDArray input) {
if (input != null) {
setInput(input.dup());
applyDropOutIfNecessary(true);
}
if (solver == null) {
solver = new Solver.Builder().model(this).configure(conf()).listeners(getListeners()).build();
//Set the updater state view array. For MLN and CG, this is done by MultiLayerUpdater and ComputationGraphUpdater respectively
Updater updater = solver.getOptimizer().getUpdater();
int updaterStateSize = updater.stateSizeForLayer(this);
if (updaterStateSize > 0)
updater.setStateViewArray(this, Nd4j.createUninitialized(new int[] { 1, updaterStateSize }, Nd4j.order()), true);
}
this.optimizer = solver.getOptimizer();
solver.optimize();
}
Aggregations