Search in sources :

Example 6 with IgniteModel

use of org.apache.ignite.ml.IgniteModel in project ignite by apache.

the class SparkModelParser method parseTreesForRandomForestAlgorithm.

/**
 * Parse trees from file for common Random Forest ensemble.
 *
 * @param pathToMdl Path to model.
 * @param learningEnvironment Learning environment.
 */
private static List<IgniteModel<Vector, Double>> parseTreesForRandomForestAlgorithm(String pathToMdl, LearningEnvironment learningEnvironment) {
    try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
        PageReadStore pages;
        final MessageType schema = r.getFooter().getFileMetaData().getSchema();
        final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
        final Map<Integer, TreeMap<Integer, NodeData>> nodesByTreeId = new TreeMap<>();
        while (null != (pages = r.readNextRowGroup())) {
            final long rows = pages.getRowCount();
            final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
            for (int i = 0; i < rows; i++) {
                final SimpleGroup g = (SimpleGroup) recordReader.read();
                final int treeID = g.getInteger(0, 0);
                final SimpleGroup nodeDataGroup = (SimpleGroup) g.getGroup(1, 0);
                NodeData nodeData = extractNodeDataFromParquetRow(nodeDataGroup);
                if (nodesByTreeId.containsKey(treeID)) {
                    Map<Integer, NodeData> nodesByNodeId = nodesByTreeId.get(treeID);
                    nodesByNodeId.put(nodeData.id, nodeData);
                } else {
                    TreeMap<Integer, NodeData> nodesByNodeId = new TreeMap<>();
                    nodesByNodeId.put(nodeData.id, nodeData);
                    nodesByTreeId.put(treeID, nodesByNodeId);
                }
            }
        }
        List<IgniteModel<Vector, Double>> models = new ArrayList<>();
        nodesByTreeId.forEach((key, nodes) -> models.add(buildDecisionTreeModel(nodes)));
        return models;
    } catch (IOException e) {
        String msg = "Error reading parquet file: " + e.getMessage();
        learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
        e.printStackTrace();
    }
    return null;
}
Also used : Path(org.apache.hadoop.fs.Path) GroupRecordConverter(org.apache.parquet.example.data.simple.convert.GroupRecordConverter) Configuration(org.apache.hadoop.conf.Configuration) ParquetFileReader(org.apache.parquet.hadoop.ParquetFileReader) RecordReader(org.apache.parquet.io.RecordReader) ArrayList(java.util.ArrayList) SimpleGroup(org.apache.parquet.example.data.simple.SimpleGroup) IOException(java.io.IOException) TreeMap(java.util.TreeMap) MessageColumnIO(org.apache.parquet.io.MessageColumnIO) ColumnIOFactory(org.apache.parquet.io.ColumnIOFactory) NodeData(org.apache.ignite.ml.tree.NodeData) PageReadStore(org.apache.parquet.column.page.PageReadStore) IgniteModel(org.apache.ignite.ml.IgniteModel) MessageType(org.apache.parquet.schema.MessageType)

Example 7 with IgniteModel

use of org.apache.ignite.ml.IgniteModel in project ignite by apache.

the class SparkModelParser method parseAndBuildGDBModel.

/**
 * Parse and build common GDB model with the custom label mapper.
 *
 * @param pathToMdl Path to model.
 * @param pathToMdlMetaData Path to model meta data.
 * @param lbMapper Label mapper.
 * @param learningEnvironment learningEnvironment
 */
