use of com.alibaba.alink.operator.common.similarity.lsh.BaseLSH in project Alink by alibaba.
the class LSHModelDataConverter method buildIndex.
@Override
public DataSet<Row> buildIndex(BatchOperator in, Params params) {
DataSet<BaseLSH> lsh = LocalitySensitiveHashApproxFunctions.buildLSH(in, params, params.get(VectorApproxNearestNeighborTrainParams.SELECTED_COL));
DataSet<Tuple3<Object, Vector, int[]>> hashValue = in.getDataSet().map(new RichMapFunction<Row, Tuple3<Object, Vector, int[]>>() {
private static final long serialVersionUID = 9119201008956936115L;
@Override
public Tuple3<Object, Vector, int[]> map(Row row) throws Exception {
BaseLSH lsh = (BaseLSH) getRuntimeContext().getBroadcastVariable("lsh").get(0);
Vector vector = VectorUtil.getVector(row.getField(1));
Object id = row.getField(0);
int[] hashValue = lsh.hashFunction(vector);
return Tuple3.of(id, vector, hashValue);
}
}).withBroadcastSet(lsh, "lsh");
DataSet<Row> bucket = hashValue.flatMap(new FlatMapFunction<Tuple3<Object, Vector, int[]>, Tuple2<Object, Integer>>() {
private static final long serialVersionUID = 7401684044391240070L;
@Override
public void flatMap(Tuple3<Object, Vector, int[]> value, Collector<Tuple2<Object, Integer>> out) throws Exception {
for (int aBucket : value.f2) {
out.collect(Tuple2.of(value.f0, aBucket));
}
}
}).groupBy(1).reduceGroup(new GroupReduceFunction<Tuple2<Object, Integer>, Row>() {
private static final long serialVersionUID = -4976135470912551698L;
@Override
public void reduce(Iterable<Tuple2<Object, Integer>> values, Collector<Row> out) throws Exception {
List<Object> ids = new ArrayList<>();
Integer index = null;
for (Tuple2<Object, Integer> t : values) {
ids.add(t.f0);
if (null == index) {
index = t.f1;
}
}
Row row = new Row(ROW_SIZE);
row.setField(BUCKETS_INDEX, JsonConverter.toJson(Tuple2.of(index, ids)));
out.collect(row);
}
});
DataSet<Row> originData = hashValue.map(new MapFunction<Tuple3<Object, Vector, int[]>, Row>() {
private static final long serialVersionUID = 7915820872982890995L;
@Override
public Row map(Tuple3<Object, Vector, int[]> value) throws Exception {
Row row = new Row(ROW_SIZE);
row.setField(DATA_IDNEX, JsonConverter.toJson(Tuple2.of(value.f0, value.f1.toString())));
return row;
}
});
return originData.union(bucket).mapPartition(new RichMapPartitionFunction<Row, Row>() {
private static final long serialVersionUID = 1398487522497229248L;
@Override
public void mapPartition(Iterable<Row> values, Collector<Row> out) {
Params meta = null;
BaseLSH lsh = (BaseLSH) getRuntimeContext().getBroadcastVariable("lsh").get(0);
if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
meta = params;
if (lsh instanceof BucketRandomProjectionLSH) {
BucketRandomProjectionLSH brpLsh = (BucketRandomProjectionLSH) lsh;
meta.set(BucketRandomProjectionLSH.RAND_VECTORS, brpLsh.getRandVectors()).set(BucketRandomProjectionLSH.RAND_NUMBER, brpLsh.getRandNumber()).set(BucketRandomProjectionLSH.PROJECTION_WIDTH, brpLsh.getProjectionWidth());
} else {
MinHashLSH minHashLSH = (MinHashLSH) lsh;
meta.set(MinHashLSH.RAND_COEFFICIENTS_A, minHashLSH.getRandCoefficientsA()).set(MinHashLSH.RAND_COEFFICIENTS_B, minHashLSH.getRandCoefficientsB());
}
}
new LSHModelDataConverter().save(Tuple2.of(meta, values), out);
}
}).withBroadcastSet(lsh, "lsh").name("build_model");
}
use of com.alibaba.alink.operator.common.similarity.lsh.BaseLSH in project Alink by alibaba.
the class LocalitySensitiveHashApproxFunctions method buildLSH.
public static DataSet<BaseLSH> buildLSH(BatchOperator in, Params params, String vectorCol) {
DataSet<BaseLSH> lsh;
VectorApproxNearestNeighborTrainParams.Metric metric = params.get(VectorApproxNearestNeighborTrainParams.METRIC);
switch(metric) {
case JACCARD:
{
lsh = MLEnvironmentFactory.get(params.get(HasMLEnvironmentId.ML_ENVIRONMENT_ID)).getExecutionEnvironment().fromElements(new MinHashLSH(params.get(VectorApproxNearestNeighborTrainParams.SEED), params.get(VectorApproxNearestNeighborTrainParams.NUM_PROJECTIONS_PER_TABLE), params.get(VectorApproxNearestNeighborTrainParams.NUM_HASH_TABLES)));
break;
}
case EUCLIDEAN:
{
Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> statistics = StatisticsHelper.summaryHelper(in, null, vectorCol);
lsh = statistics.f1.mapPartition(new MapPartitionFunction<BaseVectorSummary, BaseLSH>() {
private static final long serialVersionUID = -3698577489884292933L;
@Override
public void mapPartition(Iterable<BaseVectorSummary> values, Collector<BaseLSH> out) {
List<BaseVectorSummary> tensorInfo = new ArrayList<>();
values.forEach(tensorInfo::add);
out.collect(new BucketRandomProjectionLSH(params.get(VectorApproxNearestNeighborTrainParams.SEED), tensorInfo.get(0).vectorSize(), params.get(VectorApproxNearestNeighborTrainParams.NUM_PROJECTIONS_PER_TABLE), params.get(VectorApproxNearestNeighborTrainParams.NUM_HASH_TABLES), params.get(VectorApproxNearestNeighborTrainParams.PROJECTION_WIDTH)));
}
});
break;
}
default:
{
throw new IllegalArgumentException("Not support " + metric);
}
}
return lsh;
}
use of com.alibaba.alink.operator.common.similarity.lsh.BaseLSH in project Alink by alibaba.
the class LSHModelDataConverter method loadModelData.
@Override
public LSHModelData loadModelData(List<Row> list) {
Map<Integer, List<Object>> indexMap = new HashMap<>();
Map<Object, Vector> data = new HashMap<>();
for (Row row : list) {
if (row.getField(BUCKETS_INDEX) != null) {
Tuple2<Integer, List<Object>> tuple2 = JsonConverter.fromJson((String) row.getField(BUCKETS_INDEX), new TypeReference<Tuple2<Integer, List<Object>>>() {
}.getType());
indexMap.put(tuple2.f0, tuple2.f1);
} else if (row.getField(DATA_IDNEX) != null) {
Tuple2<Object, String> tuple3 = JsonConverter.fromJson((String) row.getField(DATA_IDNEX), new TypeReference<Tuple2<Object, String>>() {
}.getType());
data.put(tuple3.f0, VectorUtil.getVector(tuple3.f1));
}
}
BaseLSH lsh;
if (meta.get(VectorApproxNearestNeighborTrainParams.METRIC).equals(VectorApproxNearestNeighborTrainParams.Metric.JACCARD)) {
lsh = new MinHashLSH(meta.get(MinHashLSH.RAND_COEFFICIENTS_A), meta.get(MinHashLSH.RAND_COEFFICIENTS_B));
} else {
lsh = new BucketRandomProjectionLSH(meta.get(BucketRandomProjectionLSH.RAND_VECTORS), meta.get(BucketRandomProjectionLSH.RAND_NUMBER), meta.get(BucketRandomProjectionLSH.PROJECTION_WIDTH));
}
return new LSHModelData(indexMap, data, lsh);
}
Aggregations