Search in sources :

Example 1 with KDTree

use of com.alibaba.alink.operator.common.similarity.KDTree in project Alink by alibaba.

the class KDTreeModelDataConverter method buildIndex.

@Override
public DataSet<Row> buildIndex(BatchOperator in, Params params) {
    Preconditions.checkArgument(params.get(VectorApproxNearestNeighborTrainParams.METRIC).equals(VectorApproxNearestNeighborTrainParams.Metric.EUCLIDEAN), "KDTree solver only supports Euclidean distance!");
    EuclideanDistance distance = new EuclideanDistance();
    Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> statistics = StatisticsHelper.summaryHelper(in, null, params.get(VectorApproxNearestNeighborTrainParams.SELECTED_COL));
    return in.getDataSet().rebalance().mapPartition(new RichMapPartitionFunction<Row, Row>() {

        private static final long serialVersionUID = 6654757741959479783L;

        @Override
        public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
            BaseVectorSummary summary = (BaseVectorSummary) getRuntimeContext().getBroadcastVariable("vectorSize").get(0);
            int vectorSize = summary.vectorSize();
            List<FastDistanceVectorData> list = new ArrayList<>();
            for (Row row : values) {
                FastDistanceVectorData vector = distance.prepareVectorData(row, 1, 0);
                list.add(vector);
                vectorSize = vector.getVector().size();
            }
            if (list.size() > 0) {
                FastDistanceVectorData[] vectorArray = list.toArray(new FastDistanceVectorData[0]);
                KDTree tree = new KDTree(vectorArray, vectorSize, distance);
                tree.buildTree();
                int taskId = getRuntimeContext().getIndexOfThisSubtask();
                Row row = new Row(ROW_SIZE);
                row.setField(TASKID_INDEX, (long) taskId);
                for (int i = 0; i < vectorArray.length; i++) {
                    row.setField(DATA_ID_INDEX, (long) i);
                    row.setField(DATA_IDNEX, vectorArray[i].toString());
                    out.collect(row);
                }
                row.setField(DATA_ID_INDEX, null);
                row.setField(DATA_IDNEX, null);
                row.setField(ROOT_IDDEX, JsonConverter.toJson(tree.getRoot()));
                out.collect(row);
            }
        }
    }).withBroadcastSet(statistics.f1, "vectorSize").mapPartition(new RichMapPartitionFunction<Row, Row>() {

        private static final long serialVersionUID = 6849403933586157611L;

        @Override
        public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
            Params meta = null;
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                meta = params;
                BaseVectorSummary summary = (BaseVectorSummary) getRuntimeContext().getBroadcastVariable("vectorSize").get(0);
                int vectorSize = summary.vectorSize();
                meta.set(VECTOR_SIZE, vectorSize);
            }
            new KDTreeModelDataConverter().save(Tuple2.of(meta, values), out);
        }
    }).withBroadcastSet(statistics.f1, "vectorSize");
}
Also used : DataSet(org.apache.flink.api.java.DataSet) RichMapPartitionFunction(org.apache.flink.api.common.functions.RichMapPartitionFunction) ArrayList(java.util.ArrayList) VectorApproxNearestNeighborTrainParams(com.alibaba.alink.params.similarity.VectorApproxNearestNeighborTrainParams) Params(org.apache.flink.ml.api.misc.param.Params) FastDistanceVectorData(com.alibaba.alink.operator.common.distance.FastDistanceVectorData) EuclideanDistance(com.alibaba.alink.operator.common.distance.EuclideanDistance) KDTree(com.alibaba.alink.operator.common.similarity.KDTree) BaseVectorSummary(com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary) Collector(org.apache.flink.util.Collector) Row(org.apache.flink.types.Row)

Example 2 with KDTree

use of com.alibaba.alink.operator.common.similarity.KDTree in project Alink by alibaba.

the class KDTreeModelData method computeDistiance.

