use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class ItemCfRecommKernel method recommendItems.
static MTable recommendItems(Object userId, ItemCfRecommData model, int topN, boolean excludeKnown, double[] res, String objectName, TypeInformation<?> objType) {
Arrays.fill(res, 0.0);
PriorityQueue<RecommItemTopKResult> queue = new PriorityQueue<>(Comparator.comparing(o -> o.similarity));
SparseVector itemRate = model.userItemRates.get(userId);
if (null == itemRate) {
return null;
}
Set<Integer> items = model.userItems.get(userId);
int[] key = itemRate.getIndices();
double[] value = itemRate.getValues();
for (int i = 0; i < key.length; i++) {
if (model.itemSimilarityList[key[i]] != null) {
for (Tuple2<Integer, Double> t : model.itemSimilarityList[key[i]]) {
res[t.f0] += t.f1 * value[i];
}
}
}
double head = 0;
for (int i = 0; i < res.length; i++) {
if (excludeKnown && items.contains(i)) {
continue;
}
head = updateQueue(queue, topN, res[i] / items.size(), model.items[i], head);
}
return serializeQueue(queue, KObjectUtil.SCORE_NAME, objectName, objType);
}
use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class ItemCfRecommModelDataConverter method load.
@Override
public ItemCfRecommData load(List<Row> rows) {
ItemCfRecommData modelData = new ItemCfRecommData();
switch(recommType) {
case USERS_PER_ITEM:
{
String[] items;
modelData.itemSimilarities = new HashMap<>();
for (Row row : rows) {
if (row.getField(0) != null) {
Object userId = row.getField(0);
SparseVector vector = VectorUtil.getSparseVector(row.getField(2));
if (modelData.userRateList == null) {
modelData.userRateList = new List[vector.size()];
}
double[] value = vector.getValues();
int[] key = vector.getIndices();
for (int j = 0; j < key.length; j++) {
if (modelData.userRateList[key[j]] == null) {
modelData.userRateList[key[j]] = new ArrayList<>();
}
modelData.userRateList[key[j]].add(Tuple2.of(userId, value[j]));
}
} else if (row.getField(1) != null) {
modelData.itemSimilarities.put(((Number) row.getField(1)).intValue(), VectorUtil.getSparseVector(row.getField(2)));
} else {
modelData.meta = Params.fromJson((String) row.getField(2));
items = modelData.meta.get(ITEMS);
TypeInformation<?> itemType = FlinkTypeConverter.getFlinkType(modelData.meta.get(ITEM_TYPE));
modelData.itemMap = new HashMap<>();
for (int i = 0; i < items.length; i++) {
modelData.itemMap.put(EvaluationUtil.castTo(items[i], itemType), i);
}
modelData.rateCol = modelData.meta.get(ItemCfRecommTrainParams.RATE_COL);
}
}
modelData.itemUsers = new HashMap<>();
for (int i = 0; i < modelData.userRateList.length; i++) {
Set<Object> users = new HashSet<>();
for (Tuple2<Object, Double> t : modelData.userRateList[i]) {
users.add(t.f0);
}
modelData.itemUsers.put(i, users);
}
break;
}
case SIMILAR_ITEMS:
case SIMILAR_USERS:
{
modelData.itemSimilarities = new HashMap<>();
for (Row row : rows) {
if (row.getField(1) != null) {
modelData.itemSimilarities.put(((Number) row.getField(1)).intValue(), VectorUtil.getSparseVector(row.getField(2)));
} else if (row.getField(0) == null) {
modelData.meta = Params.fromJson((String) row.getField(2));
String[] items = modelData.meta.get(ITEMS);
modelData.items = new Object[items.length];
TypeInformation<?> itemType = FlinkTypeConverter.getFlinkType(modelData.meta.get(ITEM_TYPE));
modelData.itemMap = new HashMap<>();
for (int i = 0; i < items.length; i++) {
modelData.items[i] = EvaluationUtil.castTo(items[i], itemType);
modelData.itemMap.put(modelData.items[i], i);
}
modelData.rateCol = modelData.meta.get(ItemCfRecommTrainParams.RATE_COL);
}
}
break;
}
case ITEMS_PER_USER:
{
modelData.userItemRates = new HashMap<>();
for (Row row : rows) {
if (row.getField(0) != null) {
modelData.userItemRates.put(row.getField(0), VectorUtil.getSparseVector(row.getField(2)));
} else if (row.getField(1) != null) {
Integer itemId = ((Number) row.getField(1)).intValue();
SparseVector vector = VectorUtil.getSparseVector(row.getField(2));
if (modelData.itemSimilarityList == null) {
modelData.itemSimilarityList = new List[vector.size()];
}
double[] value = vector.getValues();
int[] key = vector.getIndices();
for (int j = 0; j < key.length; j++) {
if (modelData.itemSimilarityList[key[j]] == null) {
modelData.itemSimilarityList[key[j]] = new ArrayList<>();
}
modelData.itemSimilarityList[key[j]].add(Tuple2.of(itemId, value[j]));
}
} else {
modelData.meta = Params.fromJson((String) row.getField(2));
String[] items = modelData.meta.get(ITEMS);
modelData.items = new Object[items.length];
TypeInformation<?> itemType = FlinkTypeConverter.getFlinkType(modelData.meta.get(ITEM_TYPE));
for (int i = 0; i < items.length; i++) {
modelData.items[i] = EvaluationUtil.castTo(items[i], itemType);
}
modelData.rateCol = modelData.meta.get(ItemCfRecommTrainParams.RATE_COL);
}
}
modelData.userItems = new HashMap<>();
for (Map.Entry<Object, SparseVector> entry : modelData.userItemRates.entrySet()) {
Set<Integer> items = new HashSet<>();
for (int key : entry.getValue().getIndices()) {
items.add(key);
}
modelData.userItems.put(entry.getKey(), items);
}
break;
}
case RATE:
{
modelData.userItemRates = new HashMap<>();
modelData.itemSimilarities = new HashMap<>();
for (Row row : rows) {
if (row.getField(0) != null) {
modelData.userItemRates.put(row.getField(0), VectorUtil.getSparseVector(row.getField(2)));
} else if (row.getField(1) != null) {
modelData.itemSimilarities.put(((Number) row.getField(1)).intValue(), VectorUtil.getSparseVector(row.getField(2)));
} else {
modelData.meta = Params.fromJson((String) row.getField(2));
modelData.rateCol = modelData.meta.get(ItemCfRecommTrainParams.RATE_COL);
String[] items = modelData.meta.get(ITEMS);
modelData.items = new Object[items.length];
TypeInformation<?> itemType = FlinkTypeConverter.getFlinkType(modelData.meta.get(ITEM_TYPE));
modelData.itemMap = new HashMap<>();
for (int i = 0; i < items.length; i++) {
modelData.items[i] = EvaluationUtil.castTo(items[i], itemType);
modelData.itemMap.put(modelData.items[i], i);
}
}
}
break;
}
default:
{
throw new RuntimeException("Not support yet!");
}
}
return modelData;
}
use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class BucketRandomProjectionLSHTest method testHashFunction.
@Test
public void testHashFunction() {
BucketRandomProjectionLSH lsh = new BucketRandomProjectionLSH(0, 5, 2, 2, 1);
Vector vec1 = new DenseVector(new double[] { 1, 2, 3, 4, 5 });
Assert.assertArrayEquals(new int[] { -348137008, 1394862530 }, lsh.hashFunction(vec1));
Vector vec2 = new SparseVector(5, new int[] { 0, 4 }, new double[] { 1.0, 4.0 });
Assert.assertArrayEquals(new int[] { -802232505, 1759100286 }, lsh.hashFunction(vec2));
}
use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class BucketRandomProjectionLSHTest method testDistance.
@Test
public void testDistance() {
BucketRandomProjectionLSH lsh = new BucketRandomProjectionLSH(0, 5, 2, 2, 1);
Vector vec1 = new DenseVector(new double[] { 1, 0, 0, 2, 0 });
Vector vec2 = new DenseVector(new double[] { 0, 1, 0, 2, 1 });
Assert.assertEquals(1.732, lsh.keyDistance(vec1, vec2), 0.001);
vec1 = new SparseVector(10, new int[] { 0, 4, 5, 7, 9 }, new double[] { 1.0, 1.0, 1.0, 1.0, 1.0 });
vec2 = new SparseVector(10, new int[] { 0, 1, 3, 5, 9 }, new double[] { 1.0, 1.0, 1.0, 1.0, 1.0 });
Assert.assertEquals(2.0, lsh.keyDistance(vec1, vec2), 0.001);
}
use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class MinHashLSHTest method testHashFunction.
@Test
public void testHashFunction() {
MinHashLSH lsh = new MinHashLSH(0, 2, 2);
Vector vec1 = new DenseVector(new double[] { 1, 2, 3, 4, 5 });
Assert.assertArrayEquals(new int[] { 478212008, -1798305157 }, lsh.hashFunction(vec1));
Vector vec2 = new SparseVector(5, new int[] { 0, 4 }, new double[] { 1.0, 4.0 });
Assert.assertArrayEquals(new int[] { -967745172, -594675602 }, lsh.hashFunction(vec2));
}
Aggregations