Search in sources :

Example 1 with MetricStatsAggregator

use of org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator in project ignite by apache.

the class Evaluator method initEvaluationContexts.

/**
 * Inits evaluation contexts for metrics.
 *
 * @param dataset Dataset.
 * @param metrics Metrics.
 * @return Computed contexts.
 */
@SuppressWarnings("unchecked")
private static Map<Class, EvaluationContext> initEvaluationContexts(Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset, Metric... metrics) {
    long nonEmptyCtxsCnt = Arrays.stream(metrics).map(x -> x.makeAggregator().createInitializedContext()).filter(x -> ((EvaluationContext) x).needToCompute()).count();
    if (nonEmptyCtxsCnt == 0) {
        HashMap<Class, EvaluationContext> res = new HashMap<>();
        for (Metric m : metrics) {
            MetricStatsAggregator<Double, ?, ?> aggregator = m.makeAggregator();
            res.put(aggregator.getClass(), (EvaluationContext) m.makeAggregator().createInitializedContext());
            return res;
        }
    }
    return dataset.compute(data -> {
        Map<Class, MetricStatsAggregator> aggrs = new HashMap<>();
        for (Metric m : metrics) {
            MetricStatsAggregator<Double, ?, ?> aggregator = m.makeAggregator();
            if (!aggrs.containsKey(aggregator.getClass()))
                aggrs.put(aggregator.getClass(), aggregator);
        }
        Map<Class, EvaluationContext> aggrToEvCtx = new HashMap<>();
        aggrs.forEach((clazz, aggr) -> aggrToEvCtx.put(clazz, (EvaluationContext) aggr.createInitializedContext()));
        for (int i = 0; i < data.getLabels().length; i++) {
            LabeledVector<Double> vector = VectorUtils.of(data.getFeatures()[i]).labeled(data.getLabels()[i]);
            aggrToEvCtx.values().forEach(ctx -> ctx.aggregate(vector));
        }
        return aggrToEvCtx;
    }, (left, right) -> {
        if (left == null && right == null)
            return new HashMap<>();
        if (left == null)
            return right;
        if (right == null)
            return left;
        HashMap<Class, EvaluationContext> res = new HashMap<>();
        for (Class key : left.keySet()) {
            EvaluationContext ctx1 = left.get(key);
            EvaluationContext ctx2 = right.get(key);
            A.ensure(ctx1 != null && ctx2 != null, "ctx1 != null && ctx2 != null");
            res.put(key, ctx1.mergeWith(ctx2));
        }
        return res;
    });
}
Also used : FeatureMatrixWithLabelsOnHeapDataBuilder(org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder) Metric(org.apache.ignite.ml.selection.scoring.metric.Metric) Arrays(java.util.Arrays) IgniteBiPredicate(org.apache.ignite.lang.IgniteBiPredicate) EvaluationContext(org.apache.ignite.ml.selection.scoring.evaluator.context.EvaluationContext) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) HashMap(java.util.HashMap) MetricStatsAggregator(org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) MetricName(org.apache.ignite.ml.selection.scoring.metric.MetricName) Map(java.util.Map) Cache(javax.cache.Cache) LocalDatasetBuilder(org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder) EmptyContextBuilder(org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) A(org.apache.ignite.internal.util.typedef.internal.A) FeatureMatrixWithLabelsOnHeapData(org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) IgniteModel(org.apache.ignite.ml.IgniteModel) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) KNNModel(org.apache.ignite.ml.knn.KNNModel) IgniteCache(org.apache.ignite.IgniteCache) Ignition(org.apache.ignite.Ignition) VectorUtils(org.apache.ignite.ml.math.primitives.vector.VectorUtils) Dataset(org.apache.ignite.ml.dataset.Dataset) QueryCursor(org.apache.ignite.cache.query.QueryCursor) ScanQuery(org.apache.ignite.cache.query.ScanQuery) HashMap(java.util.HashMap) MetricStatsAggregator(org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator) Metric(org.apache.ignite.ml.selection.scoring.metric.Metric) EvaluationContext(org.apache.ignite.ml.selection.scoring.evaluator.context.EvaluationContext)

Example 2 with MetricStatsAggregator

use of org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator in project ignite by apache.

the class Evaluator method evaluate.

/**
 * Evaluate model.
 *
 * @param mdl          The model.
 * @param dataset      Dataset.
 * @param cache        Upstream cache.
 * @param preprocessor Preprocessor.
 * @param metrics      Metrics to compute.
 * @return Evaluation result.
 */
@SuppressWarnings("unchecked")
private static <K, V> EvaluationResult evaluate(IgniteModel<Vector, Double> mdl, Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset, IgniteCache<K, V> cache, Preprocessor<K, V> preprocessor, Metric[] metrics) {
    final Map<MetricName, Metric> metricMap = new HashMap<>();
    final Map<MetricName, Class> metricToAggrCls = new HashMap<>();
    for (Metric metric : metrics) {
        MetricStatsAggregator aggregator = metric.makeAggregator();
        MetricName name = metric.name();
        metricToAggrCls.put(name, aggregator.getClass());
        metricMap.put(name, metric);
    }
    Map<MetricName, Double> res = new HashMap<>();
    final Map<Class, EvaluationContext> aggrClsToCtx = initEvaluationContexts(dataset, metrics);
    final Map<Class, MetricStatsAggregator> aggrClsToAggr = computeStats(mdl, dataset, cache, preprocessor, aggrClsToCtx, metrics);
    for (Metric metric : metrics) {
        MetricName name = metric.name();
        Class aggrCls = metricToAggrCls.get(name);
        MetricStatsAggregator aggr = aggrClsToAggr.get(aggrCls);
        res.put(name, metricMap.get(name).initBy(aggr).value());
    }
    return new EvaluationResult(res);
}
Also used : HashMap(java.util.HashMap) MetricStatsAggregator(org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator) MetricName(org.apache.ignite.ml.selection.scoring.metric.MetricName) Metric(org.apache.ignite.ml.selection.scoring.metric.Metric) EvaluationContext(org.apache.ignite.ml.selection.scoring.evaluator.context.EvaluationContext)

