Search in sources :

Example 1 with FastDistance

use of com.alibaba.alink.operator.common.distance.FastDistance in project Alink by alibaba.

the class GeoKMeansTrainBatchOp method linkFrom.

@Override
public GeoKMeansTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    final String latitudeColName = this.getLatitudeCol();
    final String longitudeColName = this.getLongitudeCol();
    FastDistance distance = new HaversineDistance();
    final int maxIter = this.getMaxIter();
    final double tol = this.getEpsilon();
    DataSet<FastDistanceVectorData> data = in.select(new String[] { latitudeColName, longitudeColName }).getDataSet().rebalance().map(new MapFunction<Row, FastDistanceVectorData>() {

        private static final long serialVersionUID = -5236022856006527961L;

        @Override
        public FastDistanceVectorData map(Row row) {
            Vector vec = new DenseVector(new double[] { ((Number) row.getField(0)).doubleValue(), ((Number) row.getField(1)).doubleValue() });
            return distance.prepareVectorData(Row.of(vec), 0);
        }
    });
    DataSet<Integer> vectorSize = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(2);
    // Tuple3: clusterId, clusterWeight, clusterCentroid
    DataSet<FastDistanceMatrixData> initCentroid = initKmeansCentroids(data, distance, this.getParams(), vectorSize, getRandomSeed());
    DataSet<Row> finalCentroid = iterateICQ(initCentroid, data, vectorSize, maxIter, tol, distance, HasKMeansWithHaversineDistanceType.DistanceType.HAVERSINE, null, this.getLatitudeCol(), this.getLongitudeCol());
    // store the clustering model to the table
    this.setOutput(finalCentroid, new KMeansModelDataConverter().getModelSchema());
    return this;
}
Also used : FastDistanceVectorData(com.alibaba.alink.operator.common.distance.FastDistanceVectorData) KMeansModelDataConverter(com.alibaba.alink.operator.common.clustering.kmeans.KMeansModelDataConverter) HaversineDistance(com.alibaba.alink.operator.common.distance.HaversineDistance) FastDistanceMatrixData(com.alibaba.alink.operator.common.distance.FastDistanceMatrixData) FastDistance(com.alibaba.alink.operator.common.distance.FastDistance) Row(org.apache.flink.types.Row) Vector(com.alibaba.alink.common.linalg.Vector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) DenseVector(com.alibaba.alink.common.linalg.DenseVector)

Example 2 with FastDistance

use of com.alibaba.alink.operator.common.distance.FastDistance in project Alink by alibaba.

the class KMeansTrainBatchOp method linkFrom.

@Override
public KMeansTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    final int maxIter = this.getMaxIter();
    final double tol = this.getEpsilon();
    final String vectorColName = this.getVectorCol();
    final DistanceType distanceType = getDistanceType();
    FastDistance distance = distanceType.getFastDistance();
    Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> statistics = StatisticsHelper.summaryHelper(in, null, vectorColName);
    DataSet<Integer> vectorSize = statistics.f1.map(new MapFunction<BaseVectorSummary, Integer>() {

        private static final long serialVersionUID = 4184586558834055401L;

        @Override
        public Integer map(BaseVectorSummary value) {
            Preconditions.checkArgument(value.count() > 0, "The train dataset is empty!");
            return value.vectorSize();
        }
    });
    DataSet<FastDistanceVectorData> data = statistics.f0.rebalance().map(new RichMapFunction<Vector, FastDistanceVectorData>() {

        private static final long serialVersionUID = -7443226889326704768L;

        private int vectorSize;

        @Override
        public void open(Configuration params) {
            vectorSize = (int) this.getRuntimeContext().getBroadcastVariable(VECTOR_SIZE).get(0);
        }

        @Override
        public FastDistanceVectorData map(Vector value) {
            if (value instanceof SparseVector) {
                ((SparseVector) value).setSize(vectorSize);
            }
            return distance.prepareVectorData(Row.of(value), 0);
        }
    }).withBroadcastSet(vectorSize, VECTOR_SIZE);
    DataSet<FastDistanceMatrixData> initCentroid = KMeansInitCentroids.initKmeansCentroids(data, distance, this.getParams(), vectorSize, getRandomSeed());
    DataSet<Row> finalCentroid = iterateICQ(initCentroid, data, vectorSize, maxIter, tol, distance, HasKMeansWithHaversineDistanceType.DistanceType.valueOf(distanceType.name()), vectorColName, null, null);
    this.setOutput(finalCentroid, new KMeansModelDataConverter().getModelSchema());
    return this;
}
Also used : Configuration(org.apache.flink.configuration.Configuration) DataSet(org.apache.flink.api.java.DataSet) FastDistanceVectorData(com.alibaba.alink.operator.common.distance.FastDistanceVectorData) HasKMeansWithHaversineDistanceType(com.alibaba.alink.params.shared.clustering.HasKMeansWithHaversineDistanceType) SparseVector(com.alibaba.alink.common.linalg.SparseVector) FastDistance(com.alibaba.alink.operator.common.distance.FastDistance) BaseVectorSummary(com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary) Vector(com.alibaba.alink.common.linalg.Vector) SparseVector(com.alibaba.alink.common.linalg.SparseVector) KMeansModelDataConverter(com.alibaba.alink.operator.common.clustering.kmeans.KMeansModelDataConverter) FastDistanceMatrixData(com.alibaba.alink.operator.common.distance.FastDistanceMatrixData) RichMapFunction(org.apache.flink.api.common.functions.RichMapFunction) Row(org.apache.flink.types.Row)