@Override
protected ArrayList<Tuple2<Double, Object>> computeDistiance(Object input, Integer index, Integer topN, Tuple2<Double, Object> radius) {
    KDTree tree = treeList.get(index);
    ArrayList<Tuple2<Double, Object>> tupleList = new ArrayList<>();
    if (null != topN) {
        Tuple2<Double, Row>[] treeTopN = tree.getTopN(topN, (FastDistanceVectorData) input);
        for (int i = 0; i < treeTopN.length; i++) {
            Tuple2<Double, Object> tuple = Tuple2.of(treeTopN[i].f0, treeTopN[i].f1.getField(0));
            if (null == radius || radius.f0 == null || this.getQueueComparator().compare(radius, tuple) <= 0) {
                tupleList.add(tuple);
            }
        }
    } else {
        List<FastDistanceVectorData> list = tree.rangeSearch(radius.f0, (FastDistanceVectorData) input);
        for (FastDistanceVectorData data : list) {
            Double dist = distance.calc(data, (FastDistanceVectorData) input).get(0, 0);
            tupleList.add(Tuple2.of(dist, data.getRows()[0].getField(0)));
        }
    }
    return tupleList;
}
Also used : KDTree(com.alibaba.alink.operator.common.similarity.KDTree) Tuple2(org.apache.flink.api.java.tuple.Tuple2) ArrayList(java.util.ArrayList) FastDistanceVectorData(com.alibaba.alink.operator.common.distance.FastDistanceVectorData)

Example 3 with KDTree

use of com.alibaba.alink.operator.common.similarity.KDTree in project Alink by alibaba.

the class KDTreeModelDataConverter method loadModelData.

@Override
public KDTreeModelData loadModelData(List<Row> list) {
    HashMap<Long, TreeMap<Integer, FastDistanceVectorData>> data = new HashMap<>();
    HashMap<Long, KDTree.TreeNode> root = new HashMap<>();
    for (Row row : list) {
        if (row.getField(TASKID_INDEX) != null) {
            long taskId = (long) row.getField(TASKID_INDEX);
            TreeMap<Integer, FastDistanceVectorData> vectorDataMap = data.get(taskId);
            if (null == vectorDataMap) {
                vectorDataMap = new TreeMap<>();
            }
            if (row.getField(DATA_IDNEX) != null) {
                vectorDataMap.put(((Number) row.getField(DATA_ID_INDEX)).intValue(), FastDistanceVectorData.fromString((String) row.getField(DATA_IDNEX)));
                data.put(taskId, vectorDataMap);
            } else if (row.getField(ROOT_IDDEX) != null) {
                KDTree.TreeNode node = JsonConverter.fromJson((String) row.getField(ROOT_IDDEX), new TypeReference<KDTree.TreeNode>() {
                }.getType());
                root.put(taskId, node);
            }
        }
    }
    List<KDTree> treeList = new ArrayList<>();
    int vectorSize = meta.get(VECTOR_SIZE);
    EuclideanDistance distance = new EuclideanDistance();
    for (Map.Entry<Long, TreeMap<Integer, FastDistanceVectorData>> entry : data.entrySet()) {
        long taskId = entry.getKey();
        KDTree.TreeNode node = root.get(taskId);
        FastDistanceVectorData[] vectorData = entry.getValue().values().toArray(new FastDistanceVectorData[0]);
        KDTree kdTree = new KDTree(vectorData, vectorSize, distance);
        kdTree.setRoot(node);
        treeList.add(kdTree);
    }
    return new KDTreeModelData(treeList);
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) FastDistanceVectorData(com.alibaba.alink.operator.common.distance.FastDistanceVectorData) TreeMap(java.util.TreeMap) KDTreeModelData(com.alibaba.alink.operator.common.similarity.modeldata.KDTreeModelData) EuclideanDistance(com.alibaba.alink.operator.common.distance.EuclideanDistance) KDTree(com.alibaba.alink.operator.common.similarity.KDTree) Row(org.apache.flink.types.Row) HashMap(java.util.HashMap) Map(java.util.Map) TreeMap(java.util.TreeMap)

Aggregations

FastDistanceVectorData (com.alibaba.alink.operator.common.distance.FastDistanceVectorData)3 KDTree (com.alibaba.alink.operator.common.similarity.KDTree)3 ArrayList (java.util.ArrayList)3 EuclideanDistance (com.alibaba.alink.operator.common.distance.EuclideanDistance)2 Row (org.apache.flink.types.Row)2 KDTreeModelData (com.alibaba.alink.operator.common.similarity.modeldata.KDTreeModelData)1 BaseVectorSummary (com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary)1 VectorApproxNearestNeighborTrainParams (com.alibaba.alink.params.similarity.VectorApproxNearestNeighborTrainParams)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 TreeMap (java.util.TreeMap)1 RichMapPartitionFunction (org.apache.flink.api.common.functions.RichMapPartitionFunction)1 DataSet (org.apache.flink.api.java.DataSet)1 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)1 Params (org.apache.flink.ml.api.misc.param.Params)1 Collector (org.apache.flink.util.Collector)1