@Nullable
private static Model parseAndBuildGDBModel(String pathToMdl, String pathToMdlMetaData, IgniteFunction<Double, Double> lbMapper, LearningEnvironment learningEnvironment) {
    double[] treeWeights = null;
    final Map<Integer, Double> treeWeightsByTreeID = new HashMap<>();
    try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdlMetaData), new Configuration()))) {
        PageReadStore pagesMetaData;
        final MessageType schema = r.getFooter().getFileMetaData().getSchema();
        final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
        while (null != (pagesMetaData = r.readNextRowGroup())) {
            final long rows = pagesMetaData.getRowCount();
            final RecordReader recordReader = colIO.getRecordReader(pagesMetaData, new GroupRecordConverter(schema));
            for (int i = 0; i < rows; i++) {
                final SimpleGroup g = (SimpleGroup) recordReader.read();
                int treeId = g.getInteger(0, 0);
                double treeWeight = g.getDouble(2, 0);
                treeWeightsByTreeID.put(treeId, treeWeight);
            }
        }
    } catch (IOException e) {
        String msg = "Error reading parquet file with MetaData by the path: " + pathToMdlMetaData;
        learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
        e.printStackTrace();
    }
    treeWeights = new double[treeWeightsByTreeID.size()];
    for (int i = 0; i < treeWeights.length; i++) treeWeights[i] = treeWeightsByTreeID.get(i);
    try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
        PageReadStore pages;
        final MessageType schema = r.getFooter().getFileMetaData().getSchema();
        final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
        final Map<Integer, TreeMap<Integer, NodeData>> nodesByTreeId = new TreeMap<>();
        while (null != (pages = r.readNextRowGroup())) {
            final long rows = pages.getRowCount();
            final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
            for (int i = 0; i < rows; i++) {
                final SimpleGroup g = (SimpleGroup) recordReader.read();
                final int treeID = g.getInteger(0, 0);
                final SimpleGroup nodeDataGroup = (SimpleGroup) g.getGroup(1, 0);
                NodeData nodeData = extractNodeDataFromParquetRow(nodeDataGroup);
                if (nodesByTreeId.containsKey(treeID)) {
                    Map<Integer, NodeData> nodesByNodeId = nodesByTreeId.get(treeID);
                    nodesByNodeId.put(nodeData.id, nodeData);
                } else {
                    TreeMap<Integer, NodeData> nodesByNodeId = new TreeMap<>();
                    nodesByNodeId.put(nodeData.id, nodeData);
                    nodesByTreeId.put(treeID, nodesByNodeId);
                }
            }
        }
        final List<IgniteModel<Vector, Double>> models = new ArrayList<>();
        nodesByTreeId.forEach((key, nodes) -> models.add(buildDecisionTreeModel(nodes)));
        return new GDBModel(models, new WeightedPredictionsAggregator(treeWeights), lbMapper);
    } catch (IOException e) {
        String msg = "Error reading parquet file: " + e.getMessage();
        learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
        e.printStackTrace();
    }
    return null;
}
Also used : Configuration(org.apache.hadoop.conf.Configuration) HashMap(java.util.HashMap) RecordReader(org.apache.parquet.io.RecordReader) ArrayList(java.util.ArrayList) GDBModel(org.apache.ignite.ml.composition.boosting.GDBModel) SimpleGroup(org.apache.parquet.example.data.simple.SimpleGroup) MessageColumnIO(org.apache.parquet.io.MessageColumnIO) PageReadStore(org.apache.parquet.column.page.PageReadStore) MessageType(org.apache.parquet.schema.MessageType) Path(org.apache.hadoop.fs.Path) GroupRecordConverter(org.apache.parquet.example.data.simple.convert.GroupRecordConverter) ParquetFileReader(org.apache.parquet.hadoop.ParquetFileReader) WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) IOException(java.io.IOException) TreeMap(java.util.TreeMap) ColumnIOFactory(org.apache.parquet.io.ColumnIOFactory) NodeData(org.apache.ignite.ml.tree.NodeData) IgniteModel(org.apache.ignite.ml.IgniteModel) Nullable(org.jetbrains.annotations.Nullable)

Example 8 with IgniteModel

use of org.apache.ignite.ml.IgniteModel in project ignite by apache.

the class GDBLearningStrategy method initLearningState.

/**
 * Restores state of already learned model if can and sets learning parameters according to this state.
 *
 * @param mdlToUpdate Model to update.
 * @return List of already learned models.
 */