Example 3 with FastDistance

use of com.alibaba.alink.operator.common.distance.FastDistance in project Alink by alibaba.

the class VectorModelDataConverter method buildIndex.

@Override
public DataSet<Row> buildIndex(BatchOperator in, Params params) {
    DataSet<Row> dataSet = in.getDataSet();
    FastDistance fastDistance = params.get(HasFastDistanceType.DISTANCE_TYPE).getFastDistance();
    DataSet<Row> index = dataSet.mapPartition(new RichMapPartitionFunction<Row, Row>() {

        private static final long serialVersionUID = -6035963841026118219L;

        @Override
        public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
            List<FastDistanceData> list = fastDistance.prepareMatrixData(values, 1, 0);
            for (FastDistanceData fastDistanceData : list) {
                Row row = new Row(ROW_SIZE);
                if (fastDistanceData instanceof FastDistanceMatrixData) {
                    row.setField(FASTDISTANCE_TYPE_INDEX, 1L);
                    FastDistanceMatrixData data = (FastDistanceMatrixData) fastDistanceData;
                    row.setField(DATA_INDEX, data.toString());
                } else if (fastDistanceData instanceof FastDistanceVectorData) {
                    row.setField(FASTDISTANCE_TYPE_INDEX, 2L);
                    FastDistanceVectorData data = (FastDistanceVectorData) fastDistanceData;
                    row.setField(DATA_INDEX, data.toString());
                } else if (fastDistanceData instanceof FastDistanceSparseData) {
                    row.setField(FASTDISTANCE_TYPE_INDEX, 3L);
                    FastDistanceSparseData data = (FastDistanceSparseData) fastDistanceData;
                    row.setField(DATA_INDEX, data.toString());
                } else {
                    throw new RuntimeException(fastDistanceData.getClass().getName() + "is not supported!");
                }
                out.collect(row);
            }
        }
    });
    return index.mapPartition(new RichMapPartitionFunction<Row, Row>() {

        private static final long serialVersionUID = 661383020005730224L;

        @Override
        public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
            Params meta = null;
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                meta = params;
            }
            new VectorModelDataConverter().save(Tuple2.of(meta, values), out);
        }
    }).name("build_model");
}
Also used : RichMapPartitionFunction(org.apache.flink.api.common.functions.RichMapPartitionFunction) Params(org.apache.flink.ml.api.misc.param.Params) FastDistanceVectorData(com.alibaba.alink.operator.common.distance.FastDistanceVectorData) FastDistanceSparseData(com.alibaba.alink.operator.common.distance.FastDistanceSparseData) FastDistanceMatrixData(com.alibaba.alink.operator.common.distance.FastDistanceMatrixData) FastDistance(com.alibaba.alink.operator.common.distance.FastDistance) Collector(org.apache.flink.util.Collector) FastDistanceData(com.alibaba.alink.operator.common.distance.FastDistanceData) ArrayList(java.util.ArrayList) List(java.util.List) Row(org.apache.flink.types.Row)

