use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.
the class MultiLayerNetwork method fit.
@Override
public void fit(DataSetIterator iterator) {
DataSetIterator iter;
// we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
if (iterator.asyncSupported()) {
iter = new AsyncDataSetIterator(iterator, 2);
} else {
iter = iterator;
}
if (trainingListeners.size() > 0) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochStart(this);
}
}
if (layerWiseConfigurations.isPretrain()) {
pretrain(iter);
if (iter.resetSupported()) {
iter.reset();
}
// while (iter.hasNext()) {
// DataSet next = iter.next();
// if (next.getFeatureMatrix() == null || next.getLabels() == null)
// break;
// setInput(next.getFeatureMatrix());
// setLabels(next.getLabels());
// finetune();
// }
}
if (layerWiseConfigurations.isBackprop()) {
update(TaskUtils.buildTask(iter));
if (!iter.hasNext() && iter.resetSupported()) {
iter.reset();
}
while (iter.hasNext()) {
DataSet next = iter.next();
if (next.getFeatureMatrix() == null || next.getLabels() == null)
break;
boolean hasMaskArrays = next.hasMaskArrays();
if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(next.getFeatureMatrix(), next.getLabels(), next.getFeaturesMaskArray(), next.getLabelsMaskArray());
} else {
if (hasMaskArrays)
setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
setInput(next.getFeatureMatrix());
setLabels(next.getLabels());
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
}
solver.optimize();
}
if (hasMaskArrays)
clearLayerMaskArrays();
Nd4j.getMemoryManager().invokeGcOccasionally();
}
} else if (layerWiseConfigurations.isPretrain()) {
log.warn("Warning: finetune is not applied.");
}
if (trainingListeners.size() > 0) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochEnd(this);
}
}
}
use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.
the class ComputationGraph method fit.
/**
* Fit the ComputationGraph using the specified inputs and labels (and mask arrays)
*
* @param inputs The network inputs (features)
* @param labels The network labels
* @param featureMaskArrays Mask arrays for inputs/features. Typically used for RNN training. May be null.
* @param labelMaskArrays Mas arrays for the labels/outputs. Typically used for RNN training. May be null.
*/
public void fit(INDArray[] inputs, INDArray[] labels, INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
if (flattenedGradients == null)
initGradientsView();
setInputs(inputs);
setLabels(labels);
setLayerMaskArrays(featureMaskArrays, labelMaskArrays);
update(TaskUtils.buildTask(inputs, labels));
if (configuration.isPretrain()) {
MultiDataSetIterator iter = new SingletonMultiDataSetIterator(new org.nd4j.linalg.dataset.MultiDataSet(inputs, labels, featureMaskArrays, labelMaskArrays));
pretrain(iter);
}
if (configuration.isBackprop()) {
if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(inputs, labels, featureMaskArrays, labelMaskArrays);
} else {
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
}
solver.optimize();
}
}
if (featureMaskArrays != null || labelMaskArrays != null) {
clearLayerMaskArrays();
}
}
use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.
the class ComputationGraph method getUpdater.
/**
* Get the ComputationGraphUpdater for the network
*/
public ComputationGraphUpdater getUpdater() {
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this));
}
return solver.getOptimizer().getComputationGraphUpdater();
}
use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.
the class BaseLayer method getOptimizer.
@Override
public ConvexOptimizer getOptimizer() {
if (optimizer == null) {
Solver solver = new Solver.Builder().model(this).configure(conf()).build();
this.optimizer = solver.getOptimizer();
}
return optimizer;
}
use of org.deeplearning4j.optimize.Solver in project deeplearning4j by deeplearning4j.
the class ComputationGraph method fit.
/**
* Fit the ComputationGraph using a DataSetIterator.
* Note that this method can only be used with ComputationGraphs with 1 input and 1 output
*/
public void fit(DataSetIterator iterator) {
if (flattenedGradients == null)
initGradientsView();
if (numInputArrays != 1 || numOutputArrays != 1)
throw new UnsupportedOperationException("Cannot train ComputationGraph network with " + " multiple inputs or outputs using a DataSetIterator");
DataSetIterator dataSetIterator;
// we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
if (iterator.asyncSupported()) {
dataSetIterator = new AsyncDataSetIterator(iterator, 2);
} else
dataSetIterator = iterator;
if (trainingListeners.size() > 0) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochStart(this);
}
}
if (configuration.isPretrain()) {
pretrain(dataSetIterator);
}
if (configuration.isBackprop()) {
update(TaskUtils.buildTask(dataSetIterator));
while (dataSetIterator.hasNext()) {
DataSet next = dataSetIterator.next();
if (next.getFeatures() == null || next.getLabels() == null)
break;
boolean hasMaskArrays = next.hasMaskArrays();
if (hasMaskArrays) {
INDArray[] fMask = (next.getFeaturesMaskArray() != null ? new INDArray[] { next.getFeaturesMaskArray() } : null);
INDArray[] lMask = (next.getLabelsMaskArray() != null ? new INDArray[] { next.getLabelsMaskArray() } : null);
setLayerMaskArrays(fMask, lMask);
}
if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(new INDArray[] { next.getFeatures() }, new INDArray[] { next.getLabels() }, (hasMaskArrays ? new INDArray[] { next.getFeaturesMaskArray() } : null), (hasMaskArrays ? new INDArray[] { next.getLabelsMaskArray() } : null));
} else {
setInput(0, next.getFeatures());
setLabel(0, next.getLabels());
if (solver == null) {
solver = //TODO; don't like this
new Solver.Builder().configure(defaultConfiguration).listeners(listeners).model(this).build();
}
solver.optimize();
}
if (hasMaskArrays) {
clearLayerMaskArrays();
}
Nd4j.getMemoryManager().invokeGcOccasionally();
}
}
if (trainingListeners.size() > 0) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochEnd(this);
}
}
}
Aggregations