@NotNull
protected List<IgniteModel<Vector, Double>> initLearningState(GDBModel mdlToUpdate) {
    List<IgniteModel<Vector, Double>> models = new ArrayList<>();
    if (mdlToUpdate != null) {
        models.addAll(mdlToUpdate.getModels());
        WeightedPredictionsAggregator aggregator = (WeightedPredictionsAggregator) mdlToUpdate.getPredictionsAggregator();
        meanLbVal = aggregator.getBias();
        compositionWeights = new double[models.size() + cntOfIterations];
        System.arraycopy(aggregator.getWeights(), 0, compositionWeights, 0, models.size());
    } else
        compositionWeights = new double[cntOfIterations];
    Arrays.fill(compositionWeights, models.size(), compositionWeights.length, defaultGradStepSize);
    return models;
}
Also used : ArrayList(java.util.ArrayList) WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) IgniteModel(org.apache.ignite.ml.IgniteModel) NotNull(org.jetbrains.annotations.NotNull)

Example 9 with IgniteModel

use of org.apache.ignite.ml.IgniteModel in project ignite by apache.

the class GDBTrainer method updateModel.

/**
 * {@inheritDoc}
 */
@Override
protected <K, V> GDBModel updateModel(GDBModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
    if (!learnLabels(datasetBuilder, preprocessor))
        return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
    IgniteBiTuple<Double, Long> initAndSampleSize = computeInitialValue(envBuilder, datasetBuilder, preprocessor);
    if (initAndSampleSize == null)
        return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
    Double mean = initAndSampleSize.get1();
    Long sampleSize = initAndSampleSize.get2();
    long learningStartTs = System.currentTimeMillis();
    GDBLearningStrategy stgy = getLearningStrategy().withBaseModelTrainerBuilder(this::buildBaseModelTrainer).withExternalLabelToInternal(this::externalLabelToInternal).withCntOfIterations(cntOfIterations).withEnvironmentBuilder(envBuilder).withLossGradient(loss).withSampleSize(sampleSize).withMeanLabelValue(mean).withDefaultGradStepSize(gradientStep).withCheckConvergenceStgyFactory(checkConvergenceStgyFactory);
    List<IgniteModel<Vector, Double>> models;
    if (mdl != null)
        models = stgy.update(mdl, datasetBuilder, preprocessor);
    else
        models = stgy.learnModels(datasetBuilder, preprocessor);
    double learningTime = (double) (System.currentTimeMillis() - learningStartTs) / 1000.0;
    environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", learningTime);
    WeightedPredictionsAggregator resAggregator = new WeightedPredictionsAggregator(stgy.getCompositionWeights(), stgy.getMeanValue());
    return new GDBModel(models, resAggregator, this::internalLabelToExternal);
}
Also used : WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) IgniteModel(org.apache.ignite.ml.IgniteModel)

Example 10 with IgniteModel

use of org.apache.ignite.ml.IgniteModel in project ignite by apache.

the class LearningEnvironmentTest method testRandomNumbersGenerator.

/**
 * Test random number generator provided by  {@link LearningEnvironment}.
 * We test that:
 * 1. Correct random generator is returned for each partition.
 * 2. Its state is saved between compute calls (for this we do several iterations of compute).
 */