Example 4 with FastDistance

use of com.alibaba.alink.operator.common.distance.FastDistance in project Alink by alibaba.

the class EvalClusterBatchOp method linkFrom.

@Override
public EvalClusterBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator in = checkAndGetFirst(inputs);
    String labelColName = this.getLabelCol();
    String predResultColName = this.getPredictionCol();
    String vectorColName = this.getVectorCol();
    DistanceType distanceType = getDistanceType();
    FastDistance distance = distanceType.getFastDistance();
    DataSet<Params> empty = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Params());
    DataSet<Params> labelMetrics = empty, vectorMetrics;
    if (null != labelColName) {
        DataSet<Row> data = in.select(new String[] { labelColName, predResultColName }).getDataSet();
        DataSet<Tuple1<Map<Object, Integer>>> labels = data.flatMap(new FlatMapFunction<Row, Object>() {

            private static final long serialVersionUID = 6181506719667975996L;

            @Override
            public void flatMap(Row row, Collector<Object> collector) {
                if (EvaluationUtil.checkRowFieldNotNull(row)) {
                    collector.collect(row.getField(0));
                }
            }
        }).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(false, null, null, false)).project(0);
        DataSet<Tuple1<Map<Object, Integer>>> predictions = data.flatMap(new FlatMapFunction<Row, Object>() {

            private static final long serialVersionUID = 619373417169823128L;

            @Override
            public void flatMap(Row row, Collector<Object> collector) {
                if (EvaluationUtil.checkRowFieldNotNull(row)) {
                    collector.collect(row.getField(1));
                }
            }
        }).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(false, null, null, false)).project(0);
        labelMetrics = data.rebalance().mapPartition(new CalLocalPredResult()).withBroadcastSet(labels, LABELS).withBroadcastSet(predictions, PREDICTIONS).reduce(new ReduceFunction<LongMatrix>() {

            private static final long serialVersionUID = 3340266128816528106L;

            @Override
            public LongMatrix reduce(LongMatrix value1, LongMatrix value2) {
                value1.plusEqual(value2);
                return value1;
            }
        }).map(new RichMapFunction<LongMatrix, Params>() {

            private static final long serialVersionUID = -4218363116865487327L;

            @Override
            public Params map(LongMatrix value) {
                List<Tuple1<Map<Object, Integer>>> labels = getRuntimeContext().getBroadcastVariable(LABELS);
                List<Tuple1<Map<Object, Integer>>> predictions = getRuntimeContext().getBroadcastVariable(PREDICTIONS);
                return ClusterEvaluationUtil.extractParamsFromConfusionMatrix(value, labels.get(0).f0, predictions.get(0).f0);
            }
        }).withBroadcastSet(labels, LABELS).withBroadcastSet(predictions, PREDICTIONS);
    }
    if (null != vectorColName) {
        Tuple2<DataSet<Tuple2<Vector, Row>>, DataSet<BaseVectorSummary>> statistics = StatisticsHelper.summaryHelper(inputs[0], null, vectorColName, new String[] { predResultColName });
        DataSet<Tuple2<Vector, String>> nonEmpty = statistics.f0.flatMap(new FilterEmptyRow()).withBroadcastSet(statistics.f1, VECTOR_SIZE);
        DataSet<Tuple3<String, DenseVector, DenseVector>> meanAndSum = nonEmpty.groupBy(1).reduceGroup(new CalcMeanAndSum(distance)).withBroadcastSet(statistics.f1, VECTOR_SIZE);
        DataSet<BaseMetricsSummary> metricsSummary = nonEmpty.coGroup(meanAndSum).where(1).equalTo(0).with(new CalcClusterMetricsSummary(distance)).withBroadcastSet(meanAndSum, MEAN_AND_SUM).reduce(new EvaluationUtil.ReduceBaseMetrics());
        DataSet<Tuple1<Double>> silhouetteCoefficient = nonEmpty.map(new RichMapFunction<Tuple2<Vector, String>, Tuple1<Double>>() {

            private static final long serialVersionUID = 116926378586242272L;

            @Override
            public Tuple1<Double> map(Tuple2<Vector, String> value) {
                List<BaseMetricsSummary> list = getRuntimeContext().getBroadcastVariable(METRICS_SUMMARY);
                return ClusterEvaluationUtil.calSilhouetteCoefficient(value, (ClusterMetricsSummary) list.get(0));
            }
        }).withBroadcastSet(metricsSummary, METRICS_SUMMARY).aggregate(Aggregations.SUM, 0);
        vectorMetrics = metricsSummary.map(new ClusterEvaluationUtil.SaveDataAsParams()).withBroadcastSet(silhouetteCoefficient, SILHOUETTE_COEFFICIENT);
    } else {
        vectorMetrics = in.select(predResultColName).getDataSet().reduceGroup(new BasicClusterParams());
    }
    DataSet<Row> out = labelMetrics.union(vectorMetrics).reduceGroup(new GroupReduceFunction<Params, Row>() {

        private static final long serialVersionUID = -4726713311986089251L;

        @Override
        public void reduce(Iterable<Params> values, Collector<Row> out) {
            Params params = new Params();
            for (Params p : values) {
                params.merge(p);
            }
            out.collect(Row.of(params.toJson()));
        }
    });
    this.setOutputTable(DataSetConversionUtil.toTable(getMLEnvironmentId(), out, new TableSchema(new String[] { EVAL_RESULT }, new TypeInformation[] { Types.STRING })));
    return this;
}
Also used : TableSchema(org.apache.flink.table.api.TableSchema) DataSet(org.apache.flink.api.java.DataSet) FastDistance(com.alibaba.alink.operator.common.distance.FastDistance) List(java.util.List) Vector(com.alibaba.alink.common.linalg.Vector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) SparseVector(com.alibaba.alink.common.linalg.SparseVector) ClusterEvaluationUtil(com.alibaba.alink.operator.common.evaluation.ClusterEvaluationUtil) Params(org.apache.flink.ml.api.misc.param.Params) EvalClusterParams(com.alibaba.alink.params.evaluation.EvalClusterParams) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) LongMatrix(com.alibaba.alink.operator.common.evaluation.LongMatrix) EvaluationUtil(com.alibaba.alink.operator.common.evaluation.EvaluationUtil) ClusterEvaluationUtil(com.alibaba.alink.operator.common.evaluation.ClusterEvaluationUtil) Tuple1(org.apache.flink.api.java.tuple.Tuple1) RichMapFunction(org.apache.flink.api.common.functions.RichMapFunction) Tuple2(org.apache.flink.api.java.tuple.Tuple2) Tuple3(org.apache.flink.api.java.tuple.Tuple3) ClusterMetricsSummary(com.alibaba.alink.operator.common.evaluation.ClusterMetricsSummary) Row(org.apache.flink.types.Row) BaseMetricsSummary(com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary)

