Search in sources :

Example 1 with ClusterMetricsSummary

use of com.alibaba.alink.operator.common.evaluation.ClusterMetricsSummary in project Alink by alibaba.

the class EvalClusterBatchOp method linkFrom.

@Override
public EvalClusterBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator in = checkAndGetFirst(inputs);
    String labelColName = this.getLabelCol();
    String predResultColName = this.getPredictionCol();
    String vectorColName = this.getVectorCol();
    DistanceType distanceType = getDistanceType();
    FastDistance distance = distanceType.getFastDistance();
    DataSet<Params> empty = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Params());
    DataSet<Params> labelMetrics = empty, vectorMetrics;
    if (null != labelColName) {
        DataSet<Row> data = in.select(new String[] { labelColName, predResultColName }).getDataSet();
        DataSet<Tuple1<Map<Object, Integer>>> labels = data.flatMap(new FlatMapFunction<Row, Object>() {

            private static final long serialVersionUID = 6181506719667975996L;

            @Override
            public void flatMap(Row row, Collector<Object> collector) {
                if (EvaluationUtil.checkRowFieldNotNull(row)) {
                    collector.collect(row.getField(0));
                }
            }
        }).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(false, null, null, false)).project(0);
        DataSet<Tuple1<Map<Object, Integer>>> predictions = data.flatMap(new FlatMapFunction<Row, Object>() {

            private static final long serialVersionUID = 619373417169823128L;

            @Override
            public void flatMap(Row row, Collector<Object> collector) {
                if (EvaluationUtil.checkRowFieldNotNull(row)) {
                    collector.collect(row.getField(1));
                }
            }
        }).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(false, null, null, false)).project(0);
        labelMetrics = data.rebalance().mapPartition(new CalLocalPredResult()).withBroadcastSet(labels, LABELS).withBroadcastSet(predictions, PREDICTIONS).reduce(new ReduceFunction<LongMatrix>() {

            private static final long serialVersionUID = 3340266128816528106L;

            @Override
            public LongMatrix reduce(LongMatrix value1, LongMatrix value2) {
                value1.plusEqual(value2);
                return value1;
            }
        }).map(new RichMapFunction<LongMatrix, Params>() {

            private static final long serialVersionUID = -4218363116865487327L;

            @Override
            public Params map(LongMatrix value) {
                List<Tuple1<Map<Object, Integer>>> labels = getRuntimeContext().getBroadcastVariable(LABELS);
                List<Tuple1<Map<Object, Integer>>> predictions = getRuntimeContext().getBroadcastVariable(PREDICTIONS);
                return ClusterEvaluationUtil.extractParamsFromConfusionMatrix(value, labels.get(0).f0, predictions.get(0).f0);
            }
        }).withBroadcastSet(labels, LABELS).withBroadcastSet(predictions, PREDICTIONS);
    }
    if (null != vectorColName) {
        Tuple2<DataSet<Tuple2<Vector, Row>>, DataSet<BaseVectorSummary>> statistics = StatisticsHelper.summaryHelper(inputs[0], null, vectorColName, new String[] { predResultColName });
        DataSet<Tuple2<Vector, String>> nonEmpty = statistics.f0.flatMap(new FilterEmptyRow()).withBroadcastSet(statistics.f1, VECTOR_SIZE);
        DataSet<Tuple3<String, DenseVector, DenseVector>> meanAndSum = nonEmpty.groupBy(1).reduceGroup(new CalcMeanAndSum(distance)).withBroadcastSet(statistics.f1, VECTOR_SIZE);
        DataSet<BaseMetricsSummary> metricsSummary = nonEmpty.coGroup(meanAndSum).where(1).equalTo(0).with(new CalcClusterMetricsSummary(distance)).withBroadcastSet(meanAndSum, MEAN_AND_SUM).reduce(new EvaluationUtil.ReduceBaseMetrics());
        DataSet<Tuple1<Double>> silhouetteCoefficient = nonEmpty.map(new RichMapFunction<Tuple2<Vector, String>, Tuple1<Double>>() {

            private static final long serialVersionUID = 116926378586242272L;

            @Override
            public Tuple1<Double> map(Tuple2<Vector, String> value) {
                List<BaseMetricsSummary> list = getRuntimeContext().getBroadcastVariable(METRICS_SUMMARY);
                return ClusterEvaluationUtil.calSilhouetteCoefficient(value, (ClusterMetricsSummary) list.get(0));
            }
        }).withBroadcastSet(metricsSummary, METRICS_SUMMARY).aggregate(Aggregations.SUM, 0);
        vectorMetrics = metricsSummary.map(new ClusterEvaluationUtil.SaveDataAsParams()).withBroadcastSet(silhouetteCoefficient, SILHOUETTE_COEFFICIENT);
    } else {
        vectorMetrics = in.select(predResultColName).getDataSet().reduceGroup(new BasicClusterParams());
    }
    DataSet<Row> out = labelMetrics.union(vectorMetrics).reduceGroup(new GroupReduceFunction<Params, Row>() {

        private static final long serialVersionUID = -4726713311986089251L;

        @Override
        public void reduce(Iterable<Params> values, Collector<Row> out) {
            Params params = new Params();
            for (Params p : values) {
                params.merge(p);
            }
            out.collect(Row.of(params.toJson()));
        }
    });
    this.setOutputTable(DataSetConversionUtil.toTable(getMLEnvironmentId(), out, new TableSchema(new String[] { EVAL_RESULT }, new TypeInformation[] { Types.STRING })));
    return this;
}
Also used : TableSchema(org.apache.flink.table.api.TableSchema) DataSet(org.apache.flink.api.java.DataSet) FastDistance(com.alibaba.alink.operator.common.distance.FastDistance) List(java.util.List) Vector(com.alibaba.alink.common.linalg.Vector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) SparseVector(com.alibaba.alink.common.linalg.SparseVector) ClusterEvaluationUtil(com.alibaba.alink.operator.common.evaluation.ClusterEvaluationUtil) Params(org.apache.flink.ml.api.misc.param.Params) EvalClusterParams(com.alibaba.alink.params.evaluation.EvalClusterParams) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) LongMatrix(com.alibaba.alink.operator.common.evaluation.LongMatrix) EvaluationUtil(com.alibaba.alink.operator.common.evaluation.EvaluationUtil) ClusterEvaluationUtil(com.alibaba.alink.operator.common.evaluation.ClusterEvaluationUtil) Tuple1(org.apache.flink.api.java.tuple.Tuple1) RichMapFunction(org.apache.flink.api.common.functions.RichMapFunction) Tuple2(org.apache.flink.api.java.tuple.Tuple2) Tuple3(org.apache.flink.api.java.tuple.Tuple3) ClusterMetricsSummary(com.alibaba.alink.operator.common.evaluation.ClusterMetricsSummary) Row(org.apache.flink.types.Row) BaseMetricsSummary(com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary)

