Search in sources :

Example 46 with SparseVector

use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.

the class ItemCfRecommKernel method recommendItems.

static MTable recommendItems(Object userId, ItemCfRecommData model, int topN, boolean excludeKnown, double[] res, String objectName, TypeInformation<?> objType) {
    Arrays.fill(res, 0.0);
    PriorityQueue<RecommItemTopKResult> queue = new PriorityQueue<>(Comparator.comparing(o -> o.similarity));
    SparseVector itemRate = model.userItemRates.get(userId);
    if (null == itemRate) {
        return null;
    }
    Set<Integer> items = model.userItems.get(userId);
    int[] key = itemRate.getIndices();
    double[] value = itemRate.getValues();
    for (int i = 0; i < key.length; i++) {
        if (model.itemSimilarityList[key[i]] != null) {
            for (Tuple2<Integer, Double> t : model.itemSimilarityList[key[i]]) {
                res[t.f0] += t.f1 * value[i];
            }
        }
    }
    double head = 0;
    for (int i = 0; i < res.length; i++) {
        if (excludeKnown && items.contains(i)) {
            continue;
        }
        head = updateQueue(queue, topN, res[i] / items.size(), model.items[i], head);
    }
    return serializeQueue(queue, KObjectUtil.SCORE_NAME, objectName, objType);
}
Also used : Arrays(java.util.Arrays) Tuple2(org.apache.flink.api.java.tuple.Tuple2) PriorityQueue(java.util.PriorityQueue) Set(java.util.Set) TableSchema(org.apache.flink.table.api.TableSchema) HashMap(java.util.HashMap) BaseItemsPerUserRecommParams(com.alibaba.alink.params.recommendation.BaseItemsPerUserRecommParams) BaseSimilarItemsRecommParams(com.alibaba.alink.params.recommendation.BaseSimilarItemsRecommParams) Serializable(java.io.Serializable) ArrayList(java.util.ArrayList) List(java.util.List) MTable(com.alibaba.alink.common.MTable) Map(java.util.Map) Row(org.apache.flink.types.Row) Queue(java.util.Queue) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) SparseVector(com.alibaba.alink.common.linalg.SparseVector) Comparator(java.util.Comparator) Params(org.apache.flink.ml.api.misc.param.Params) FlinkTypeConverter(com.alibaba.alink.operator.common.io.types.FlinkTypeConverter) Collections(java.util.Collections) PriorityQueue(java.util.PriorityQueue) SparseVector(com.alibaba.alink.common.linalg.SparseVector)

Example 47 with SparseVector

use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.

the class ItemCfRecommModelDataConverter method load.

