use of com.alibaba.alink.operator.common.similarity.KDTree in project Alink by alibaba.
the class KDTreeModelDataConverter method buildIndex.
@Override
public DataSet<Row> buildIndex(BatchOperator in, Params params) {
Preconditions.checkArgument(params.get(VectorApproxNearestNeighborTrainParams.METRIC).equals(VectorApproxNearestNeighborTrainParams.Metric.EUCLIDEAN), "KDTree solver only supports Euclidean distance!");
EuclideanDistance distance = new EuclideanDistance();
Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> statistics = StatisticsHelper.summaryHelper(in, null, params.get(VectorApproxNearestNeighborTrainParams.SELECTED_COL));
return in.getDataSet().rebalance().mapPartition(new RichMapPartitionFunction<Row, Row>() {
private static final long serialVersionUID = 6654757741959479783L;
@Override
public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
BaseVectorSummary summary = (BaseVectorSummary) getRuntimeContext().getBroadcastVariable("vectorSize").get(0);
int vectorSize = summary.vectorSize();
List<FastDistanceVectorData> list = new ArrayList<>();
for (Row row : values) {
FastDistanceVectorData vector = distance.prepareVectorData(row, 1, 0);
list.add(vector);
vectorSize = vector.getVector().size();
}
if (list.size() > 0) {
FastDistanceVectorData[] vectorArray = list.toArray(new FastDistanceVectorData[0]);
KDTree tree = new KDTree(vectorArray, vectorSize, distance);
tree.buildTree();
int taskId = getRuntimeContext().getIndexOfThisSubtask();
Row row = new Row(ROW_SIZE);
row.setField(TASKID_INDEX, (long) taskId);
for (int i = 0; i < vectorArray.length; i++) {
row.setField(DATA_ID_INDEX, (long) i);
row.setField(DATA_IDNEX, vectorArray[i].toString());
out.collect(row);
}
row.setField(DATA_ID_INDEX, null);
row.setField(DATA_IDNEX, null);
row.setField(ROOT_IDDEX, JsonConverter.toJson(tree.getRoot()));
out.collect(row);
}
}
}).withBroadcastSet(statistics.f1, "vectorSize").mapPartition(new RichMapPartitionFunction<Row, Row>() {
private static final long serialVersionUID = 6849403933586157611L;
@Override
public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
Params meta = null;
if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
meta = params;
BaseVectorSummary summary = (BaseVectorSummary) getRuntimeContext().getBroadcastVariable("vectorSize").get(0);
int vectorSize = summary.vectorSize();
meta.set(VECTOR_SIZE, vectorSize);
}
new KDTreeModelDataConverter().save(Tuple2.of(meta, values), out);
}
}).withBroadcastSet(statistics.f1, "vectorSize");
}
use of com.alibaba.alink.operator.common.similarity.KDTree in project Alink by alibaba.
the class KDTreeModelData method computeDistiance.
@Override
protected ArrayList<Tuple2<Double, Object>> computeDistiance(Object input, Integer index, Integer topN, Tuple2<Double, Object> radius) {
KDTree tree = treeList.get(index);
ArrayList<Tuple2<Double, Object>> tupleList = new ArrayList<>();
if (null != topN) {
Tuple2<Double, Row>[] treeTopN = tree.getTopN(topN, (FastDistanceVectorData) input);
for (int i = 0; i < treeTopN.length; i++) {
Tuple2<Double, Object> tuple = Tuple2.of(treeTopN[i].f0, treeTopN[i].f1.getField(0));
if (null == radius || radius.f0 == null || this.getQueueComparator().compare(radius, tuple) <= 0) {
tupleList.add(tuple);
}
}
} else {
List<FastDistanceVectorData> list = tree.rangeSearch(radius.f0, (FastDistanceVectorData) input);
for (FastDistanceVectorData data : list) {
Double dist = distance.calc(data, (FastDistanceVectorData) input).get(0, 0);
tupleList.add(Tuple2.of(dist, data.getRows()[0].getField(0)));
}
}
return tupleList;
}
use of com.alibaba.alink.operator.common.similarity.KDTree in project Alink by alibaba.
the class KDTreeModelDataConverter method loadModelData.
@Override
public KDTreeModelData loadModelData(List<Row> list) {
HashMap<Long, TreeMap<Integer, FastDistanceVectorData>> data = new HashMap<>();
HashMap<Long, KDTree.TreeNode> root = new HashMap<>();
for (Row row : list) {
if (row.getField(TASKID_INDEX) != null) {
long taskId = (long) row.getField(TASKID_INDEX);
TreeMap<Integer, FastDistanceVectorData> vectorDataMap = data.get(taskId);
if (null == vectorDataMap) {
vectorDataMap = new TreeMap<>();
}
if (row.getField(DATA_IDNEX) != null) {
vectorDataMap.put(((Number) row.getField(DATA_ID_INDEX)).intValue(), FastDistanceVectorData.fromString((String) row.getField(DATA_IDNEX)));
data.put(taskId, vectorDataMap);
} else if (row.getField(ROOT_IDDEX) != null) {
KDTree.TreeNode node = JsonConverter.fromJson((String) row.getField(ROOT_IDDEX), new TypeReference<KDTree.TreeNode>() {
}.getType());
root.put(taskId, node);
}
}
}
List<KDTree> treeList = new ArrayList<>();
int vectorSize = meta.get(VECTOR_SIZE);
EuclideanDistance distance = new EuclideanDistance();
for (Map.Entry<Long, TreeMap<Integer, FastDistanceVectorData>> entry : data.entrySet()) {
long taskId = entry.getKey();
KDTree.TreeNode node = root.get(taskId);
FastDistanceVectorData[] vectorData = entry.getValue().values().toArray(new FastDistanceVectorData[0]);
KDTree kdTree = new KDTree(vectorData, vectorSize, distance);
kdTree.setRoot(node);
treeList.add(kdTree);
}
return new KDTreeModelData(treeList);
}
Aggregations