Search in sources :

Example 1 with BaseLSH

use of com.alibaba.alink.operator.common.similarity.lsh.BaseLSH in project Alink by alibaba.

the class LSHModelDataConverter method buildIndex.

@Override
public DataSet<Row> buildIndex(BatchOperator in, Params params) {
    DataSet<BaseLSH> lsh = LocalitySensitiveHashApproxFunctions.buildLSH(in, params, params.get(VectorApproxNearestNeighborTrainParams.SELECTED_COL));
    DataSet<Tuple3<Object, Vector, int[]>> hashValue = in.getDataSet().map(new RichMapFunction<Row, Tuple3<Object, Vector, int[]>>() {

        private static final long serialVersionUID = 9119201008956936115L;

        @Override
        public Tuple3<Object, Vector, int[]> map(Row row) throws Exception {
            BaseLSH lsh = (BaseLSH) getRuntimeContext().getBroadcastVariable("lsh").get(0);
            Vector vector = VectorUtil.getVector(row.getField(1));
            Object id = row.getField(0);
            int[] hashValue = lsh.hashFunction(vector);
            return Tuple3.of(id, vector, hashValue);
        }
    }).withBroadcastSet(lsh, "lsh");
    DataSet<Row> bucket = hashValue.flatMap(new FlatMapFunction<Tuple3<Object, Vector, int[]>, Tuple2<Object, Integer>>() {

        private static final long serialVersionUID = 7401684044391240070L;

        @Override
        public void flatMap(Tuple3<Object, Vector, int[]> value, Collector<Tuple2<Object, Integer>> out) throws Exception {
            for (int aBucket : value.f2) {
                out.collect(Tuple2.of(value.f0, aBucket));
            }
        }
    }).groupBy(1).reduceGroup(new GroupReduceFunction<Tuple2<Object, Integer>, Row>() {

        private static final long serialVersionUID = -4976135470912551698L;

        @Override
        public void reduce(Iterable<Tuple2<Object, Integer>> values, Collector<Row> out) throws Exception {
            List<Object> ids = new ArrayList<>();
            Integer index = null;
            for (Tuple2<Object, Integer> t : values) {
                ids.add(t.f0);
                if (null == index) {
                    index = t.f1;
                }
            }
            Row row = new Row(ROW_SIZE);
            row.setField(BUCKETS_INDEX, JsonConverter.toJson(Tuple2.of(index, ids)));
            out.collect(row);
        }
    });
    DataSet<Row> originData = hashValue.map(new MapFunction<Tuple3<Object, Vector, int[]>, Row>() {

        private static final long serialVersionUID = 7915820872982890995L;

        @Override
        public Row map(Tuple3<Object, Vector, int[]> value) throws Exception {
            Row row = new Row(ROW_SIZE);
            row.setField(DATA_IDNEX, JsonConverter.toJson(Tuple2.of(value.f0, value.f1.toString())));
            return row;
        }
    });
    return originData.union(bucket).mapPartition(new RichMapPartitionFunction<Row, Row>() {

        private static final long serialVersionUID = 1398487522497229248L;

        @Override
        public void mapPartition(Iterable<Row> values, Collector<Row> out) {
            Params meta = null;
            BaseLSH lsh = (BaseLSH) getRuntimeContext().getBroadcastVariable("lsh").get(0);
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                meta = params;
                if (lsh instanceof BucketRandomProjectionLSH) {
                    BucketRandomProjectionLSH brpLsh = (BucketRandomProjectionLSH) lsh;
                    meta.set(BucketRandomProjectionLSH.RAND_VECTORS, brpLsh.getRandVectors()).set(BucketRandomProjectionLSH.RAND_NUMBER, brpLsh.getRandNumber()).set(BucketRandomProjectionLSH.PROJECTION_WIDTH, brpLsh.getProjectionWidth());
                } else {
                    MinHashLSH minHashLSH = (MinHashLSH) lsh;
                    meta.set(MinHashLSH.RAND_COEFFICIENTS_A, minHashLSH.getRandCoefficientsA()).set(MinHashLSH.RAND_COEFFICIENTS_B, minHashLSH.getRandCoefficientsB());
                }
            }
            new LSHModelDataConverter().save(Tuple2.of(meta, values), out);
        }
    }).withBroadcastSet(lsh, "lsh").name("build_model");
}
Also used : BucketRandomProjectionLSH(com.alibaba.alink.operator.common.similarity.lsh.BucketRandomProjectionLSH) ArrayList(java.util.ArrayList) List(java.util.List) Vector(com.alibaba.alink.common.linalg.Vector) VectorApproxNearestNeighborTrainParams(com.alibaba.alink.params.similarity.VectorApproxNearestNeighborTrainParams) Params(org.apache.flink.ml.api.misc.param.Params) BaseLSH(com.alibaba.alink.operator.common.similarity.lsh.BaseLSH) MinHashLSH(com.alibaba.alink.operator.common.similarity.lsh.MinHashLSH) RichMapFunction(org.apache.flink.api.common.functions.RichMapFunction) Tuple2(org.apache.flink.api.java.tuple.Tuple2) Tuple3(org.apache.flink.api.java.tuple.Tuple3) Row(org.apache.flink.types.Row)

