Search in sources :

Example 1 with KDTreeModelData

use of com.alibaba.alink.operator.common.similarity.modeldata.KDTreeModelData in project Alink by alibaba.

the class KDTreeModelDataConverter method loadModelData.

@Override
public KDTreeModelData loadModelData(List<Row> list) {
    HashMap<Long, TreeMap<Integer, FastDistanceVectorData>> data = new HashMap<>();
    HashMap<Long, KDTree.TreeNode> root = new HashMap<>();
    for (Row row : list) {
        if (row.getField(TASKID_INDEX) != null) {
            long taskId = (long) row.getField(TASKID_INDEX);
            TreeMap<Integer, FastDistanceVectorData> vectorDataMap = data.get(taskId);
            if (null == vectorDataMap) {
                vectorDataMap = new TreeMap<>();
            }
            if (row.getField(DATA_IDNEX) != null) {
                vectorDataMap.put(((Number) row.getField(DATA_ID_INDEX)).intValue(), FastDistanceVectorData.fromString((String) row.getField(DATA_IDNEX)));
                data.put(taskId, vectorDataMap);
            } else if (row.getField(ROOT_IDDEX) != null) {
                KDTree.TreeNode node = JsonConverter.fromJson((String) row.getField(ROOT_IDDEX), new TypeReference<KDTree.TreeNode>() {
                }.getType());
                root.put(taskId, node);
            }
        }
    }
    List<KDTree> treeList = new ArrayList<>();
    int vectorSize = meta.get(VECTOR_SIZE);
    EuclideanDistance distance = new EuclideanDistance();
    for (Map.Entry<Long, TreeMap<Integer, FastDistanceVectorData>> entry : data.entrySet()) {
        long taskId = entry.getKey();
        KDTree.TreeNode node = root.get(taskId);
        FastDistanceVectorData[] vectorData = entry.getValue().values().toArray(new FastDistanceVectorData[0]);
        KDTree kdTree = new KDTree(vectorData, vectorSize, distance);
        kdTree.setRoot(node);
        treeList.add(kdTree);
    }
    return new KDTreeModelData(treeList);
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) FastDistanceVectorData(com.alibaba.alink.operator.common.distance.FastDistanceVectorData) TreeMap(java.util.TreeMap) KDTreeModelData(com.alibaba.alink.operator.common.similarity.modeldata.KDTreeModelData) EuclideanDistance(com.alibaba.alink.operator.common.distance.EuclideanDistance) KDTree(com.alibaba.alink.operator.common.similarity.KDTree) Row(org.apache.flink.types.Row) HashMap(java.util.HashMap) Map(java.util.Map) TreeMap(java.util.TreeMap)

Aggregations

EuclideanDistance (com.alibaba.alink.operator.common.distance.EuclideanDistance)1 FastDistanceVectorData (com.alibaba.alink.operator.common.distance.FastDistanceVectorData)1 KDTree (com.alibaba.alink.operator.common.similarity.KDTree)1 KDTreeModelData (com.alibaba.alink.operator.common.similarity.modeldata.KDTreeModelData)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 TreeMap (java.util.TreeMap)1 Row (org.apache.flink.types.Row)1