Aggregations

DenseVector (com.alibaba.alink.common.linalg.DenseVector)1 SparseVector (com.alibaba.alink.common.linalg.SparseVector)1 Vector (com.alibaba.alink.common.linalg.Vector)1 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)1 FastDistance (com.alibaba.alink.operator.common.distance.FastDistance)1 BaseMetricsSummary (com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary)1 ClusterEvaluationUtil (com.alibaba.alink.operator.common.evaluation.ClusterEvaluationUtil)1 ClusterMetricsSummary (com.alibaba.alink.operator.common.evaluation.ClusterMetricsSummary)1 EvaluationUtil (com.alibaba.alink.operator.common.evaluation.EvaluationUtil)1 LongMatrix (com.alibaba.alink.operator.common.evaluation.LongMatrix)1 EvalClusterParams (com.alibaba.alink.params.evaluation.EvalClusterParams)1 List (java.util.List)1 RichMapFunction (org.apache.flink.api.common.functions.RichMapFunction)1 DataSet (org.apache.flink.api.java.DataSet)1 Tuple1 (org.apache.flink.api.java.tuple.Tuple1)1 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)1 Tuple3 (org.apache.flink.api.java.tuple.Tuple3)1 Params (org.apache.flink.ml.api.misc.param.Params)1 TableSchema (org.apache.flink.table.api.TableSchema)1 Row (org.apache.flink.types.Row)1