use of org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator in project ignite by apache.
the class LocalBatchTrainer method train.
/**
* {@inheritDoc}
*/
@Override
public M train(LocalBatchTrainerInput<M> data) {
int i = 0;
M mdl = data.mdl();
double err;
ParameterUpdateCalculator<? super M, P> updater = updaterSupplier.get();
P updaterParams = updater.init(mdl, loss);
while (i < maxIterations) {
IgniteBiTuple<Matrix, Matrix> batch = data.batchSupplier().get();
Matrix input = batch.get1();
Matrix truth = batch.get2();
updaterParams = updater.calculateNewUpdate(mdl, updaterParams, i, input, truth);
// Update mdl with updater parameters.
mdl = updater.update(mdl, updaterParams);
Matrix predicted = mdl.apply(input);
int batchSize = input.columnSize();
err = MatrixUtil.zipFoldByColumns(predicted, truth, (predCol, truthCol) -> loss.apply(truthCol).apply(predCol)).sum() / batchSize;
debug("Error: " + err);
if (err < errorThreshold)
break;
i++;
}
return mdl;
}
Aggregations