Example 2 with BaseLSH

use of com.alibaba.alink.operator.common.similarity.lsh.BaseLSH in project Alink by alibaba.

the class LocalitySensitiveHashApproxFunctions method buildLSH.

public static DataSet<BaseLSH> buildLSH(BatchOperator in, Params params, String vectorCol) {
    DataSet<BaseLSH> lsh;
    VectorApproxNearestNeighborTrainParams.Metric metric = params.get(VectorApproxNearestNeighborTrainParams.METRIC);
    switch(metric) {
        case JACCARD:
            {
                lsh = MLEnvironmentFactory.get(params.get(HasMLEnvironmentId.ML_ENVIRONMENT_ID)).getExecutionEnvironment().fromElements(new MinHashLSH(params.get(VectorApproxNearestNeighborTrainParams.SEED), params.get(VectorApproxNearestNeighborTrainParams.NUM_PROJECTIONS_PER_TABLE), params.get(VectorApproxNearestNeighborTrainParams.NUM_HASH_TABLES)));
                break;
            }
        case EUCLIDEAN:
            {
                Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> statistics = StatisticsHelper.summaryHelper(in, null, vectorCol);
                lsh = statistics.f1.mapPartition(new MapPartitionFunction<BaseVectorSummary, BaseLSH>() {

                    private static final long serialVersionUID = -3698577489884292933L;

                    @Override
                    public void mapPartition(Iterable<BaseVectorSummary> values, Collector<BaseLSH> out) {
                        List<BaseVectorSummary> tensorInfo = new ArrayList<>();
                        values.forEach(tensorInfo::add);
                        out.collect(new BucketRandomProjectionLSH(params.get(VectorApproxNearestNeighborTrainParams.SEED), tensorInfo.get(0).vectorSize(), params.get(VectorApproxNearestNeighborTrainParams.NUM_PROJECTIONS_PER_TABLE), params.get(VectorApproxNearestNeighborTrainParams.NUM_HASH_TABLES), params.get(VectorApproxNearestNeighborTrainParams.PROJECTION_WIDTH)));
                    }
                });
                break;
            }
        default:
            {
                throw new IllegalArgumentException("Not support " + metric);
            }
    }
    return lsh;
}
Also used : BaseLSH(com.alibaba.alink.operator.common.similarity.lsh.BaseLSH) MinHashLSH(com.alibaba.alink.operator.common.similarity.lsh.MinHashLSH) Tuple2(org.apache.flink.api.java.tuple.Tuple2) BaseVectorSummary(com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary) BucketRandomProjectionLSH(com.alibaba.alink.operator.common.similarity.lsh.BucketRandomProjectionLSH) VectorApproxNearestNeighborTrainParams(com.alibaba.alink.params.similarity.VectorApproxNearestNeighborTrainParams) ArrayList(java.util.ArrayList) List(java.util.List) Vector(com.alibaba.alink.common.linalg.Vector)