@Test
public void testRandomNumbersGenerator() {
    // We make such builders that provide as functions returning partition index * iteration as random number generator nextInt
    LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder().withRandomDependency(MockRandom::new);
    int partitions = 10;
    int iterations = 2;
    DatasetTrainer<IgniteModel<Object, Vector>, Void> trainer = new DatasetTrainer<IgniteModel<Object, Vector>, Void>() {

        /**
         * {@inheritDoc}
         */
        @Override
        public <K, V> IgniteModel<Object, Vector> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
            Dataset<EmptyContext, TestUtils.DataWrapper<Integer>> ds = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), (PartitionDataBuilder<K, V, EmptyContext, TestUtils.DataWrapper<Integer>>) (env, upstreamData, upstreamDataSize, ctx) -> TestUtils.DataWrapper.of(env.partition()), envBuilder.buildForTrainer());
            Vector v = null;
            for (int iter = 0; iter < iterations; iter++) {
                v = ds.compute((dw, env) -> VectorUtils.fill(-1, partitions).set(env.partition(), env.randomNumbersGenerator().nextInt()), (v1, v2) -> zipOverridingEmpty(v1, v2, -1));
            }
            return constantModel(v);
        }

        /**
         * {@inheritDoc}
         */
        @Override
        public boolean isUpdateable(IgniteModel<Object, Vector> mdl) {
            return false;
        }

        /**
         * {@inheritDoc}
         */
        @Override
        protected <K, V> IgniteModel<Object, Vector> updateModel(IgniteModel<Object, Vector> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
            return null;
        }
    };
    trainer.withEnvironmentBuilder(envBuilder);
    IgniteModel<Object, Vector> mdl = trainer.fit(getCacheMock(partitions), partitions, null);
    Vector exp = VectorUtils.zeroes(partitions);
    for (int i = 0; i < partitions; i++) exp.set(i, i * iterations);
    Vector res = mdl.predict(null);
    assertEquals(exp, res);
}
Also used : IntStream(java.util.stream.IntStream) TestUtils(org.apache.ignite.ml.TestUtils) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) Random(java.util.Random) DatasetTrainer(org.apache.ignite.ml.trainers.DatasetTrainer) ParallelismStrategy(org.apache.ignite.ml.environment.parallelism.ParallelismStrategy) FeatureMeta(org.apache.ignite.ml.dataset.feature.FeatureMeta) RandomForestRegressionTrainer(org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer) Map(java.util.Map) EmptyContextBuilder(org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder) MLLogger(org.apache.ignite.ml.environment.logging.MLLogger) PartitionDataBuilder(org.apache.ignite.ml.dataset.PartitionDataBuilder) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) ConsoleLogger(org.apache.ignite.ml.environment.logging.ConsoleLogger) Test(org.junit.Test) FeaturesCountSelectionStrategies(org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies) IgniteModel(org.apache.ignite.ml.IgniteModel) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) Collectors(java.util.stream.Collectors) VectorUtils(org.apache.ignite.ml.math.primitives.vector.VectorUtils) Dataset(org.apache.ignite.ml.dataset.Dataset) DefaultParallelismStrategy(org.apache.ignite.ml.environment.parallelism.DefaultParallelismStrategy) TestUtils.constantModel(org.apache.ignite.ml.TestUtils.constantModel) Assert.assertEquals(org.junit.Assert.assertEquals) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) DatasetTrainer(org.apache.ignite.ml.trainers.DatasetTrainer) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) IgniteModel(org.apache.ignite.ml.IgniteModel) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Test(org.junit.Test)

Aggregations

IgniteModel (org.apache.ignite.ml.IgniteModel)10 WeightedPredictionsAggregator (org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator)5 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)5 IOException (java.io.IOException)3 ArrayList (java.util.ArrayList)3 HashMap (java.util.HashMap)3 Map (java.util.Map)3 Serializable (java.io.Serializable)2 TreeMap (java.util.TreeMap)2 Configuration (org.apache.hadoop.conf.Configuration)2 Path (org.apache.hadoop.fs.Path)2 ModelsComposition (org.apache.ignite.ml.composition.ModelsComposition)2 EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)2 VectorUtils (org.apache.ignite.ml.math.primitives.vector.VectorUtils)2 EvaluationResult (org.apache.ignite.ml.selection.scoring.evaluator.EvaluationResult)2 Evaluator (org.apache.ignite.ml.selection.scoring.evaluator.Evaluator)2 MetricName (org.apache.ignite.ml.selection.scoring.metric.MetricName)2 NodeData (org.apache.ignite.ml.tree.NodeData)2 PageReadStore (org.apache.parquet.column.page.PageReadStore)2 SimpleGroup (org.apache.parquet.example.data.simple.SimpleGroup)2