use of org.apache.ignite.ml.selection.scoring.metric.MetricName in project ignite by apache.
the class Evaluator method evaluate.
/**
* Evaluate model.
*
* @param mdl The model.
* @param dataset Dataset.
* @param cache Upstream cache.
* @param preprocessor Preprocessor.
* @param metrics Metrics to compute.
* @return Evaluation result.
*/
@SuppressWarnings("unchecked")
private static <K, V> EvaluationResult evaluate(IgniteModel<Vector, Double> mdl, Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset, IgniteCache<K, V> cache, Preprocessor<K, V> preprocessor, Metric[] metrics) {
final Map<MetricName, Metric> metricMap = new HashMap<>();
final Map<MetricName, Class> metricToAggrCls = new HashMap<>();
for (Metric metric : metrics) {
MetricStatsAggregator aggregator = metric.makeAggregator();
MetricName name = metric.name();
metricToAggrCls.put(name, aggregator.getClass());
metricMap.put(name, metric);
}
Map<MetricName, Double> res = new HashMap<>();
final Map<Class, EvaluationContext> aggrClsToCtx = initEvaluationContexts(dataset, metrics);
final Map<Class, MetricStatsAggregator> aggrClsToAggr = computeStats(mdl, dataset, cache, preprocessor, aggrClsToCtx, metrics);
for (Metric metric : metrics) {
MetricName name = metric.name();
Class aggrCls = metricToAggrCls.get(name);
MetricStatsAggregator aggr = aggrClsToAggr.get(aggrCls);
res.put(name, metricMap.get(name).initBy(aggr).value());
}
return new EvaluationResult(res);
}
Aggregations