Example 3 with BaseLSH

use of com.alibaba.alink.operator.common.similarity.lsh.BaseLSH in project Alink by alibaba.

the class LSHModelDataConverter method loadModelData.

@Override
public LSHModelData loadModelData(List<Row> list) {
    Map<Integer, List<Object>> indexMap = new HashMap<>();
    Map<Object, Vector> data = new HashMap<>();
    for (Row row : list) {
        if (row.getField(BUCKETS_INDEX) != null) {
            Tuple2<Integer, List<Object>> tuple2 = JsonConverter.fromJson((String) row.getField(BUCKETS_INDEX), new TypeReference<Tuple2<Integer, List<Object>>>() {
            }.getType());
            indexMap.put(tuple2.f0, tuple2.f1);
        } else if (row.getField(DATA_IDNEX) != null) {
            Tuple2<Object, String> tuple3 = JsonConverter.fromJson((String) row.getField(DATA_IDNEX), new TypeReference<Tuple2<Object, String>>() {
            }.getType());
            data.put(tuple3.f0, VectorUtil.getVector(tuple3.f1));
        }
    }
    BaseLSH lsh;
    if (meta.get(VectorApproxNearestNeighborTrainParams.METRIC).equals(VectorApproxNearestNeighborTrainParams.Metric.JACCARD)) {
        lsh = new MinHashLSH(meta.get(MinHashLSH.RAND_COEFFICIENTS_A), meta.get(MinHashLSH.RAND_COEFFICIENTS_B));
    } else {
        lsh = new BucketRandomProjectionLSH(meta.get(BucketRandomProjectionLSH.RAND_VECTORS), meta.get(BucketRandomProjectionLSH.RAND_NUMBER), meta.get(BucketRandomProjectionLSH.PROJECTION_WIDTH));
    }
    return new LSHModelData(indexMap, data, lsh);
}
Also used : HashMap(java.util.HashMap) LSHModelData(com.alibaba.alink.operator.common.similarity.modeldata.LSHModelData) BaseLSH(com.alibaba.alink.operator.common.similarity.lsh.BaseLSH) MinHashLSH(com.alibaba.alink.operator.common.similarity.lsh.MinHashLSH) Tuple2(org.apache.flink.api.java.tuple.Tuple2) BucketRandomProjectionLSH(com.alibaba.alink.operator.common.similarity.lsh.BucketRandomProjectionLSH) ArrayList(java.util.ArrayList) List(java.util.List) Row(org.apache.flink.types.Row) TypeReference(org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference) Vector(com.alibaba.alink.common.linalg.Vector)

Aggregations

Vector (com.alibaba.alink.common.linalg.Vector)3 BaseLSH (com.alibaba.alink.operator.common.similarity.lsh.BaseLSH)3 BucketRandomProjectionLSH (com.alibaba.alink.operator.common.similarity.lsh.BucketRandomProjectionLSH)3 MinHashLSH (com.alibaba.alink.operator.common.similarity.lsh.MinHashLSH)3 ArrayList (java.util.ArrayList)3 List (java.util.List)3 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)3 VectorApproxNearestNeighborTrainParams (com.alibaba.alink.params.similarity.VectorApproxNearestNeighborTrainParams)2 Row (org.apache.flink.types.Row)2 LSHModelData (com.alibaba.alink.operator.common.similarity.modeldata.LSHModelData)1 BaseVectorSummary (com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary)1 HashMap (java.util.HashMap)1 RichMapFunction (org.apache.flink.api.common.functions.RichMapFunction)1 Tuple3 (org.apache.flink.api.java.tuple.Tuple3)1 Params (org.apache.flink.ml.api.misc.param.Params)1 TypeReference (org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference)1