Search in sources :

Example 1 with IgniteModel

use of org.apache.ignite.ml.IgniteModel in project ignite by apache.

the class GDBLearningStrategy method update.

/**
 * Gets state of model in arguments, compare it with training parameters of trainer and if they are fit then trainer
 * updates model in according to new data and return new model. In other case trains new model.
 *
 * @param mdlToUpdate Learned model.
 * @param datasetBuilder Dataset builder.
 * @param preprocessor Upstream preprocessor.
 * @param <K> Type of a key in {@code upstream} data.
 * @param <V> Type of a value in {@code upstream} data.
 * @return Updated models list.
 */
public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
    if (trainerEnvironment == null)
        throw new IllegalStateException("Learning environment builder is not set.");
    List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
    ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, preprocessor);
    DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
    for (int i = 0; i < cntOfIterations; i++) {
        double[] weights = Arrays.copyOf(compositionWeights, models.size());
        WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
        ModelsComposition currComposition = new ModelsComposition(models, aggregator);
        if (convCheck.isConverged(envBuilder, datasetBuilder, currComposition))
            break;
        Vectorizer<K, V, Serializable, Double> extractor = new Vectorizer.VectorizerAdapter<K, V, Serializable, Double>() {

            /**
             * {@inheritDoc}
             */
            @Override
            public LabeledVector<Double> extract(K k, V v) {
                LabeledVector<Double> labeledVector = preprocessor.apply(k, v);
                Vector features = labeledVector.features();
                Double realAnswer = externalLbToInternalMapping.apply(labeledVector.label());
                Double mdlAnswer = currComposition.predict(features);
                return new LabeledVector<>(features, -loss.gradient(sampleSize, realAnswer, mdlAnswer));
            }
        };
        long startTs = System.currentTimeMillis();
        models.add(trainer.fit(datasetBuilder, extractor));
        double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
        trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
    }
    return models;
}
Also used : Serializable(java.io.Serializable) WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) ModelsComposition(org.apache.ignite.ml.composition.ModelsComposition) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) IgniteModel(org.apache.ignite.ml.IgniteModel) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector)

Example 2 with IgniteModel

use of org.apache.ignite.ml.IgniteModel in project ignite by apache.

the class BinaryClassificationMetricsTest method testCalculation.

/**
 */
@Test
public void testCalculation() {
    Map<Vector, Double> xorset = new HashMap<Vector, Double>() {

        {
            put(VectorUtils.of(0., 0.), 0.);
            put(VectorUtils.of(0., 1.), 1.);
            put(VectorUtils.of(1., 0.), 1.);
            put(VectorUtils.of(1., 1.), 0.);
        }
    };
    IgniteModel<Vector, Double> xorFunction = v -> {
        if (Math.abs(v.get(0) - v.get(1)) < 0.01)
            return 0.;
        else
            return 1.;
    };
    IgniteModel<Vector, Double> andFunction = v -> {
        if (Math.abs(v.get(0) - v.get(1)) < 0.01 && v.get(0) > 0)
            return 1.;
        else
            return 0.;
    };
    IgniteModel<Vector, Double> orFunction = v -> {
        if (v.get(0) > 0 || v.get(1) > 0)
            return 1.;
        else
            return 0.;
    };
    EvaluationResult xorResult = Evaluator.evaluateBinaryClassification(xorset, xorFunction, Vector::labeled);
    assertEquals(1., xorResult.get(MetricName.ACCURACY), 0.01);
    assertEquals(1., xorResult.get(MetricName.PRECISION), 0.01);
    assertEquals(1., xorResult.get(MetricName.RECALL), 0.01);
    assertEquals(1., xorResult.get(MetricName.F_MEASURE), 0.01);
    EvaluationResult andResult = Evaluator.evaluateBinaryClassification(xorset, andFunction, Vector::labeled);
    assertEquals(0.25, andResult.get(MetricName.ACCURACY), 0.01);
    // there is no TP
    assertEquals(0., andResult.get(MetricName.PRECISION), 0.01);
    // there is no TP
    assertEquals(0., andResult.get(MetricName.RECALL), 0.01);
    // // there is no TP and zero in denominator
    assertEquals(Double.NaN, andResult.get(MetricName.F_MEASURE), 0.01);
    EvaluationResult orResult = Evaluator.evaluateBinaryClassification(xorset, orFunction, Vector::labeled);
    assertEquals(0.75, orResult.get(MetricName.ACCURACY), 0.01);
    // there is no TP
    assertEquals(0.66, orResult.get(MetricName.PRECISION), 0.01);
    // there is no TP
    assertEquals(1., orResult.get(MetricName.RECALL), 0.01);
    // // there is no TP and zero in denominator
    assertEquals(0.8, orResult.get(MetricName.F_MEASURE), 0.01);
}
Also used : VectorUtils(org.apache.ignite.ml.math.primitives.vector.VectorUtils) Evaluator(org.apache.ignite.ml.selection.scoring.evaluator.Evaluator) MetricName(org.apache.ignite.ml.selection.scoring.metric.MetricName) Map(java.util.Map) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) EvaluationResult(org.apache.ignite.ml.selection.scoring.evaluator.EvaluationResult) HashMap(java.util.HashMap) Test(org.junit.Test) IgniteModel(org.apache.ignite.ml.IgniteModel) Assert.assertEquals(org.junit.Assert.assertEquals) HashMap(java.util.HashMap) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) EvaluationResult(org.apache.ignite.ml.selection.scoring.evaluator.EvaluationResult) Test(org.junit.Test)

