use of com.alibaba.alink.operator.common.evaluation.ClusterMetricsSummary in project Alink by alibaba.
the class EvalClusterBatchOp method linkFrom.
@Override
public EvalClusterBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator in = checkAndGetFirst(inputs);
String labelColName = this.getLabelCol();
String predResultColName = this.getPredictionCol();
String vectorColName = this.getVectorCol();
DistanceType distanceType = getDistanceType();
FastDistance distance = distanceType.getFastDistance();
DataSet<Params> empty = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Params());
DataSet<Params> labelMetrics = empty, vectorMetrics;
if (null != labelColName) {
DataSet<Row> data = in.select(new String[] { labelColName, predResultColName }).getDataSet();
DataSet<Tuple1<Map<Object, Integer>>> labels = data.flatMap(new FlatMapFunction<Row, Object>() {
private static final long serialVersionUID = 6181506719667975996L;
@Override
public void flatMap(Row row, Collector<Object> collector) {
if (EvaluationUtil.checkRowFieldNotNull(row)) {
collector.collect(row.getField(0));
}
}
}).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(false, null, null, false)).project(0);
DataSet<Tuple1<Map<Object, Integer>>> predictions = data.flatMap(new FlatMapFunction<Row, Object>() {
private static final long serialVersionUID = 619373417169823128L;
@Override
public void flatMap(Row row, Collector<Object> collector) {
if (EvaluationUtil.checkRowFieldNotNull(row)) {
collector.collect(row.getField(1));
}
}
}).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(false, null, null, false)).project(0);
labelMetrics = data.rebalance().mapPartition(new CalLocalPredResult()).withBroadcastSet(labels, LABELS).withBroadcastSet(predictions, PREDICTIONS).reduce(new ReduceFunction<LongMatrix>() {
private static final long serialVersionUID = 3340266128816528106L;
@Override
public LongMatrix reduce(LongMatrix value1, LongMatrix value2) {
value1.plusEqual(value2);
return value1;
}
}).map(new RichMapFunction<LongMatrix, Params>() {
private static final long serialVersionUID = -4218363116865487327L;
@Override
public Params map(LongMatrix value) {
List<Tuple1<Map<Object, Integer>>> labels = getRuntimeContext().getBroadcastVariable(LABELS);
List<Tuple1<Map<Object, Integer>>> predictions = getRuntimeContext().getBroadcastVariable(PREDICTIONS);
return ClusterEvaluationUtil.extractParamsFromConfusionMatrix(value, labels.get(0).f0, predictions.get(0).f0);
}
}).withBroadcastSet(labels, LABELS).withBroadcastSet(predictions, PREDICTIONS);
}
if (null != vectorColName) {
Tuple2<DataSet<Tuple2<Vector, Row>>, DataSet<BaseVectorSummary>> statistics = StatisticsHelper.summaryHelper(inputs[0], null, vectorColName, new String[] { predResultColName });
DataSet<Tuple2<Vector, String>> nonEmpty = statistics.f0.flatMap(new FilterEmptyRow()).withBroadcastSet(statistics.f1, VECTOR_SIZE);
DataSet<Tuple3<String, DenseVector, DenseVector>> meanAndSum = nonEmpty.groupBy(1).reduceGroup(new CalcMeanAndSum(distance)).withBroadcastSet(statistics.f1, VECTOR_SIZE);
DataSet<BaseMetricsSummary> metricsSummary = nonEmpty.coGroup(meanAndSum).where(1).equalTo(0).with(new CalcClusterMetricsSummary(distance)).withBroadcastSet(meanAndSum, MEAN_AND_SUM).reduce(new EvaluationUtil.ReduceBaseMetrics());
DataSet<Tuple1<Double>> silhouetteCoefficient = nonEmpty.map(new RichMapFunction<Tuple2<Vector, String>, Tuple1<Double>>() {
private static final long serialVersionUID = 116926378586242272L;
@Override
public Tuple1<Double> map(Tuple2<Vector, String> value) {
List<BaseMetricsSummary> list = getRuntimeContext().getBroadcastVariable(METRICS_SUMMARY);
return ClusterEvaluationUtil.calSilhouetteCoefficient(value, (ClusterMetricsSummary) list.get(0));
}
}).withBroadcastSet(metricsSummary, METRICS_SUMMARY).aggregate(Aggregations.SUM, 0);
vectorMetrics = metricsSummary.map(new ClusterEvaluationUtil.SaveDataAsParams()).withBroadcastSet(silhouetteCoefficient, SILHOUETTE_COEFFICIENT);
} else {
vectorMetrics = in.select(predResultColName).getDataSet().reduceGroup(new BasicClusterParams());
}
DataSet<Row> out = labelMetrics.union(vectorMetrics).reduceGroup(new GroupReduceFunction<Params, Row>() {
private static final long serialVersionUID = -4726713311986089251L;
@Override
public void reduce(Iterable<Params> values, Collector<Row> out) {
Params params = new Params();
for (Params p : values) {
params.merge(p);
}
out.collect(Row.of(params.toJson()));
}
});
this.setOutputTable(DataSetConversionUtil.toTable(getMLEnvironmentId(), out, new TableSchema(new String[] { EVAL_RESULT }, new TypeInformation[] { Types.STRING })));
return this;
}
Aggregations