Search in sources :

Example 1 with HaversineDistance

use of com.alibaba.alink.operator.common.distance.HaversineDistance in project Alink by alibaba.

the class GeoKMeansTrainBatchOp method linkFrom.

@Override
public GeoKMeansTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    final String latitudeColName = this.getLatitudeCol();
    final String longitudeColName = this.getLongitudeCol();
    FastDistance distance = new HaversineDistance();
    final int maxIter = this.getMaxIter();
    final double tol = this.getEpsilon();
    DataSet<FastDistanceVectorData> data = in.select(new String[] { latitudeColName, longitudeColName }).getDataSet().rebalance().map(new MapFunction<Row, FastDistanceVectorData>() {

        private static final long serialVersionUID = -5236022856006527961L;

        @Override
        public FastDistanceVectorData map(Row row) {
            Vector vec = new DenseVector(new double[] { ((Number) row.getField(0)).doubleValue(), ((Number) row.getField(1)).doubleValue() });
            return distance.prepareVectorData(Row.of(vec), 0);
        }
    });
    DataSet<Integer> vectorSize = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(2);
    // Tuple3: clusterId, clusterWeight, clusterCentroid
    DataSet<FastDistanceMatrixData> initCentroid = initKmeansCentroids(data, distance, this.getParams(), vectorSize, getRandomSeed());
    DataSet<Row> finalCentroid = iterateICQ(initCentroid, data, vectorSize, maxIter, tol, distance, HasKMeansWithHaversineDistanceType.DistanceType.HAVERSINE, null, this.getLatitudeCol(), this.getLongitudeCol());
    // store the clustering model to the table
    this.setOutput(finalCentroid, new KMeansModelDataConverter().getModelSchema());
    return this;
}
Also used : FastDistanceVectorData(com.alibaba.alink.operator.common.distance.FastDistanceVectorData) KMeansModelDataConverter(com.alibaba.alink.operator.common.clustering.kmeans.KMeansModelDataConverter) HaversineDistance(com.alibaba.alink.operator.common.distance.HaversineDistance) FastDistanceMatrixData(com.alibaba.alink.operator.common.distance.FastDistanceMatrixData) FastDistance(com.alibaba.alink.operator.common.distance.FastDistance) Row(org.apache.flink.types.Row) Vector(com.alibaba.alink.common.linalg.Vector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) DenseVector(com.alibaba.alink.common.linalg.DenseVector)

Aggregations

DenseVector (com.alibaba.alink.common.linalg.DenseVector)1 Vector (com.alibaba.alink.common.linalg.Vector)1 KMeansModelDataConverter (com.alibaba.alink.operator.common.clustering.kmeans.KMeansModelDataConverter)1 FastDistance (com.alibaba.alink.operator.common.distance.FastDistance)1 FastDistanceMatrixData (com.alibaba.alink.operator.common.distance.FastDistanceMatrixData)1 FastDistanceVectorData (com.alibaba.alink.operator.common.distance.FastDistanceVectorData)1 HaversineDistance (com.alibaba.alink.operator.common.distance.HaversineDistance)1 Row (org.apache.flink.types.Row)1