use of org.apache.ignite.ml.IgniteModel in project ignite by apache.
the class SparkModelParser method parseTreesForRandomForestAlgorithm.
/**
* Parse trees from file for common Random Forest ensemble.
*
* @param pathToMdl Path to model.
* @param learningEnvironment Learning environment.
*/
private static List<IgniteModel<Vector, Double>> parseTreesForRandomForestAlgorithm(String pathToMdl, LearningEnvironment learningEnvironment) {
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
PageReadStore pages;
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
final Map<Integer, TreeMap<Integer, NodeData>> nodesByTreeId = new TreeMap<>();
while (null != (pages = r.readNextRowGroup())) {
final long rows = pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
final SimpleGroup g = (SimpleGroup) recordReader.read();
final int treeID = g.getInteger(0, 0);
final SimpleGroup nodeDataGroup = (SimpleGroup) g.getGroup(1, 0);
NodeData nodeData = extractNodeDataFromParquetRow(nodeDataGroup);
if (nodesByTreeId.containsKey(treeID)) {
Map<Integer, NodeData> nodesByNodeId = nodesByTreeId.get(treeID);
nodesByNodeId.put(nodeData.id, nodeData);
} else {
TreeMap<Integer, NodeData> nodesByNodeId = new TreeMap<>();
nodesByNodeId.put(nodeData.id, nodeData);
nodesByTreeId.put(treeID, nodesByNodeId);
}
}
}
List<IgniteModel<Vector, Double>> models = new ArrayList<>();
nodesByTreeId.forEach((key, nodes) -> models.add(buildDecisionTreeModel(nodes)));
return models;
} catch (IOException e) {
String msg = "Error reading parquet file: " + e.getMessage();
learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
e.printStackTrace();
}
return null;
}
use of org.apache.ignite.ml.IgniteModel in project ignite by apache.
the class SparkModelParser method parseAndBuildGDBModel.
/**
* Parse and build common GDB model with the custom label mapper.
*
* @param pathToMdl Path to model.
* @param pathToMdlMetaData Path to model meta data.
* @param lbMapper Label mapper.
* @param learningEnvironment learningEnvironment
*/
@Nullable
private static Model parseAndBuildGDBModel(String pathToMdl, String pathToMdlMetaData, IgniteFunction<Double, Double> lbMapper, LearningEnvironment learningEnvironment) {
double[] treeWeights = null;
final Map<Integer, Double> treeWeightsByTreeID = new HashMap<>();
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdlMetaData), new Configuration()))) {
PageReadStore pagesMetaData;
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
while (null != (pagesMetaData = r.readNextRowGroup())) {
final long rows = pagesMetaData.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pagesMetaData, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
final SimpleGroup g = (SimpleGroup) recordReader.read();
int treeId = g.getInteger(0, 0);
double treeWeight = g.getDouble(2, 0);
treeWeightsByTreeID.put(treeId, treeWeight);
}
}
} catch (IOException e) {
String msg = "Error reading parquet file with MetaData by the path: " + pathToMdlMetaData;
learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
e.printStackTrace();
}
treeWeights = new double[treeWeightsByTreeID.size()];
for (int i = 0; i < treeWeights.length; i++) treeWeights[i] = treeWeightsByTreeID.get(i);
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
PageReadStore pages;
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
final Map<Integer, TreeMap<Integer, NodeData>> nodesByTreeId = new TreeMap<>();
while (null != (pages = r.readNextRowGroup())) {
final long rows = pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
final SimpleGroup g = (SimpleGroup) recordReader.read();
final int treeID = g.getInteger(0, 0);
final SimpleGroup nodeDataGroup = (SimpleGroup) g.getGroup(1, 0);
NodeData nodeData = extractNodeDataFromParquetRow(nodeDataGroup);
if (nodesByTreeId.containsKey(treeID)) {
Map<Integer, NodeData> nodesByNodeId = nodesByTreeId.get(treeID);
nodesByNodeId.put(nodeData.id, nodeData);
} else {
TreeMap<Integer, NodeData> nodesByNodeId = new TreeMap<>();
nodesByNodeId.put(nodeData.id, nodeData);
nodesByTreeId.put(treeID, nodesByNodeId);
}
}
}
final List<IgniteModel<Vector, Double>> models = new ArrayList<>();
nodesByTreeId.forEach((key, nodes) -> models.add(buildDecisionTreeModel(nodes)));
return new GDBModel(models, new WeightedPredictionsAggregator(treeWeights), lbMapper);
} catch (IOException e) {
String msg = "Error reading parquet file: " + e.getMessage();
learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
e.printStackTrace();
}
return null;
}
use of org.apache.ignite.ml.IgniteModel in project ignite by apache.
the class GDBLearningStrategy method initLearningState.
/**
* Restores state of already learned model if can and sets learning parameters according to this state.
*
* @param mdlToUpdate Model to update.
* @return List of already learned models.
*/
@NotNull
protected List<IgniteModel<Vector, Double>> initLearningState(GDBModel mdlToUpdate) {
List<IgniteModel<Vector, Double>> models = new ArrayList<>();
if (mdlToUpdate != null) {
models.addAll(mdlToUpdate.getModels());
WeightedPredictionsAggregator aggregator = (WeightedPredictionsAggregator) mdlToUpdate.getPredictionsAggregator();
meanLbVal = aggregator.getBias();
compositionWeights = new double[models.size() + cntOfIterations];
System.arraycopy(aggregator.getWeights(), 0, compositionWeights, 0, models.size());
} else
compositionWeights = new double[cntOfIterations];
Arrays.fill(compositionWeights, models.size(), compositionWeights.length, defaultGradStepSize);
return models;
}
use of org.apache.ignite.ml.IgniteModel in project ignite by apache.
the class GDBTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> GDBModel updateModel(GDBModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
if (!learnLabels(datasetBuilder, preprocessor))
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
IgniteBiTuple<Double, Long> initAndSampleSize = computeInitialValue(envBuilder, datasetBuilder, preprocessor);
if (initAndSampleSize == null)
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
Double mean = initAndSampleSize.get1();
Long sampleSize = initAndSampleSize.get2();
long learningStartTs = System.currentTimeMillis();
GDBLearningStrategy stgy = getLearningStrategy().withBaseModelTrainerBuilder(this::buildBaseModelTrainer).withExternalLabelToInternal(this::externalLabelToInternal).withCntOfIterations(cntOfIterations).withEnvironmentBuilder(envBuilder).withLossGradient(loss).withSampleSize(sampleSize).withMeanLabelValue(mean).withDefaultGradStepSize(gradientStep).withCheckConvergenceStgyFactory(checkConvergenceStgyFactory);
List<IgniteModel<Vector, Double>> models;
if (mdl != null)
models = stgy.update(mdl, datasetBuilder, preprocessor);
else
models = stgy.learnModels(datasetBuilder, preprocessor);
double learningTime = (double) (System.currentTimeMillis() - learningStartTs) / 1000.0;
environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", learningTime);
WeightedPredictionsAggregator resAggregator = new WeightedPredictionsAggregator(stgy.getCompositionWeights(), stgy.getMeanValue());
return new GDBModel(models, resAggregator, this::internalLabelToExternal);
}
use of org.apache.ignite.ml.IgniteModel in project ignite by apache.
the class LearningEnvironmentTest method testRandomNumbersGenerator.
/**
* Test random number generator provided by {@link LearningEnvironment}.
* We test that:
* 1. Correct random generator is returned for each partition.
* 2. Its state is saved between compute calls (for this we do several iterations of compute).
*/
@Test
public void testRandomNumbersGenerator() {
// We make such builders that provide as functions returning partition index * iteration as random number generator nextInt
LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder().withRandomDependency(MockRandom::new);
int partitions = 10;
int iterations = 2;
DatasetTrainer<IgniteModel<Object, Vector>, Void> trainer = new DatasetTrainer<IgniteModel<Object, Vector>, Void>() {
/**
* {@inheritDoc}
*/
@Override
public <K, V> IgniteModel<Object, Vector> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
Dataset<EmptyContext, TestUtils.DataWrapper<Integer>> ds = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), (PartitionDataBuilder<K, V, EmptyContext, TestUtils.DataWrapper<Integer>>) (env, upstreamData, upstreamDataSize, ctx) -> TestUtils.DataWrapper.of(env.partition()), envBuilder.buildForTrainer());
Vector v = null;
for (int iter = 0; iter < iterations; iter++) {
v = ds.compute((dw, env) -> VectorUtils.fill(-1, partitions).set(env.partition(), env.randomNumbersGenerator().nextInt()), (v1, v2) -> zipOverridingEmpty(v1, v2, -1));
}
return constantModel(v);
}
/**
* {@inheritDoc}
*/
@Override
public boolean isUpdateable(IgniteModel<Object, Vector> mdl) {
return false;
}
/**
* {@inheritDoc}
*/
@Override
protected <K, V> IgniteModel<Object, Vector> updateModel(IgniteModel<Object, Vector> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
return null;
}
};
trainer.withEnvironmentBuilder(envBuilder);
IgniteModel<Object, Vector> mdl = trainer.fit(getCacheMock(partitions), partitions, null);
Vector exp = VectorUtils.zeroes(partitions);
for (int i = 0; i < partitions; i++) exp.set(i, i * iterations);
Vector res = mdl.predict(null);
assertEquals(exp, res);
}
Aggregations