Example 3 with MetricStatsAggregator

use of org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator in project ignite by apache.

the class Evaluator method computeStats.

/**
 * Aggregates statistics for metrics evaluation.
 *
 * @param dataset Dataset.
 * @param metrics Metrics.
 * @return Aggregated statistics.
 */
@SuppressWarnings("unchecked")
private static <K, V> Map<Class, MetricStatsAggregator> computeStats(IgniteModel<Vector, Double> mdl, Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset, IgniteCache<K, V> cache, Preprocessor<K, V> preprocessor, Map<Class, EvaluationContext> ctxs, Metric... metrics) {
    if (isOnlyLocalEstimation(mdl) && cache != null) {
        Map<Class, MetricStatsAggregator> aggrs = initAggregators(ctxs, metrics);
        try (QueryCursor<Cache.Entry<K, V>> qry = cache.query(new ScanQuery<>())) {
            qry.iterator().forEachRemaining(kv -> {
                LabeledVector vector = preprocessor.apply(kv.getKey(), kv.getValue());
                for (Class key : aggrs.keySet()) {
                    MetricStatsAggregator aggr = aggrs.get(key);
                    aggr.aggregate(mdl, vector);
                }
            });
        }
        return aggrs;
    } else {
        return dataset.compute(data -> {
            Map<Class, MetricStatsAggregator> aggrs = initAggregators(ctxs, metrics);
            for (int i = 0; i < data.getLabels().length; i++) {
                LabeledVector<Double> vector = VectorUtils.of(data.getFeatures()[i]).labeled(data.getLabels()[i]);
                for (Class key : aggrs.keySet()) {
                    MetricStatsAggregator aggr = aggrs.get(key);
                    aggr.aggregate(mdl, vector);
                }
            }
            return aggrs;
        }, (left, right) -> {
            if (left == null && right == null)
                return new HashMap<>();
            if (left == null)
                return right;
            if (right == null)
                return left;
            HashMap<Class, MetricStatsAggregator> res = new HashMap<>();
            for (Class key : left.keySet()) {
                MetricStatsAggregator agg1 = left.get(key);
                MetricStatsAggregator agg2 = right.get(key);
                A.ensure(agg1 != null && agg2 != null, "agg1 != null && agg2 != null");
                res.put(key, agg1.mergeWith(agg2));
            }
            return res;
        });
    }
}
Also used : HashMap(java.util.HashMap) MetricStatsAggregator(org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator) LabeledVector(org.apache.ignite.ml.structures.LabeledVector)

Example 4 with MetricStatsAggregator

use of org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator in project ignite by apache.

the class Evaluator method initAggregators.

/**
 * Inits aggregators.
 *
 * @param ctxs    Evaluation contexts.
 * @param metrics Metrics.
 * @return Aggregators map.
 */
private static Map<Class, MetricStatsAggregator> initAggregators(Map<Class, EvaluationContext> ctxs, Metric[] metrics) {
    Map<Class, MetricStatsAggregator> aggrs = new HashMap<>();
    for (Metric m : metrics) {
        MetricStatsAggregator aggregator = m.makeAggregator();
        EvaluationContext ctx = ctxs.get(aggregator.getClass());
        A.ensure(ctx != null, "ctx != null");
        aggregator.initByContext(ctx);
        if (!aggrs.containsKey(aggregator.getClass()))
            aggrs.put(aggregator.getClass(), aggregator);
    }
    return aggrs;
}
Also used : HashMap(java.util.HashMap) MetricStatsAggregator(org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator) Metric(org.apache.ignite.ml.selection.scoring.metric.Metric) EvaluationContext(org.apache.ignite.ml.selection.scoring.evaluator.context.EvaluationContext)

Aggregations

HashMap (java.util.HashMap)4 MetricStatsAggregator (org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator)4 EvaluationContext (org.apache.ignite.ml.selection.scoring.evaluator.context.EvaluationContext)3 Metric (org.apache.ignite.ml.selection.scoring.metric.Metric)3 MetricName (org.apache.ignite.ml.selection.scoring.metric.MetricName)2 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)2 Arrays (java.util.Arrays)1 Map (java.util.Map)1 Cache (javax.cache.Cache)1 IgniteCache (org.apache.ignite.IgniteCache)1 Ignition (org.apache.ignite.Ignition)1 QueryCursor (org.apache.ignite.cache.query.QueryCursor)1 ScanQuery (org.apache.ignite.cache.query.ScanQuery)1 A (org.apache.ignite.internal.util.typedef.internal.A)1 IgniteBiPredicate (org.apache.ignite.lang.IgniteBiPredicate)1 IgniteModel (org.apache.ignite.ml.IgniteModel)1 Dataset (org.apache.ignite.ml.dataset.Dataset)1 DatasetBuilder (org.apache.ignite.ml.dataset.DatasetBuilder)1 CacheBasedDatasetBuilder (org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder)1 LocalDatasetBuilder (org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder)1