use of org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator in project ignite by apache.
the class Evaluator method initEvaluationContexts.
/**
* Inits evaluation contexts for metrics.
*
* @param dataset Dataset.
* @param metrics Metrics.
* @return Computed contexts.
*/
@SuppressWarnings("unchecked")
private static Map<Class, EvaluationContext> initEvaluationContexts(Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset, Metric... metrics) {
long nonEmptyCtxsCnt = Arrays.stream(metrics).map(x -> x.makeAggregator().createInitializedContext()).filter(x -> ((EvaluationContext) x).needToCompute()).count();
if (nonEmptyCtxsCnt == 0) {
HashMap<Class, EvaluationContext> res = new HashMap<>();
for (Metric m : metrics) {
MetricStatsAggregator<Double, ?, ?> aggregator = m.makeAggregator();
res.put(aggregator.getClass(), (EvaluationContext) m.makeAggregator().createInitializedContext());
return res;
}
}
return dataset.compute(data -> {
Map<Class, MetricStatsAggregator> aggrs = new HashMap<>();
for (Metric m : metrics) {
MetricStatsAggregator<Double, ?, ?> aggregator = m.makeAggregator();
if (!aggrs.containsKey(aggregator.getClass()))
aggrs.put(aggregator.getClass(), aggregator);
}
Map<Class, EvaluationContext> aggrToEvCtx = new HashMap<>();
aggrs.forEach((clazz, aggr) -> aggrToEvCtx.put(clazz, (EvaluationContext) aggr.createInitializedContext()));
for (int i = 0; i < data.getLabels().length; i++) {
LabeledVector<Double> vector = VectorUtils.of(data.getFeatures()[i]).labeled(data.getLabels()[i]);
aggrToEvCtx.values().forEach(ctx -> ctx.aggregate(vector));
}
return aggrToEvCtx;
}, (left, right) -> {
if (left == null && right == null)
return new HashMap<>();
if (left == null)
return right;
if (right == null)
return left;
HashMap<Class, EvaluationContext> res = new HashMap<>();
for (Class key : left.keySet()) {
EvaluationContext ctx1 = left.get(key);
EvaluationContext ctx2 = right.get(key);
A.ensure(ctx1 != null && ctx2 != null, "ctx1 != null && ctx2 != null");
res.put(key, ctx1.mergeWith(ctx2));
}
return res;
});
}
use of org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator in project ignite by apache.
the class Evaluator method evaluate.
/**
* Evaluate model.
*
* @param mdl The model.
* @param dataset Dataset.
* @param cache Upstream cache.
* @param preprocessor Preprocessor.
* @param metrics Metrics to compute.
* @return Evaluation result.
*/
@SuppressWarnings("unchecked")
private static <K, V> EvaluationResult evaluate(IgniteModel<Vector, Double> mdl, Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset, IgniteCache<K, V> cache, Preprocessor<K, V> preprocessor, Metric[] metrics) {
final Map<MetricName, Metric> metricMap = new HashMap<>();
final Map<MetricName, Class> metricToAggrCls = new HashMap<>();
for (Metric metric : metrics) {
MetricStatsAggregator aggregator = metric.makeAggregator();
MetricName name = metric.name();
metricToAggrCls.put(name, aggregator.getClass());
metricMap.put(name, metric);
}
Map<MetricName, Double> res = new HashMap<>();
final Map<Class, EvaluationContext> aggrClsToCtx = initEvaluationContexts(dataset, metrics);
final Map<Class, MetricStatsAggregator> aggrClsToAggr = computeStats(mdl, dataset, cache, preprocessor, aggrClsToCtx, metrics);
for (Metric metric : metrics) {
MetricName name = metric.name();
Class aggrCls = metricToAggrCls.get(name);
MetricStatsAggregator aggr = aggrClsToAggr.get(aggrCls);
res.put(name, metricMap.get(name).initBy(aggr).value());
}
return new EvaluationResult(res);
}
use of org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator in project ignite by apache.
the class Evaluator method computeStats.
/**
* Aggregates statistics for metrics evaluation.
*
* @param dataset Dataset.
* @param metrics Metrics.
* @return Aggregated statistics.
*/
@SuppressWarnings("unchecked")
private static <K, V> Map<Class, MetricStatsAggregator> computeStats(IgniteModel<Vector, Double> mdl, Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset, IgniteCache<K, V> cache, Preprocessor<K, V> preprocessor, Map<Class, EvaluationContext> ctxs, Metric... metrics) {
if (isOnlyLocalEstimation(mdl) && cache != null) {
Map<Class, MetricStatsAggregator> aggrs = initAggregators(ctxs, metrics);
try (QueryCursor<Cache.Entry<K, V>> qry = cache.query(new ScanQuery<>())) {
qry.iterator().forEachRemaining(kv -> {
LabeledVector vector = preprocessor.apply(kv.getKey(), kv.getValue());
for (Class key : aggrs.keySet()) {
MetricStatsAggregator aggr = aggrs.get(key);
aggr.aggregate(mdl, vector);
}
});
}
return aggrs;
} else {
return dataset.compute(data -> {
Map<Class, MetricStatsAggregator> aggrs = initAggregators(ctxs, metrics);
for (int i = 0; i < data.getLabels().length; i++) {
LabeledVector<Double> vector = VectorUtils.of(data.getFeatures()[i]).labeled(data.getLabels()[i]);
for (Class key : aggrs.keySet()) {
MetricStatsAggregator aggr = aggrs.get(key);
aggr.aggregate(mdl, vector);
}
}
return aggrs;
}, (left, right) -> {
if (left == null && right == null)
return new HashMap<>();
if (left == null)
return right;
if (right == null)
return left;
HashMap<Class, MetricStatsAggregator> res = new HashMap<>();
for (Class key : left.keySet()) {
MetricStatsAggregator agg1 = left.get(key);
MetricStatsAggregator agg2 = right.get(key);
A.ensure(agg1 != null && agg2 != null, "agg1 != null && agg2 != null");
res.put(key, agg1.mergeWith(agg2));
}
return res;
});
}
}
use of org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator in project ignite by apache.
the class Evaluator method initAggregators.
/**
* Inits aggregators.
*
* @param ctxs Evaluation contexts.
* @param metrics Metrics.
* @return Aggregators map.
*/
private static Map<Class, MetricStatsAggregator> initAggregators(Map<Class, EvaluationContext> ctxs, Metric[] metrics) {
Map<Class, MetricStatsAggregator> aggrs = new HashMap<>();
for (Metric m : metrics) {
MetricStatsAggregator aggregator = m.makeAggregator();
EvaluationContext ctx = ctxs.get(aggregator.getClass());
A.ensure(ctx != null, "ctx != null");
aggregator.initByContext(ctx);
if (!aggrs.containsKey(aggregator.getClass()))
aggrs.put(aggregator.getClass(), aggregator);
}
return aggrs;
}
Aggregations