use of org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData in project ignite by apache.
the class MeanAbsValueConvergenceCheckerTest method testConvergenceChecking.
/**
*/
@Test
public void testConvergenceChecking() {
LocalDatasetBuilder<Integer, LabeledVector<Double>> datasetBuilder = new LocalDatasetBuilder<>(data, 1);
ConvergenceChecker<Integer, LabeledVector<Double>> checker = createChecker(new MeanAbsValueConvergenceCheckerFactory(0.1), datasetBuilder);
double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl);
LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder();
Assert.assertEquals(1.9, error, 0.01);
Assert.assertFalse(checker.isConverged(envBuilder, datasetBuilder, notConvergedMdl));
Assert.assertTrue(checker.isConverged(envBuilder, datasetBuilder, convergedMdl));
try (LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(vectorizer), envBuilder.buildForTrainer())) {
double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl);
Assert.assertEquals(1.55, onDSError, 0.01);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData in project ignite by apache.
the class Evaluator method initEvaluationContexts.
/**
* Inits evaluation contexts for metrics.
*
* @param dataset Dataset.
* @param metrics Metrics.
* @return Computed contexts.
*/
@SuppressWarnings("unchecked")
private static Map<Class, EvaluationContext> initEvaluationContexts(Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset, Metric... metrics) {
long nonEmptyCtxsCnt = Arrays.stream(metrics).map(x -> x.makeAggregator().createInitializedContext()).filter(x -> ((EvaluationContext) x).needToCompute()).count();
if (nonEmptyCtxsCnt == 0) {
HashMap<Class, EvaluationContext> res = new HashMap<>();
for (Metric m : metrics) {
MetricStatsAggregator<Double, ?, ?> aggregator = m.makeAggregator();
res.put(aggregator.getClass(), (EvaluationContext) m.makeAggregator().createInitializedContext());
return res;
}
}
return dataset.compute(data -> {
Map<Class, MetricStatsAggregator> aggrs = new HashMap<>();
for (Metric m : metrics) {
MetricStatsAggregator<Double, ?, ?> aggregator = m.makeAggregator();
if (!aggrs.containsKey(aggregator.getClass()))
aggrs.put(aggregator.getClass(), aggregator);
}
Map<Class, EvaluationContext> aggrToEvCtx = new HashMap<>();
aggrs.forEach((clazz, aggr) -> aggrToEvCtx.put(clazz, (EvaluationContext) aggr.createInitializedContext()));
for (int i = 0; i < data.getLabels().length; i++) {
LabeledVector<Double> vector = VectorUtils.of(data.getFeatures()[i]).labeled(data.getLabels()[i]);
aggrToEvCtx.values().forEach(ctx -> ctx.aggregate(vector));
}
return aggrToEvCtx;
}, (left, right) -> {
if (left == null && right == null)
return new HashMap<>();
if (left == null)
return right;
if (right == null)
return left;
HashMap<Class, EvaluationContext> res = new HashMap<>();
for (Class key : left.keySet()) {
EvaluationContext ctx1 = left.get(key);
EvaluationContext ctx2 = right.get(key);
A.ensure(ctx1 != null && ctx2 != null, "ctx1 != null && ctx2 != null");
res.put(key, ctx1.mergeWith(ctx2));
}
return res;
});
}
use of org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData in project ignite by apache.
the class ConvergenceChecker method isConverged.
/**
* Checks convergency on dataset.
*
* @param envBuilder Learning environment builder.
* @param currMdl Current model.
* @return True if GDB is converged.
*/
public boolean isConverged(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) {
LearningEnvironment environment = envBuilder.buildForTrainer();
environment.initDeployingContext(preprocessor);
try (Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(preprocessor), environment)) {
return isConverged(dataset, currMdl);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData in project ignite by apache.
the class MedianOfMedianConvergenceCheckerTest method testConvergenceChecking.
/**
*/
@Test
public void testConvergenceChecking() {
data.put(666, VectorUtils.of(10, 11).labeled(100000.0));
LocalDatasetBuilder<Integer, LabeledVector<Double>> datasetBuilder = new LocalDatasetBuilder<>(data, 1);
ConvergenceChecker<Integer, LabeledVector<Double>> checker = createChecker(new MedianOfMedianConvergenceCheckerFactory(0.1), datasetBuilder);
double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl);
Assert.assertEquals(1.9, error, 0.01);
LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder();
Assert.assertFalse(checker.isConverged(envBuilder, datasetBuilder, notConvergedMdl));
Assert.assertTrue(checker.isConverged(envBuilder, datasetBuilder, convergedMdl));
try (LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(vectorizer), TestUtils.testEnvBuilder().buildForTrainer())) {
double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl);
Assert.assertEquals(1.6, onDSError, 0.01);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Aggregations