Search in sources :

Example 1 with MetricName

use of org.apache.ignite.ml.selection.scoring.metric.MetricName in project ignite by apache.

the class Evaluator method evaluate.

/**
 * Evaluate model.
 *
 * @param mdl          The model.
 * @param dataset      Dataset.
 * @param cache        Upstream cache.
 * @param preprocessor Preprocessor.
 * @param metrics      Metrics to compute.
 * @return Evaluation result.
 */
@SuppressWarnings("unchecked")
private static <K, V> EvaluationResult evaluate(IgniteModel<Vector, Double> mdl, Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset, IgniteCache<K, V> cache, Preprocessor<K, V> preprocessor, Metric[] metrics) {
    final Map<MetricName, Metric> metricMap = new HashMap<>();
    final Map<MetricName, Class> metricToAggrCls = new HashMap<>();
    for (Metric metric : metrics) {
        MetricStatsAggregator aggregator = metric.makeAggregator();
        MetricName name = metric.name();
        metricToAggrCls.put(name, aggregator.getClass());
        metricMap.put(name, metric);
    }
    Map<MetricName, Double> res = new HashMap<>();
    final Map<Class, EvaluationContext> aggrClsToCtx = initEvaluationContexts(dataset, metrics);
    final Map<Class, MetricStatsAggregator> aggrClsToAggr = computeStats(mdl, dataset, cache, preprocessor, aggrClsToCtx, metrics);
    for (Metric metric : metrics) {
        MetricName name = metric.name();
        Class aggrCls = metricToAggrCls.get(name);
        MetricStatsAggregator aggr = aggrClsToAggr.get(aggrCls);
        res.put(name, metricMap.get(name).initBy(aggr).value());
    }
    return new EvaluationResult(res);
}
Also used : HashMap(java.util.HashMap) MetricStatsAggregator(org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator) MetricName(org.apache.ignite.ml.selection.scoring.metric.MetricName) Metric(org.apache.ignite.ml.selection.scoring.metric.Metric) EvaluationContext(org.apache.ignite.ml.selection.scoring.evaluator.context.EvaluationContext)

Aggregations

HashMap (java.util.HashMap)1 MetricStatsAggregator (org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator)1 EvaluationContext (org.apache.ignite.ml.selection.scoring.evaluator.context.EvaluationContext)1 Metric (org.apache.ignite.ml.selection.scoring.metric.Metric)1 MetricName (org.apache.ignite.ml.selection.scoring.metric.MetricName)1