Example 3 with IgniteModel

use of org.apache.ignite.ml.IgniteModel in project ignite by apache.

the class RegressionMetricsTest method testCalculation.

/**
 */
@Test
public void testCalculation() {
    Map<Vector, Double> linearSet = new HashMap<Vector, Double>() {

        {
            put(VectorUtils.of(0.), 0.);
            put(VectorUtils.of(1.), 1.);
            put(VectorUtils.of(2.), 2.);
            put(VectorUtils.of(3.), 3.);
        }
    };
    IgniteModel<Vector, Double> linearModel = v -> v.get(0);
    IgniteModel<Vector, Double> squareModel = v -> Math.pow(v.get(0), 2);
    EvaluationResult linearRes = Evaluator.evaluateRegression(linearSet, linearModel, Vector::labeled);
    assertEquals(0., linearRes.get(MetricName.MAE), 0.01);
    assertEquals(0., linearRes.get(MetricName.MSE), 0.01);
    assertEquals(0., linearRes.get(MetricName.R2), 0.01);
    assertEquals(0., linearRes.get(MetricName.RSS), 0.01);
    assertEquals(0., linearRes.get(MetricName.RMSE), 0.01);
    EvaluationResult squareRes = Evaluator.evaluateRegression(linearSet, squareModel, Vector::labeled);
    assertEquals(2., squareRes.get(MetricName.MAE), 0.01);
    assertEquals(10., squareRes.get(MetricName.MSE), 0.01);
    assertEquals(8., squareRes.get(MetricName.R2), 0.01);
    assertEquals(40., squareRes.get(MetricName.RSS), 0.01);
    assertEquals(Math.sqrt(10), squareRes.get(MetricName.RMSE), 0.01);
}
Also used : VectorUtils(org.apache.ignite.ml.math.primitives.vector.VectorUtils) Evaluator(org.apache.ignite.ml.selection.scoring.evaluator.Evaluator) MetricName(org.apache.ignite.ml.selection.scoring.metric.MetricName) Map(java.util.Map) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) EvaluationResult(org.apache.ignite.ml.selection.scoring.evaluator.EvaluationResult) HashMap(java.util.HashMap) Test(org.junit.Test) IgniteModel(org.apache.ignite.ml.IgniteModel) Assert.assertEquals(org.junit.Assert.assertEquals) HashMap(java.util.HashMap) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) EvaluationResult(org.apache.ignite.ml.selection.scoring.evaluator.EvaluationResult) Test(org.junit.Test)

Example 4 with IgniteModel

use of org.apache.ignite.ml.IgniteModel in project ignite by apache.

the class ModelStorageExample method main.

/**
 * Run example.
 */
public static void main(String... args) throws IOException, ClassNotFoundException {
    try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) {
        System.out.println(">>> Ignite grid started.");
        ModelStorage storage = new ModelStorageFactory().getModelStorage(ignite);
        ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
        System.out.println("Saving model into model storage...");
        byte[] mdl = serialize((IgniteModel<byte[], byte[]>) i -> i);
        storage.mkdirs("/");
        storage.putFile("/my_model", mdl);
        System.out.println("Saving model descriptor into model descriptor storage...");
        ModelDescriptor desc = new ModelDescriptor("MyModel", "My Cool Model", new ModelSignature("", "", ""), new ModelStorageModelReader("/my_model"), new IgniteModelParser<>());
        descStorage.put("my_model", desc);
        System.out.println("List saved models...");
        for (IgniteBiTuple<String, ModelDescriptor> model : descStorage) System.out.println("-> {'" + model.getKey() + "' : " + model.getValue() + "}");
        System.out.println("Load saved model descriptor...");
        desc = descStorage.get("my_model");
        System.out.println("Build inference model...");
        SingleModelBuilder mdlBuilder = new SingleModelBuilder();
        try (Model<byte[], byte[]> infMdl = mdlBuilder.build(desc.getReader(), desc.getParser())) {
            System.out.println("Make inference...");
            for (int i = 0; i < 10; i++) {
                Integer res = deserialize(infMdl.predict(serialize(i)));
                System.out.println(i + " -> " + res);
            }
        }
    } finally {
        System.out.flush();
    }
}
Also used : ModelStorageModelReader(org.apache.ignite.ml.inference.reader.ModelStorageModelReader) ByteArrayOutputStream(java.io.ByteArrayOutputStream) ModelDescriptor(org.apache.ignite.ml.inference.ModelDescriptor) ObjectInputStream(java.io.ObjectInputStream) IOException(java.io.IOException) ModelSignature(org.apache.ignite.ml.inference.ModelSignature) Ignite(org.apache.ignite.Ignite) IgniteModel(org.apache.ignite.ml.IgniteModel) ModelDescriptorStorageFactory(org.apache.ignite.ml.inference.storage.descriptor.ModelDescriptorStorageFactory) Serializable(java.io.Serializable) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) SingleModelBuilder(org.apache.ignite.ml.inference.builder.SingleModelBuilder) Ignition(org.apache.ignite.Ignition) ModelStorage(org.apache.ignite.ml.inference.storage.model.ModelStorage) ByteArrayInputStream(java.io.ByteArrayInputStream) Model(org.apache.ignite.ml.inference.Model) IgniteModelParser(org.apache.ignite.ml.inference.parser.IgniteModelParser) ModelDescriptorStorage(org.apache.ignite.ml.inference.storage.descriptor.ModelDescriptorStorage) ObjectOutputStream(java.io.ObjectOutputStream) ModelStorageFactory(org.apache.ignite.ml.inference.storage.model.ModelStorageFactory) ModelStorage(org.apache.ignite.ml.inference.storage.model.ModelStorage) SingleModelBuilder(org.apache.ignite.ml.inference.builder.SingleModelBuilder) ModelStorageModelReader(org.apache.ignite.ml.inference.reader.ModelStorageModelReader) ModelStorageFactory(org.apache.ignite.ml.inference.storage.model.ModelStorageFactory) ModelDescriptorStorage(org.apache.ignite.ml.inference.storage.descriptor.ModelDescriptorStorage) ModelDescriptor(org.apache.ignite.ml.inference.ModelDescriptor) Ignite(org.apache.ignite.Ignite) ModelDescriptorStorageFactory(org.apache.ignite.ml.inference.storage.descriptor.ModelDescriptorStorageFactory) ModelSignature(org.apache.ignite.ml.inference.ModelSignature)

