Search in sources :

Example 1 with WeightedPredictionsAggregator

use of org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator in project ignite by apache.

the class GDBTrainerTest method testClassifier.

/**
 */
private void testClassifier(BiFunction<GDBTrainer, Map<Integer, double[]>, IgniteModel<Vector, Double>> fitter) {
    int sampleSize = 100;
    double[] xs = new double[sampleSize];
    double[] ys = new double[sampleSize];
    for (int i = 0; i < sampleSize; i++) {
        xs[i] = i;
        ys[i] = ((int) (xs[i] / 10.0) % 2) == 0 ? -1.0 : 1.0;
    }
    Map<Integer, double[]> learningSample = new HashMap<>();
    for (int i = 0; i < sampleSize; i++) learningSample.put(i, new double[] { xs[i], ys[i] });
    GDBTrainer trainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0).withUsingIdx(true).withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.3));
    IgniteModel<Vector, Double> mdl = fitter.apply(trainer, learningSample);
    int errorsCnt = 0;
    for (int j = 0; j < sampleSize; j++) {
        double x = xs[j];
        double y = ys[j];
        double p = mdl.predict(VectorUtils.of(x));
        if (p != y)
            errorsCnt++;
    }
    assertEquals(0, errorsCnt);
    assertTrue(mdl instanceof ModelsComposition);
    ModelsComposition composition = (ModelsComposition) mdl;
    composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeModel));
    assertTrue(composition.getModels().size() < 500);
    assertTrue(composition.getPredictionsAggregator() instanceof WeightedPredictionsAggregator);
    trainer = trainer.withCheckConvergenceStgyFactory(new ConvergenceCheckerStubFactory());
    assertEquals(500, ((ModelsComposition) fitter.apply(trainer, learningSample)).getModels().size());
}
Also used : HashMap(java.util.HashMap) DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) ModelsComposition(org.apache.ignite.ml.composition.ModelsComposition) GDBBinaryClassifierOnTreesTrainer(org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer) MeanAbsValueConvergenceCheckerFactory(org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory) ConvergenceCheckerStubFactory(org.apache.ignite.ml.composition.boosting.convergence.simple.ConvergenceCheckerStubFactory) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Example 2 with WeightedPredictionsAggregator

use of org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator in project ignite by apache.

the class GDBLearningStrategy method update.

/**
 * Gets state of model in arguments, compare it with training parameters of trainer and if they are fit then trainer
 * updates model in according to new data and return new model. In other case trains new model.
 *
 * @param mdlToUpdate Learned model.
 * @param datasetBuilder Dataset builder.
 * @param preprocessor Upstream preprocessor.
 * @param <K> Type of a key in {@code upstream} data.
 * @param <V> Type of a value in {@code upstream} data.
 * @return Updated models list.
 */
public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
    if (trainerEnvironment == null)
        throw new IllegalStateException("Learning environment builder is not set.");
    List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
    ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, preprocessor);
    DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
    for (int i = 0; i < cntOfIterations; i++) {
        double[] weights = Arrays.copyOf(compositionWeights, models.size());
        WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
        ModelsComposition currComposition = new ModelsComposition(models, aggregator);
        if (convCheck.isConverged(envBuilder, datasetBuilder, currComposition))
            break;
        Vectorizer<K, V, Serializable, Double> extractor = new Vectorizer.VectorizerAdapter<K, V, Serializable, Double>() {

            /**
             * {@inheritDoc}
             */
            @Override
            public LabeledVector<Double> extract(K k, V v) {
                LabeledVector<Double> labeledVector = preprocessor.apply(k, v);
                Vector features = labeledVector.features();
                Double realAnswer = externalLbToInternalMapping.apply(labeledVector.label());
                Double mdlAnswer = currComposition.predict(features);
                return new LabeledVector<>(features, -loss.gradient(sampleSize, realAnswer, mdlAnswer));
            }
        };
        long startTs = System.currentTimeMillis();
        models.add(trainer.fit(datasetBuilder, extractor));
        double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
        trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
    }
    return models;
}
Also used : Serializable(java.io.Serializable) WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) ModelsComposition(org.apache.ignite.ml.composition.ModelsComposition) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) IgniteModel(org.apache.ignite.ml.IgniteModel) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector)

Example 3 with WeightedPredictionsAggregator

use of org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator in project ignite by apache.

the class GDBOnTreesLearningStrategy method update.

/**
 * {@inheritDoc}
 */