@Override
public ItemCfRecommData load(List<Row> rows) {
    ItemCfRecommData modelData = new ItemCfRecommData();
    switch(recommType) {
        case USERS_PER_ITEM:
            {
                String[] items;
                modelData.itemSimilarities = new HashMap<>();
                for (Row row : rows) {
                    if (row.getField(0) != null) {
                        Object userId = row.getField(0);
                        SparseVector vector = VectorUtil.getSparseVector(row.getField(2));
                        if (modelData.userRateList == null) {
                            modelData.userRateList = new List[vector.size()];
                        }
                        double[] value = vector.getValues();
                        int[] key = vector.getIndices();
                        for (int j = 0; j < key.length; j++) {
                            if (modelData.userRateList[key[j]] == null) {
                                modelData.userRateList[key[j]] = new ArrayList<>();
                            }
                            modelData.userRateList[key[j]].add(Tuple2.of(userId, value[j]));
                        }
                    } else if (row.getField(1) != null) {
                        modelData.itemSimilarities.put(((Number) row.getField(1)).intValue(), VectorUtil.getSparseVector(row.getField(2)));
                    } else {
                        modelData.meta = Params.fromJson((String) row.getField(2));
                        items = modelData.meta.get(ITEMS);
                        TypeInformation<?> itemType = FlinkTypeConverter.getFlinkType(modelData.meta.get(ITEM_TYPE));
                        modelData.itemMap = new HashMap<>();
                        for (int i = 0; i < items.length; i++) {
                            modelData.itemMap.put(EvaluationUtil.castTo(items[i], itemType), i);
                        }
                        modelData.rateCol = modelData.meta.get(ItemCfRecommTrainParams.RATE_COL);
                    }
                }
                modelData.itemUsers = new HashMap<>();
                for (int i = 0; i < modelData.userRateList.length; i++) {
                    Set<Object> users = new HashSet<>();
                    for (Tuple2<Object, Double> t : modelData.userRateList[i]) {
                        users.add(t.f0);
                    }
                    modelData.itemUsers.put(i, users);
                }
                break;
            }
        case SIMILAR_ITEMS:
        case SIMILAR_USERS:
            {
                modelData.itemSimilarities = new HashMap<>();
                for (Row row : rows) {
                    if (row.getField(1) != null) {
                        modelData.itemSimilarities.put(((Number) row.getField(1)).intValue(), VectorUtil.getSparseVector(row.getField(2)));
                    } else if (row.getField(0) == null) {
                        modelData.meta = Params.fromJson((String) row.getField(2));
                        String[] items = modelData.meta.get(ITEMS);
                        modelData.items = new Object[items.length];
                        TypeInformation<?> itemType = FlinkTypeConverter.getFlinkType(modelData.meta.get(ITEM_TYPE));
                        modelData.itemMap = new HashMap<>();
                        for (int i = 0; i < items.length; i++) {
                            modelData.items[i] = EvaluationUtil.castTo(items[i], itemType);
                            modelData.itemMap.put(modelData.items[i], i);
                        }
                        modelData.rateCol = modelData.meta.get(ItemCfRecommTrainParams.RATE_COL);
                    }
                }
                break;
            }
        case ITEMS_PER_USER:
            {
                modelData.userItemRates = new HashMap<>();
                for (Row row : rows) {
                    if (row.getField(0) != null) {
                        modelData.userItemRates.put(row.getField(0), VectorUtil.getSparseVector(row.getField(2)));
                    } else if (row.getField(1) != null) {
                        Integer itemId = ((Number) row.getField(1)).intValue();
                        SparseVector vector = VectorUtil.getSparseVector(row.getField(2));
                        if (modelData.itemSimilarityList == null) {
                            modelData.itemSimilarityList = new List[vector.size()];
                        }
                        double[] value = vector.getValues();
                        int[] key = vector.getIndices();
                        for (int j = 0; j < key.length; j++) {
                            if (modelData.itemSimilarityList[key[j]] == null) {
                                modelData.itemSimilarityList[key[j]] = new ArrayList<>();
                            }
                            modelData.itemSimilarityList[key[j]].add(Tuple2.of(itemId, value[j]));
                        }
                    } else {
                        modelData.meta = Params.fromJson((String) row.getField(2));
                        String[] items = modelData.meta.get(ITEMS);
                        modelData.items = new Object[items.length];
                        TypeInformation<?> itemType = FlinkTypeConverter.getFlinkType(modelData.meta.get(ITEM_TYPE));
                        for (int i = 0; i < items.length; i++) {
                            modelData.items[i] = EvaluationUtil.castTo(items[i], itemType);
                        }
                        modelData.rateCol = modelData.meta.get(ItemCfRecommTrainParams.RATE_COL);
                    }
                }
                modelData.userItems = new HashMap<>();
                for (Map.Entry<Object, SparseVector> entry : modelData.userItemRates.entrySet()) {
                    Set<Integer> items = new HashSet<>();
                    for (int key : entry.getValue().getIndices()) {
                        items.add(key);
                    }
                    modelData.userItems.put(entry.getKey(), items);
                }
                break;
            }
        case RATE:
            {
                modelData.userItemRates = new HashMap<>();
                modelData.itemSimilarities = new HashMap<>();
                for (Row row : rows) {
                    if (row.getField(0) != null) {
                        modelData.userItemRates.put(row.getField(0), VectorUtil.getSparseVector(row.getField(2)));
                    } else if (row.getField(1) != null) {
                        modelData.itemSimilarities.put(((Number) row.getField(1)).intValue(), VectorUtil.getSparseVector(row.getField(2)));
                    } else {
                        modelData.meta = Params.fromJson((String) row.getField(2));
                        modelData.rateCol = modelData.meta.get(ItemCfRecommTrainParams.RATE_COL);
                        String[] items = modelData.meta.get(ITEMS);
                        modelData.items = new Object[items.length];
                        TypeInformation<?> itemType = FlinkTypeConverter.getFlinkType(modelData.meta.get(ITEM_TYPE));
                        modelData.itemMap = new HashMap<>();
                        for (int i = 0; i < items.length; i++) {
                            modelData.items[i] = EvaluationUtil.castTo(items[i], itemType);
                            modelData.itemMap.put(modelData.items[i], i);
                        }
                    }
                }
                break;
            }
        default:
            {
                throw new RuntimeException("Not support yet!");
            }
    }
    return modelData;
}
Also used : Set(java.util.Set) HashSet(java.util.HashSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) SparseVector(com.alibaba.alink.common.linalg.SparseVector) Tuple2(org.apache.flink.api.java.tuple.Tuple2) ArrayList(java.util.ArrayList) List(java.util.List) Row(org.apache.flink.types.Row)

