use of org.apache.ignite.ml.IgniteModel in project ignite by apache.
the class GDBLearningStrategy method update.
/**
* Gets state of model in arguments, compare it with training parameters of trainer and if they are fit then trainer
* updates model in according to new data and return new model. In other case trains new model.
*
* @param mdlToUpdate Learned model.
* @param datasetBuilder Dataset builder.
* @param preprocessor Upstream preprocessor.
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
* @return Updated models list.
*/
public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
if (trainerEnvironment == null)
throw new IllegalStateException("Learning environment builder is not set.");
List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, preprocessor);
DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
for (int i = 0; i < cntOfIterations; i++) {
double[] weights = Arrays.copyOf(compositionWeights, models.size());
WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
ModelsComposition currComposition = new ModelsComposition(models, aggregator);
if (convCheck.isConverged(envBuilder, datasetBuilder, currComposition))
break;
Vectorizer<K, V, Serializable, Double> extractor = new Vectorizer.VectorizerAdapter<K, V, Serializable, Double>() {
/**
* {@inheritDoc}
*/
@Override
public LabeledVector<Double> extract(K k, V v) {
LabeledVector<Double> labeledVector = preprocessor.apply(k, v);
Vector features = labeledVector.features();
Double realAnswer = externalLbToInternalMapping.apply(labeledVector.label());
Double mdlAnswer = currComposition.predict(features);
return new LabeledVector<>(features, -loss.gradient(sampleSize, realAnswer, mdlAnswer));
}
};
long startTs = System.currentTimeMillis();
models.add(trainer.fit(datasetBuilder, extractor));
double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
}
return models;
}
use of org.apache.ignite.ml.IgniteModel in project ignite by apache.
the class BinaryClassificationMetricsTest method testCalculation.
/**
*/
@Test
public void testCalculation() {
Map<Vector, Double> xorset = new HashMap<Vector, Double>() {
{
put(VectorUtils.of(0., 0.), 0.);
put(VectorUtils.of(0., 1.), 1.);
put(VectorUtils.of(1., 0.), 1.);
put(VectorUtils.of(1., 1.), 0.);
}
};
IgniteModel<Vector, Double> xorFunction = v -> {
if (Math.abs(v.get(0) - v.get(1)) < 0.01)
return 0.;
else
return 1.;
};
IgniteModel<Vector, Double> andFunction = v -> {
if (Math.abs(v.get(0) - v.get(1)) < 0.01 && v.get(0) > 0)
return 1.;
else
return 0.;
};
IgniteModel<Vector, Double> orFunction = v -> {
if (v.get(0) > 0 || v.get(1) > 0)
return 1.;
else
return 0.;
};
EvaluationResult xorResult = Evaluator.evaluateBinaryClassification(xorset, xorFunction, Vector::labeled);
assertEquals(1., xorResult.get(MetricName.ACCURACY), 0.01);
assertEquals(1., xorResult.get(MetricName.PRECISION), 0.01);
assertEquals(1., xorResult.get(MetricName.RECALL), 0.01);
assertEquals(1., xorResult.get(MetricName.F_MEASURE), 0.01);
EvaluationResult andResult = Evaluator.evaluateBinaryClassification(xorset, andFunction, Vector::labeled);
assertEquals(0.25, andResult.get(MetricName.ACCURACY), 0.01);
// there is no TP
assertEquals(0., andResult.get(MetricName.PRECISION), 0.01);
// there is no TP
assertEquals(0., andResult.get(MetricName.RECALL), 0.01);
// // there is no TP and zero in denominator
assertEquals(Double.NaN, andResult.get(MetricName.F_MEASURE), 0.01);
EvaluationResult orResult = Evaluator.evaluateBinaryClassification(xorset, orFunction, Vector::labeled);
assertEquals(0.75, orResult.get(MetricName.ACCURACY), 0.01);
// there is no TP
assertEquals(0.66, orResult.get(MetricName.PRECISION), 0.01);
// there is no TP
assertEquals(1., orResult.get(MetricName.RECALL), 0.01);
// // there is no TP and zero in denominator
assertEquals(0.8, orResult.get(MetricName.F_MEASURE), 0.01);
}
use of org.apache.ignite.ml.IgniteModel in project ignite by apache.
the class RegressionMetricsTest method testCalculation.
/**
*/
@Test
public void testCalculation() {
Map<Vector, Double> linearSet = new HashMap<Vector, Double>() {
{
put(VectorUtils.of(0.), 0.);
put(VectorUtils.of(1.), 1.);
put(VectorUtils.of(2.), 2.);
put(VectorUtils.of(3.), 3.);
}
};
IgniteModel<Vector, Double> linearModel = v -> v.get(0);
IgniteModel<Vector, Double> squareModel = v -> Math.pow(v.get(0), 2);
EvaluationResult linearRes = Evaluator.evaluateRegression(linearSet, linearModel, Vector::labeled);
assertEquals(0., linearRes.get(MetricName.MAE), 0.01);
assertEquals(0., linearRes.get(MetricName.MSE), 0.01);
assertEquals(0., linearRes.get(MetricName.R2), 0.01);
assertEquals(0., linearRes.get(MetricName.RSS), 0.01);
assertEquals(0., linearRes.get(MetricName.RMSE), 0.01);
EvaluationResult squareRes = Evaluator.evaluateRegression(linearSet, squareModel, Vector::labeled);
assertEquals(2., squareRes.get(MetricName.MAE), 0.01);
assertEquals(10., squareRes.get(MetricName.MSE), 0.01);
assertEquals(8., squareRes.get(MetricName.R2), 0.01);
assertEquals(40., squareRes.get(MetricName.RSS), 0.01);
assertEquals(Math.sqrt(10), squareRes.get(MetricName.RMSE), 0.01);
}
use of org.apache.ignite.ml.IgniteModel in project ignite by apache.
the class ModelStorageExample method main.
/**
* Run example.
*/
public static void main(String... args) throws IOException, ClassNotFoundException {
try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) {
System.out.println(">>> Ignite grid started.");
ModelStorage storage = new ModelStorageFactory().getModelStorage(ignite);
ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
System.out.println("Saving model into model storage...");
byte[] mdl = serialize((IgniteModel<byte[], byte[]>) i -> i);
storage.mkdirs("/");
storage.putFile("/my_model", mdl);
System.out.println("Saving model descriptor into model descriptor storage...");
ModelDescriptor desc = new ModelDescriptor("MyModel", "My Cool Model", new ModelSignature("", "", ""), new ModelStorageModelReader("/my_model"), new IgniteModelParser<>());
descStorage.put("my_model", desc);
System.out.println("List saved models...");
for (IgniteBiTuple<String, ModelDescriptor> model : descStorage) System.out.println("-> {'" + model.getKey() + "' : " + model.getValue() + "}");
System.out.println("Load saved model descriptor...");
desc = descStorage.get("my_model");
System.out.println("Build inference model...");
SingleModelBuilder mdlBuilder = new SingleModelBuilder();
try (Model<byte[], byte[]> infMdl = mdlBuilder.build(desc.getReader(), desc.getParser())) {
System.out.println("Make inference...");
for (int i = 0; i < 10; i++) {
Integer res = deserialize(infMdl.predict(serialize(i)));
System.out.println(i + " -> " + res);
}
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.IgniteModel in project ignite by apache.
the class GDBOnTreesLearningStrategy method update.
/**
* {@inheritDoc}
*/
@Override
public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
LearningEnvironment environment = envBuilder.buildForTrainer();
environment.initDeployingContext(vectorizer);
DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
assert trainer instanceof DecisionTreeTrainer;
DecisionTreeTrainer decisionTreeTrainer = (DecisionTreeTrainer) trainer;
List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, vectorizer);
try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(vectorizer, useIdx), environment)) {
for (int i = 0; i < cntOfIterations; i++) {
double[] weights = Arrays.copyOf(compositionWeights, models.size());
WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
ModelsComposition currComposition = new ModelsComposition(models, aggregator);
if (convCheck.isConverged(dataset, currComposition))
break;
dataset.compute(part -> {
if (part.getCopiedOriginalLabels() == null)
part.setCopiedOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length));
for (int j = 0; j < part.getLabels().length; j++) {
double mdlAnswer = currComposition.predict(VectorUtils.of(part.getFeatures()[j]));
double originalLbVal = externalLbToInternalMapping.apply(part.getCopiedOriginalLabels()[j]);
part.getLabels()[j] = -loss.gradient(sampleSize, originalLbVal, mdlAnswer);
}
});
long startTs = System.currentTimeMillis();
models.add(decisionTreeTrainer.fit(dataset));
double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
compositionWeights = Arrays.copyOf(compositionWeights, models.size());
return models;
}
Aggregations