@Override
public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
    LearningEnvironment environment = envBuilder.buildForTrainer();
    environment.initDeployingContext(vectorizer);
    DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
    assert trainer instanceof DecisionTreeTrainer;
    DecisionTreeTrainer decisionTreeTrainer = (DecisionTreeTrainer) trainer;
    List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
    ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, vectorizer);
    try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(vectorizer, useIdx), environment)) {
        for (int i = 0; i < cntOfIterations; i++) {
            double[] weights = Arrays.copyOf(compositionWeights, models.size());
            WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
            ModelsComposition currComposition = new ModelsComposition(models, aggregator);
            if (convCheck.isConverged(dataset, currComposition))
                break;
            dataset.compute(part -> {
                if (part.getCopiedOriginalLabels() == null)
                    part.setCopiedOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length));
                for (int j = 0; j < part.getLabels().length; j++) {
                    double mdlAnswer = currComposition.predict(VectorUtils.of(part.getFeatures()[j]));
                    double originalLbVal = externalLbToInternalMapping.apply(part.getCopiedOriginalLabels()[j]);
                    part.getLabels()[j] = -loss.gradient(sampleSize, originalLbVal, mdlAnswer);
                }
            });
            long startTs = System.currentTimeMillis();
            models.add(decisionTreeTrainer.fit(dataset));
            double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
            trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
        }
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    compositionWeights = Arrays.copyOf(compositionWeights, models.size());
    return models;
}
Also used : EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) ModelsComposition(org.apache.ignite.ml.composition.ModelsComposition) DecisionTreeTrainer(org.apache.ignite.ml.tree.DecisionTreeTrainer) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) DecisionTreeData(org.apache.ignite.ml.tree.data.DecisionTreeData) IgniteModel(org.apache.ignite.ml.IgniteModel) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Example 4 with WeightedPredictionsAggregator

use of org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator in project ignite by apache.

the class SparkModelParser method parseAndBuildGDBModel.

/**
 * Parse and build common GDB model with the custom label mapper.
 *
 * @param pathToMdl Path to model.
 * @param pathToMdlMetaData Path to model meta data.
 * @param lbMapper Label mapper.
 * @param learningEnvironment learningEnvironment
 */
@Nullable
private static Model parseAndBuildGDBModel(String pathToMdl, String pathToMdlMetaData, IgniteFunction<Double, Double> lbMapper, LearningEnvironment learningEnvironment) {
    double[] treeWeights = null;
    final Map<Integer, Double> treeWeightsByTreeID = new HashMap<>();
    try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdlMetaData), new Configuration()))) {
        PageReadStore pagesMetaData;
        final MessageType schema = r.getFooter().getFileMetaData().getSchema();
        final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
        while (null != (pagesMetaData = r.readNextRowGroup())) {
            final long rows = pagesMetaData.getRowCount();
            final RecordReader recordReader = colIO.getRecordReader(pagesMetaData, new GroupRecordConverter(schema));
            for (int i = 0; i < rows; i++) {
                final SimpleGroup g = (SimpleGroup) recordReader.read();
                int treeId = g.getInteger(0, 0);
                double treeWeight = g.getDouble(2, 0);
                treeWeightsByTreeID.put(treeId, treeWeight);
            }
        }
    } catch (IOException e) {
        String msg = "Error reading parquet file with MetaData by the path: " + pathToMdlMetaData;
        learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
        e.printStackTrace();
    }
    treeWeights = new double[treeWeightsByTreeID.size()];
    for (int i = 0; i < treeWeights.length; i++) treeWeights[i] = treeWeightsByTreeID.get(i);
    try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
        PageReadStore pages;
        final MessageType schema = r.getFooter().getFileMetaData().getSchema();
        final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
        final Map<Integer, TreeMap<Integer, NodeData>> nodesByTreeId = new TreeMap<>();
        while (null != (pages = r.readNextRowGroup())) {
            final long rows = pages.getRowCount();
            final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
            for (int i = 0; i < rows; i++) {
                final SimpleGroup g = (SimpleGroup) recordReader.read();
                final int treeID = g.getInteger(0, 0);
                final SimpleGroup nodeDataGroup = (SimpleGroup) g.getGroup(1, 0);
                NodeData nodeData = extractNodeDataFromParquetRow(nodeDataGroup);
                if (nodesByTreeId.containsKey(treeID)) {
                    Map<Integer, NodeData> nodesByNodeId = nodesByTreeId.get(treeID);
                    nodesByNodeId.put(nodeData.id, nodeData);
                } else {
                    TreeMap<Integer, NodeData> nodesByNodeId = new TreeMap<>();
                    nodesByNodeId.put(nodeData.id, nodeData);
                    nodesByTreeId.put(treeID, nodesByNodeId);
                }
            }
        }
        final List<IgniteModel<Vector, Double>> models = new ArrayList<>();
        nodesByTreeId.forEach((key, nodes) -> models.add(buildDecisionTreeModel(nodes)));
        return new GDBModel(models, new WeightedPredictionsAggregator(treeWeights), lbMapper);
    } catch (IOException e) {
        String msg = "Error reading parquet file: " + e.getMessage();
        learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
        e.printStackTrace();
    }
    return null;
}
Also used : Configuration(org.apache.hadoop.conf.Configuration) HashMap(java.util.HashMap) RecordReader(org.apache.parquet.io.RecordReader) ArrayList(java.util.ArrayList) GDBModel(org.apache.ignite.ml.composition.boosting.GDBModel) SimpleGroup(org.apache.parquet.example.data.simple.SimpleGroup) MessageColumnIO(org.apache.parquet.io.MessageColumnIO) PageReadStore(org.apache.parquet.column.page.PageReadStore) MessageType(org.apache.parquet.schema.MessageType) Path(org.apache.hadoop.fs.Path) GroupRecordConverter(org.apache.parquet.example.data.simple.convert.GroupRecordConverter) ParquetFileReader(org.apache.parquet.hadoop.ParquetFileReader) WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) IOException(java.io.IOException) TreeMap(java.util.TreeMap) ColumnIOFactory(org.apache.parquet.io.ColumnIOFactory) NodeData(org.apache.ignite.ml.tree.NodeData) IgniteModel(org.apache.ignite.ml.IgniteModel) Nullable(org.jetbrains.annotations.Nullable)