Example 48 with SparseVector

use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.

the class BucketRandomProjectionLSHTest method testHashFunction.

@Test
public void testHashFunction() {
    BucketRandomProjectionLSH lsh = new BucketRandomProjectionLSH(0, 5, 2, 2, 1);
    Vector vec1 = new DenseVector(new double[] { 1, 2, 3, 4, 5 });
    Assert.assertArrayEquals(new int[] { -348137008, 1394862530 }, lsh.hashFunction(vec1));
    Vector vec2 = new SparseVector(5, new int[] { 0, 4 }, new double[] { 1.0, 4.0 });
    Assert.assertArrayEquals(new int[] { -802232505, 1759100286 }, lsh.hashFunction(vec2));
}
Also used : BucketRandomProjectionLSH(com.alibaba.alink.operator.common.similarity.lsh.BucketRandomProjectionLSH) SparseVector(com.alibaba.alink.common.linalg.SparseVector) Vector(com.alibaba.alink.common.linalg.Vector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) SparseVector(com.alibaba.alink.common.linalg.SparseVector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) Test(org.junit.Test)

Example 49 with SparseVector

use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.

the class BucketRandomProjectionLSHTest method testDistance.

@Test
public void testDistance() {
    BucketRandomProjectionLSH lsh = new BucketRandomProjectionLSH(0, 5, 2, 2, 1);
    Vector vec1 = new DenseVector(new double[] { 1, 0, 0, 2, 0 });
    Vector vec2 = new DenseVector(new double[] { 0, 1, 0, 2, 1 });
    Assert.assertEquals(1.732, lsh.keyDistance(vec1, vec2), 0.001);
    vec1 = new SparseVector(10, new int[] { 0, 4, 5, 7, 9 }, new double[] { 1.0, 1.0, 1.0, 1.0, 1.0 });
    vec2 = new SparseVector(10, new int[] { 0, 1, 3, 5, 9 }, new double[] { 1.0, 1.0, 1.0, 1.0, 1.0 });
    Assert.assertEquals(2.0, lsh.keyDistance(vec1, vec2), 0.001);
}
Also used : BucketRandomProjectionLSH(com.alibaba.alink.operator.common.similarity.lsh.BucketRandomProjectionLSH) SparseVector(com.alibaba.alink.common.linalg.SparseVector) Vector(com.alibaba.alink.common.linalg.Vector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) SparseVector(com.alibaba.alink.common.linalg.SparseVector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) Test(org.junit.Test)

Example 50 with SparseVector

use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.

the class MinHashLSHTest method testHashFunction.

@Test
public void testHashFunction() {
    MinHashLSH lsh = new MinHashLSH(0, 2, 2);
    Vector vec1 = new DenseVector(new double[] { 1, 2, 3, 4, 5 });
    Assert.assertArrayEquals(new int[] { 478212008, -1798305157 }, lsh.hashFunction(vec1));
    Vector vec2 = new SparseVector(5, new int[] { 0, 4 }, new double[] { 1.0, 4.0 });
    Assert.assertArrayEquals(new int[] { -967745172, -594675602 }, lsh.hashFunction(vec2));
}
Also used : MinHashLSH(com.alibaba.alink.operator.common.similarity.lsh.MinHashLSH) SparseVector(com.alibaba.alink.common.linalg.SparseVector) Vector(com.alibaba.alink.common.linalg.Vector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) SparseVector(com.alibaba.alink.common.linalg.SparseVector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) Test(org.junit.Test)

Aggregations

SparseVector (com.alibaba.alink.common.linalg.SparseVector)125 Test (org.junit.Test)63 DenseVector (com.alibaba.alink.common.linalg.DenseVector)60 Params (org.apache.flink.ml.api.misc.param.Params)45 Row (org.apache.flink.types.Row)45 Vector (com.alibaba.alink.common.linalg.Vector)40 TableSchema (org.apache.flink.table.api.TableSchema)27 ArrayList (java.util.ArrayList)21 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)15 HashMap (java.util.HashMap)12 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)12 List (java.util.List)11 DenseMatrix (com.alibaba.alink.common.linalg.DenseMatrix)10 MTable (com.alibaba.alink.common.MTable)7 BaseVectorSummary (com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary)6 CollectSinkStreamOp (com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp)6 Map (java.util.Map)6 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)5 VectorAssemblerParams (com.alibaba.alink.params.dataproc.vector.VectorAssemblerParams)5 OneHotPredictParams (com.alibaba.alink.params.feature.OneHotPredictParams)5