use of org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class ComputationGraph method fit.
/**
* Fit the ComputationGraph using the specified inputs and labels (and mask arrays)
*
* @param inputs The network inputs (features)
* @param labels The network labels
* @param featureMaskArrays Mask arrays for inputs/features. Typically used for RNN training. May be null.
* @param labelMaskArrays Mas arrays for the labels/outputs. Typically used for RNN training. May be null.
*/
public void fit(INDArray[] inputs, INDArray[] labels, INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
if (flattenedGradients == null)
initGradientsView();
setInputs(inputs);
setLabels(labels);
setLayerMaskArrays(featureMaskArrays, labelMaskArrays);
update(TaskUtils.buildTask(inputs, labels));
if (configuration.isPretrain()) {
MultiDataSetIterator iter = new SingletonMultiDataSetIterator(new org.nd4j.linalg.dataset.MultiDataSet(inputs, labels, featureMaskArrays, labelMaskArrays));
pretrain(iter);
}
if (configuration.isBackprop()) {
if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(inputs, labels, featureMaskArrays, labelMaskArrays);
} else {
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
}
solver.optimize();
}
}
if (featureMaskArrays != null || labelMaskArrays != null) {
clearLayerMaskArrays();
}
}
Aggregations