Search in sources :

Example 1 with ParameterUpdateCalculator

use of org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator in project ignite by apache.

the class LocalBatchTrainer method train.

/**
 * {@inheritDoc}
 */
@Override
public M train(LocalBatchTrainerInput<M> data) {
    int i = 0;
    M mdl = data.mdl();
    double err;
    ParameterUpdateCalculator<? super M, P> updater = updaterSupplier.get();
    P updaterParams = updater.init(mdl, loss);
    while (i < maxIterations) {
        IgniteBiTuple<Matrix, Matrix> batch = data.batchSupplier().get();
        Matrix input = batch.get1();
        Matrix truth = batch.get2();
        updaterParams = updater.calculateNewUpdate(mdl, updaterParams, i, input, truth);
        // Update mdl with updater parameters.
        mdl = updater.update(mdl, updaterParams);
        Matrix predicted = mdl.apply(input);
        int batchSize = input.columnSize();
        err = MatrixUtil.zipFoldByColumns(predicted, truth, (predCol, truthCol) -> loss.apply(truthCol).apply(predCol)).sum() / batchSize;
        debug("Error: " + err);
        if (err < errorThreshold)
            break;
        i++;
    }
    return mdl;
}
Also used : Trainer(org.apache.ignite.ml.Trainer) ParameterUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator) Vector(org.apache.ignite.ml.math.Vector) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) Model(org.apache.ignite.ml.Model) Matrix(org.apache.ignite.ml.math.Matrix) IgniteLogger(org.apache.ignite.IgniteLogger) IgniteDifferentiableVectorToDoubleFunction(org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction) MatrixUtil(org.apache.ignite.ml.math.util.MatrixUtil) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) Matrix(org.apache.ignite.ml.math.Matrix)

Aggregations

IgniteLogger (org.apache.ignite.IgniteLogger)1 IgniteBiTuple (org.apache.ignite.lang.IgniteBiTuple)1 Model (org.apache.ignite.ml.Model)1 Trainer (org.apache.ignite.ml.Trainer)1 Matrix (org.apache.ignite.ml.math.Matrix)1 Vector (org.apache.ignite.ml.math.Vector)1 IgniteDifferentiableVectorToDoubleFunction (org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction)1 IgniteFunction (org.apache.ignite.ml.math.functions.IgniteFunction)1 IgniteSupplier (org.apache.ignite.ml.math.functions.IgniteSupplier)1 MatrixUtil (org.apache.ignite.ml.math.util.MatrixUtil)1 ParameterUpdateCalculator (org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator)1