Aggregations

FastDistance (com.alibaba.alink.operator.common.distance.FastDistance)4 Row (org.apache.flink.types.Row)4 Vector (com.alibaba.alink.common.linalg.Vector)3 FastDistanceMatrixData (com.alibaba.alink.operator.common.distance.FastDistanceMatrixData)3 FastDistanceVectorData (com.alibaba.alink.operator.common.distance.FastDistanceVectorData)3 DenseVector (com.alibaba.alink.common.linalg.DenseVector)2 SparseVector (com.alibaba.alink.common.linalg.SparseVector)2 KMeansModelDataConverter (com.alibaba.alink.operator.common.clustering.kmeans.KMeansModelDataConverter)2 List (java.util.List)2 RichMapFunction (org.apache.flink.api.common.functions.RichMapFunction)2 DataSet (org.apache.flink.api.java.DataSet)2 Params (org.apache.flink.ml.api.misc.param.Params)2 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)1 FastDistanceData (com.alibaba.alink.operator.common.distance.FastDistanceData)1 FastDistanceSparseData (com.alibaba.alink.operator.common.distance.FastDistanceSparseData)1 HaversineDistance (com.alibaba.alink.operator.common.distance.HaversineDistance)1 BaseMetricsSummary (com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary)1 ClusterEvaluationUtil (com.alibaba.alink.operator.common.evaluation.ClusterEvaluationUtil)1 ClusterMetricsSummary (com.alibaba.alink.operator.common.evaluation.ClusterMetricsSummary)1 EvaluationUtil (com.alibaba.alink.operator.common.evaluation.EvaluationUtil)1