Example 5 with WeightedPredictionsAggregator

use of org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator in project ignite by apache.

the class GDBLearningStrategy method initLearningState.

/**
 * Restores state of already learned model if can and sets learning parameters according to this state.
 *
 * @param mdlToUpdate Model to update.
 * @return List of already learned models.
 */
@NotNull
protected List<IgniteModel<Vector, Double>> initLearningState(GDBModel mdlToUpdate) {
    List<IgniteModel<Vector, Double>> models = new ArrayList<>();
    if (mdlToUpdate != null) {
        models.addAll(mdlToUpdate.getModels());
        WeightedPredictionsAggregator aggregator = (WeightedPredictionsAggregator) mdlToUpdate.getPredictionsAggregator();
        meanLbVal = aggregator.getBias();
        compositionWeights = new double[models.size() + cntOfIterations];
        System.arraycopy(aggregator.getWeights(), 0, compositionWeights, 0, models.size());
    } else
        compositionWeights = new double[cntOfIterations];
    Arrays.fill(compositionWeights, models.size(), compositionWeights.length, defaultGradStepSize);
    return models;
}
Also used : ArrayList(java.util.ArrayList) WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) IgniteModel(org.apache.ignite.ml.IgniteModel) NotNull(org.jetbrains.annotations.NotNull)

Aggregations

WeightedPredictionsAggregator (org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator)7 IgniteModel (org.apache.ignite.ml.IgniteModel)5 ModelsComposition (org.apache.ignite.ml.composition.ModelsComposition)4 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)4 HashMap (java.util.HashMap)3 ArrayList (java.util.ArrayList)2 MeanAbsValueConvergenceCheckerFactory (org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory)2 DecisionTreeModel (org.apache.ignite.ml.tree.DecisionTreeModel)2 IOException (java.io.IOException)1 Serializable (java.io.Serializable)1 TreeMap (java.util.TreeMap)1 Configuration (org.apache.hadoop.conf.Configuration)1 Path (org.apache.hadoop.fs.Path)1 TrainerTest (org.apache.ignite.ml.common.TrainerTest)1 GDBModel (org.apache.ignite.ml.composition.boosting.GDBModel)1 ConvergenceCheckerStubFactory (org.apache.ignite.ml.composition.boosting.convergence.simple.ConvergenceCheckerStubFactory)1 DoubleArrayVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer)1 EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)1 LearningEnvironment (org.apache.ignite.ml.environment.LearningEnvironment)1 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)1