Example 5 with IgniteModel

use of org.apache.ignite.ml.IgniteModel in project ignite by apache.

the class GDBOnTreesLearningStrategy method update.

/**
 * {@inheritDoc}
 */
@Override
public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
    LearningEnvironment environment = envBuilder.buildForTrainer();
    environment.initDeployingContext(vectorizer);
    DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
    assert trainer instanceof DecisionTreeTrainer;
    DecisionTreeTrainer decisionTreeTrainer = (DecisionTreeTrainer) trainer;
    List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
    ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, vectorizer);
    try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(vectorizer, useIdx), environment)) {
        for (int i = 0; i < cntOfIterations; i++) {
            double[] weights = Arrays.copyOf(compositionWeights, models.size());
            WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
            ModelsComposition currComposition = new ModelsComposition(models, aggregator);
            if (convCheck.isConverged(dataset, currComposition))
                break;
            dataset.compute(part -> {
                if (part.getCopiedOriginalLabels() == null)
                    part.setCopiedOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length));
                for (int j = 0; j < part.getLabels().length; j++) {
                    double mdlAnswer = currComposition.predict(VectorUtils.of(part.getFeatures()[j]));
                    double originalLbVal = externalLbToInternalMapping.apply(part.getCopiedOriginalLabels()[j]);
                    part.getLabels()[j] = -loss.gradient(sampleSize, originalLbVal, mdlAnswer);
                }
            });
            long startTs = System.currentTimeMillis();
            models.add(decisionTreeTrainer.fit(dataset));
            double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
            trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
        }
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    compositionWeights = Arrays.copyOf(compositionWeights, models.size());
    return models;
}
Also used : EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) ModelsComposition(org.apache.ignite.ml.composition.ModelsComposition) DecisionTreeTrainer(org.apache.ignite.ml.tree.DecisionTreeTrainer) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) DecisionTreeData(org.apache.ignite.ml.tree.data.DecisionTreeData) IgniteModel(org.apache.ignite.ml.IgniteModel) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Aggregations

IgniteModel (org.apache.ignite.ml.IgniteModel)10 WeightedPredictionsAggregator (org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator)5 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)5 IOException (java.io.IOException)3 ArrayList (java.util.ArrayList)3 HashMap (java.util.HashMap)3 Map (java.util.Map)3 Serializable (java.io.Serializable)2 TreeMap (java.util.TreeMap)2 Configuration (org.apache.hadoop.conf.Configuration)2 Path (org.apache.hadoop.fs.Path)2 ModelsComposition (org.apache.ignite.ml.composition.ModelsComposition)2 EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)2 VectorUtils (org.apache.ignite.ml.math.primitives.vector.VectorUtils)2 EvaluationResult (org.apache.ignite.ml.selection.scoring.evaluator.EvaluationResult)2 Evaluator (org.apache.ignite.ml.selection.scoring.evaluator.Evaluator)2 MetricName (org.apache.ignite.ml.selection.scoring.metric.MetricName)2 NodeData (org.apache.ignite.ml.tree.NodeData)2 PageReadStore (org.apache.parquet.column.page.PageReadStore)2 SimpleGroup (org.apache.parquet.example.data.simple.SimpleGroup)2