use of org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator in project ignite by apache.
the class GDBTrainerTest method testClassifier.
/**
*/
private void testClassifier(BiFunction<GDBTrainer, Map<Integer, double[]>, IgniteModel<Vector, Double>> fitter) {
int sampleSize = 100;
double[] xs = new double[sampleSize];
double[] ys = new double[sampleSize];
for (int i = 0; i < sampleSize; i++) {
xs[i] = i;
ys[i] = ((int) (xs[i] / 10.0) % 2) == 0 ? -1.0 : 1.0;
}
Map<Integer, double[]> learningSample = new HashMap<>();
for (int i = 0; i < sampleSize; i++) learningSample.put(i, new double[] { xs[i], ys[i] });
GDBTrainer trainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0).withUsingIdx(true).withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.3));
IgniteModel<Vector, Double> mdl = fitter.apply(trainer, learningSample);
int errorsCnt = 0;
for (int j = 0; j < sampleSize; j++) {
double x = xs[j];
double y = ys[j];
double p = mdl.predict(VectorUtils.of(x));
if (p != y)
errorsCnt++;
}
assertEquals(0, errorsCnt);
assertTrue(mdl instanceof ModelsComposition);
ModelsComposition composition = (ModelsComposition) mdl;
composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeModel));
assertTrue(composition.getModels().size() < 500);
assertTrue(composition.getPredictionsAggregator() instanceof WeightedPredictionsAggregator);
trainer = trainer.withCheckConvergenceStgyFactory(new ConvergenceCheckerStubFactory());
assertEquals(500, ((ModelsComposition) fitter.apply(trainer, learningSample)).getModels().size());
}
use of org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator in project ignite by apache.
the class GDBLearningStrategy method update.
/**
* Gets state of model in arguments, compare it with training parameters of trainer and if they are fit then trainer
* updates model in according to new data and return new model. In other case trains new model.
*
* @param mdlToUpdate Learned model.
* @param datasetBuilder Dataset builder.
* @param preprocessor Upstream preprocessor.
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
* @return Updated models list.
*/
public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
if (trainerEnvironment == null)
throw new IllegalStateException("Learning environment builder is not set.");
List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, preprocessor);
DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
for (int i = 0; i < cntOfIterations; i++) {
double[] weights = Arrays.copyOf(compositionWeights, models.size());
WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
ModelsComposition currComposition = new ModelsComposition(models, aggregator);
if (convCheck.isConverged(envBuilder, datasetBuilder, currComposition))
break;
Vectorizer<K, V, Serializable, Double> extractor = new Vectorizer.VectorizerAdapter<K, V, Serializable, Double>() {
/**
* {@inheritDoc}
*/
@Override
public LabeledVector<Double> extract(K k, V v) {
LabeledVector<Double> labeledVector = preprocessor.apply(k, v);
Vector features = labeledVector.features();
Double realAnswer = externalLbToInternalMapping.apply(labeledVector.label());
Double mdlAnswer = currComposition.predict(features);
return new LabeledVector<>(features, -loss.gradient(sampleSize, realAnswer, mdlAnswer));
}
};
long startTs = System.currentTimeMillis();
models.add(trainer.fit(datasetBuilder, extractor));
double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
}
return models;
}
use of org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator in project ignite by apache.
the class GDBOnTreesLearningStrategy method update.
/**
* {@inheritDoc}
*/
@Override
public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
LearningEnvironment environment = envBuilder.buildForTrainer();
environment.initDeployingContext(vectorizer);
DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
assert trainer instanceof DecisionTreeTrainer;
DecisionTreeTrainer decisionTreeTrainer = (DecisionTreeTrainer) trainer;
List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, vectorizer);
try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(vectorizer, useIdx), environment)) {
for (int i = 0; i < cntOfIterations; i++) {
double[] weights = Arrays.copyOf(compositionWeights, models.size());
WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
ModelsComposition currComposition = new ModelsComposition(models, aggregator);
if (convCheck.isConverged(dataset, currComposition))
break;
dataset.compute(part -> {
if (part.getCopiedOriginalLabels() == null)
part.setCopiedOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length));
for (int j = 0; j < part.getLabels().length; j++) {
double mdlAnswer = currComposition.predict(VectorUtils.of(part.getFeatures()[j]));
double originalLbVal = externalLbToInternalMapping.apply(part.getCopiedOriginalLabels()[j]);
part.getLabels()[j] = -loss.gradient(sampleSize, originalLbVal, mdlAnswer);
}
});
long startTs = System.currentTimeMillis();
models.add(decisionTreeTrainer.fit(dataset));
double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
compositionWeights = Arrays.copyOf(compositionWeights, models.size());
return models;
}
use of org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator in project ignite by apache.
the class SparkModelParser method parseAndBuildGDBModel.
/**
* Parse and build common GDB model with the custom label mapper.
*
* @param pathToMdl Path to model.
* @param pathToMdlMetaData Path to model meta data.
* @param lbMapper Label mapper.
* @param learningEnvironment learningEnvironment
*/
@Nullable
private static Model parseAndBuildGDBModel(String pathToMdl, String pathToMdlMetaData, IgniteFunction<Double, Double> lbMapper, LearningEnvironment learningEnvironment) {
double[] treeWeights = null;
final Map<Integer, Double> treeWeightsByTreeID = new HashMap<>();
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdlMetaData), new Configuration()))) {
PageReadStore pagesMetaData;
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
while (null != (pagesMetaData = r.readNextRowGroup())) {
final long rows = pagesMetaData.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pagesMetaData, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
final SimpleGroup g = (SimpleGroup) recordReader.read();
int treeId = g.getInteger(0, 0);
double treeWeight = g.getDouble(2, 0);
treeWeightsByTreeID.put(treeId, treeWeight);
}
}
} catch (IOException e) {
String msg = "Error reading parquet file with MetaData by the path: " + pathToMdlMetaData;
learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
e.printStackTrace();
}
treeWeights = new double[treeWeightsByTreeID.size()];
for (int i = 0; i < treeWeights.length; i++) treeWeights[i] = treeWeightsByTreeID.get(i);
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
PageReadStore pages;
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
final Map<Integer, TreeMap<Integer, NodeData>> nodesByTreeId = new TreeMap<>();
while (null != (pages = r.readNextRowGroup())) {
final long rows = pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
final SimpleGroup g = (SimpleGroup) recordReader.read();
final int treeID = g.getInteger(0, 0);
final SimpleGroup nodeDataGroup = (SimpleGroup) g.getGroup(1, 0);
NodeData nodeData = extractNodeDataFromParquetRow(nodeDataGroup);
if (nodesByTreeId.containsKey(treeID)) {
Map<Integer, NodeData> nodesByNodeId = nodesByTreeId.get(treeID);
nodesByNodeId.put(nodeData.id, nodeData);
} else {
TreeMap<Integer, NodeData> nodesByNodeId = new TreeMap<>();
nodesByNodeId.put(nodeData.id, nodeData);
nodesByTreeId.put(treeID, nodesByNodeId);
}
}
}
final List<IgniteModel<Vector, Double>> models = new ArrayList<>();
nodesByTreeId.forEach((key, nodes) -> models.add(buildDecisionTreeModel(nodes)));
return new GDBModel(models, new WeightedPredictionsAggregator(treeWeights), lbMapper);
} catch (IOException e) {
String msg = "Error reading parquet file: " + e.getMessage();
learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
e.printStackTrace();
}
return null;
}
use of org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator in project ignite by apache.
the class GDBLearningStrategy method initLearningState.
/**
* Restores state of already learned model if can and sets learning parameters according to this state.
*
* @param mdlToUpdate Model to update.
* @return List of already learned models.
*/
@NotNull
protected List<IgniteModel<Vector, Double>> initLearningState(GDBModel mdlToUpdate) {
List<IgniteModel<Vector, Double>> models = new ArrayList<>();
if (mdlToUpdate != null) {
models.addAll(mdlToUpdate.getModels());
WeightedPredictionsAggregator aggregator = (WeightedPredictionsAggregator) mdlToUpdate.getPredictionsAggregator();
meanLbVal = aggregator.getBias();
compositionWeights = new double[models.size() + cntOfIterations];
System.arraycopy(aggregator.getWeights(), 0, compositionWeights, 0, models.size());
} else
compositionWeights = new double[cntOfIterations];
Arrays.fill(compositionWeights, models.size(), compositionWeights.length, defaultGradStepSize);
return models;
}
Aggregations