Search in sources :

Example 6 with BaseVectorSummary

use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.

the class StatisticsHelperTest method summaryHelperVector.

@Test
public void summaryHelperVector() throws Exception {
    BatchOperator data = getDenseBatch();
    String vectorColName = "vec";
    Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> dataSet = StatisticsHelper.summaryHelper(data, null, vectorColName);
    BaseVectorSummary summary = dataSet.f1.collect().get(0);
    assertEquals(summary.vectorSize(), 3);
    assertEquals(summary.count(), 4);
    assertEquals(summary.max(2), 4.0, 10e-4);
    assertEquals(summary.min(1), 0.0, 10e-4);
    assertEquals(summary.mean(2), 1.25, 10e-4);
    assertEquals(summary.variance(2), 8.9167, 10e-4);
    assertEquals(summary.standardDeviation(2), 2.9861, 10e-4);
    assertEquals(summary.normL1(2), 11.0, 10e-4);
    assertEquals(summary.normL2(2), 5.7446, 10e-4);
    List<Vector> vectors = dataSet.f0.collect();
    assertEquals(vectors.size(), 4);
    assertArrayEquals(((DenseVector) vectors.get(0)).getData(), new double[] { 1, 1, 2.0 }, 10e-4);
    assertArrayEquals(((DenseVector) vectors.get(1)).getData(), new double[] { 2, 2, -3.0 }, 10e-4);
    assertArrayEquals(((DenseVector) vectors.get(2)).getData(), new double[] { 1, 3, 2.0 }, 10e-4);
    assertArrayEquals(((DenseVector) vectors.get(3)).getData(), new double[] { 0, 0, 4.0 }, 10e-4);
}
Also used : DataSet(org.apache.flink.api.java.DataSet) BaseVectorSummary(com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary) Vector(com.alibaba.alink.common.linalg.Vector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) Test(org.junit.Test)

Example 7 with BaseVectorSummary

use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.

the class StatisticsHelperTest method summaryHelperTableWithReservedCols.

@Test
public void summaryHelperTableWithReservedCols() throws Exception {
    BatchOperator data = getBatchTable();
    String[] selectedColNames = new String[] { "f_long", "f_int", "f_double" };
    String[] reservedColNames = new String[] { "id" };
    Tuple2<DataSet<Tuple2<Vector, Row>>, DataSet<BaseVectorSummary>> dataSet = StatisticsHelper.summaryHelper(data, selectedColNames, null, reservedColNames);
    BaseVectorSummary summary = dataSet.f1.collect().get(0);
    assertEquals(summary.vectorSize(), 3);
    assertEquals(summary.count(), 4);
    assertEquals(summary.max(2), 4.0, 10e-4);
    assertEquals(summary.min(1), 0.0, 10e-4);
    assertEquals(summary.mean(2), 1.25, 10e-4);
    assertEquals(summary.variance(2), 8.9167, 10e-4);
    assertEquals(summary.standardDeviation(2), 2.9861, 10e-4);
    assertEquals(summary.normL1(2), 11.0, 10e-4);
    assertEquals(summary.normL2(2), 5.7446, 10e-4);
    List<Tuple2<Vector, Row>> tuple2s = dataSet.f0.collect();
    assertEquals(tuple2s.size(), 4);
    assertArrayEquals(((DenseVector) tuple2s.get(0).f0).getData(), new double[] { 1, 1, 2.0 }, 10e-4);
    assertArrayEquals(((DenseVector) tuple2s.get(1).f0).getData(), new double[] { 2, 2, -3.0 }, 10e-4);
    assertArrayEquals(((DenseVector) tuple2s.get(2).f0).getData(), new double[] { 1, 3, 2.0 }, 10e-4);
    assertArrayEquals(((DenseVector) tuple2s.get(3).f0).getData(), new double[] { 0, 0, 4.0 }, 10e-4);
    assertEquals(tuple2s.get(0).f1.getField(0), 1);
    assertEquals(tuple2s.get(1).f1.getField(0), 2);
    assertEquals(tuple2s.get(2).f1.getField(0), 3);
    assertEquals(tuple2s.get(3).f1.getField(0), 4);
}
Also used : DataSet(org.apache.flink.api.java.DataSet) Tuple2(org.apache.flink.api.java.tuple.Tuple2) BaseVectorSummary(com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary) Row(org.apache.flink.types.Row) Vector(com.alibaba.alink.common.linalg.Vector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) Test(org.junit.Test)

Example 8 with BaseVectorSummary

use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.

the class LdaTrainBatchOp method online.

private void online(Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> dataAndStat, int numTopic, int numIter, double alpha, double beta, DataSet<DocCountVectorizerModelData> resDocCountModel, int gammaShape, Integer seed) {
    if (beta == -1) {
        beta = 1.0 / numTopic;
    }
    if (alpha == -1) {
        alpha = 1.0 / numTopic;
    }
    double learningOffset = getParams().get(ONLINE_LEARNING_OFFSET);
    double learningDecay = getParams().get(LEARNING_DECAY);
    double subSamplingRate = getParams().get(SUBSAMPLING_RATE);
    boolean optimizeDocConcentration = getParams().get(OPTIMIZE_DOC_CONCENTRATION);
    DataSet<Vector> data = dataAndStat.f0;
    DataSet<Tuple2<Long, Integer>> shape = dataAndStat.f1.map(new MapFunction<BaseVectorSummary, Tuple2<Long, Integer>>() {

        private static final long serialVersionUID = 1305270477796787466L;

        @Override
        public Tuple2<Long, Integer> map(BaseVectorSummary srt) {
            return new Tuple2<>(srt.count(), srt.vectorSize());
        }
    });
    DataSet<Tuple2<DenseMatrix, DenseMatrix>> initModel = data.mapPartition(new OnlineInit(numTopic, gammaShape, alpha, seed)).name("init lambda").withBroadcastSet(shape, LdaVariable.shape);
    DataSet<Row> ldaModelData = new IterativeComQueue().initWithPartitionedData(LdaVariable.data, data).initWithBroadcastData(LdaVariable.shape, shape).initWithBroadcastData(LdaVariable.initModel, initModel).add(new OnlineCorpusStep(numTopic, subSamplingRate, gammaShape, seed)).add(new AllReduce(LdaVariable.wordTopicStat)).add(new AllReduce(LdaVariable.logPhatPart)).add(new AllReduce(LdaVariable.nonEmptyWordCount)).add(new AllReduce(LdaVariable.nonEmptyDocCount)).add(new UpdateLambdaAndAlpha(numTopic, learningOffset, learningDecay, subSamplingRate, optimizeDocConcentration, beta)).add(new OnlineLogLikelihood(beta, numTopic, numIter, gammaShape, seed)).add(new AllReduce(LdaVariable.logLikelihood)).closeWith(new BuildOnlineLdaModel(numTopic, beta)).setMaxIter(numIter).exec();
    DataSet<Row> model = ldaModelData.flatMap(new BuildResModel(seed)).withBroadcastSet(resDocCountModel, "DocCountModel");
    setOutput(model, new LdaModelDataConverter().getModelSchema());
    saveWordTopicModelAndPerplexity(model, numTopic, true);
}
Also used : IterativeComQueue(com.alibaba.alink.common.comqueue.IterativeComQueue) AllReduce(com.alibaba.alink.common.comqueue.communication.AllReduce) UpdateLambdaAndAlpha(com.alibaba.alink.operator.common.clustering.lda.UpdateLambdaAndAlpha) OnlineLogLikelihood(com.alibaba.alink.operator.common.clustering.lda.OnlineLogLikelihood) BuildOnlineLdaModel(com.alibaba.alink.operator.common.clustering.lda.BuildOnlineLdaModel) Tuple2(org.apache.flink.api.java.tuple.Tuple2) OnlineCorpusStep(com.alibaba.alink.operator.common.clustering.lda.OnlineCorpusStep) LdaModelDataConverter(com.alibaba.alink.operator.common.clustering.LdaModelDataConverter) BaseVectorSummary(com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary) Row(org.apache.flink.types.Row) Vector(com.alibaba.alink.common.linalg.Vector) SparseVector(com.alibaba.alink.common.linalg.SparseVector)

Example 9 with BaseVectorSummary

use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.

the class KDTreeModelDataConverter method buildIndex.

@Override
public DataSet<Row> buildIndex(BatchOperator in, Params params) {
    Preconditions.checkArgument(params.get(VectorApproxNearestNeighborTrainParams.METRIC).equals(VectorApproxNearestNeighborTrainParams.Metric.EUCLIDEAN), "KDTree solver only supports Euclidean distance!");
    EuclideanDistance distance = new EuclideanDistance();
    Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> statistics = StatisticsHelper.summaryHelper(in, null, params.get(VectorApproxNearestNeighborTrainParams.SELECTED_COL));
    return in.getDataSet().rebalance().mapPartition(new RichMapPartitionFunction<Row, Row>() {

        private static final long serialVersionUID = 6654757741959479783L;

        @Override
        public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
            BaseVectorSummary summary = (BaseVectorSummary) getRuntimeContext().getBroadcastVariable("vectorSize").get(0);
            int vectorSize = summary.vectorSize();
            List<FastDistanceVectorData> list = new ArrayList<>();
            for (Row row : values) {
                FastDistanceVectorData vector = distance.prepareVectorData(row, 1, 0);
                list.add(vector);
                vectorSize = vector.getVector().size();
            }
            if (list.size() > 0) {
                FastDistanceVectorData[] vectorArray = list.toArray(new FastDistanceVectorData[0]);
                KDTree tree = new KDTree(vectorArray, vectorSize, distance);
                tree.buildTree();
                int taskId = getRuntimeContext().getIndexOfThisSubtask();
                Row row = new Row(ROW_SIZE);
                row.setField(TASKID_INDEX, (long) taskId);
                for (int i = 0; i < vectorArray.length; i++) {
                    row.setField(DATA_ID_INDEX, (long) i);
                    row.setField(DATA_IDNEX, vectorArray[i].toString());
                    out.collect(row);
                }
                row.setField(DATA_ID_INDEX, null);
                row.setField(DATA_IDNEX, null);
                row.setField(ROOT_IDDEX, JsonConverter.toJson(tree.getRoot()));
                out.collect(row);
            }
        }
    }).withBroadcastSet(statistics.f1, "vectorSize").mapPartition(new RichMapPartitionFunction<Row, Row>() {

        private static final long serialVersionUID = 6849403933586157611L;

        @Override
        public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception {
            Params meta = null;
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                meta = params;
                BaseVectorSummary summary = (BaseVectorSummary) getRuntimeContext().getBroadcastVariable("vectorSize").get(0);
                int vectorSize = summary.vectorSize();
                meta.set(VECTOR_SIZE, vectorSize);
            }
            new KDTreeModelDataConverter().save(Tuple2.of(meta, values), out);
        }
    }).withBroadcastSet(statistics.f1, "vectorSize");
}
Also used : DataSet(org.apache.flink.api.java.DataSet) RichMapPartitionFunction(org.apache.flink.api.common.functions.RichMapPartitionFunction) ArrayList(java.util.ArrayList) VectorApproxNearestNeighborTrainParams(com.alibaba.alink.params.similarity.VectorApproxNearestNeighborTrainParams) Params(org.apache.flink.ml.api.misc.param.Params) FastDistanceVectorData(com.alibaba.alink.operator.common.distance.FastDistanceVectorData) EuclideanDistance(com.alibaba.alink.operator.common.distance.EuclideanDistance) KDTree(com.alibaba.alink.operator.common.similarity.KDTree) BaseVectorSummary(com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary) Collector(org.apache.flink.util.Collector) Row(org.apache.flink.types.Row)

Example 10 with BaseVectorSummary

use of com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary in project Alink by alibaba.

the class LocalitySensitiveHashApproxFunctions method buildLSH.

public static DataSet<BaseLSH> buildLSH(BatchOperator in, Params params, String vectorCol) {
    DataSet<BaseLSH> lsh;
    VectorApproxNearestNeighborTrainParams.Metric metric = params.get(VectorApproxNearestNeighborTrainParams.METRIC);
    switch(metric) {
        case JACCARD:
            {
                lsh = MLEnvironmentFactory.get(params.get(HasMLEnvironmentId.ML_ENVIRONMENT_ID)).getExecutionEnvironment().fromElements(new MinHashLSH(params.get(VectorApproxNearestNeighborTrainParams.SEED), params.get(VectorApproxNearestNeighborTrainParams.NUM_PROJECTIONS_PER_TABLE), params.get(VectorApproxNearestNeighborTrainParams.NUM_HASH_TABLES)));
                break;
            }
        case EUCLIDEAN:
            {
                Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> statistics = StatisticsHelper.summaryHelper(in, null, vectorCol);
                lsh = statistics.f1.mapPartition(new MapPartitionFunction<BaseVectorSummary, BaseLSH>() {

                    private static final long serialVersionUID = -3698577489884292933L;

                    @Override
                    public void mapPartition(Iterable<BaseVectorSummary> values, Collector<BaseLSH> out) {
                        List<BaseVectorSummary> tensorInfo = new ArrayList<>();
                        values.forEach(tensorInfo::add);
                        out.collect(new BucketRandomProjectionLSH(params.get(VectorApproxNearestNeighborTrainParams.SEED), tensorInfo.get(0).vectorSize(), params.get(VectorApproxNearestNeighborTrainParams.NUM_PROJECTIONS_PER_TABLE), params.get(VectorApproxNearestNeighborTrainParams.NUM_HASH_TABLES), params.get(VectorApproxNearestNeighborTrainParams.PROJECTION_WIDTH)));
                    }
                });
                break;
            }
        default:
            {
                throw new IllegalArgumentException("Not support " + metric);
            }
    }
    return lsh;
}
Also used : BaseLSH(com.alibaba.alink.operator.common.similarity.lsh.BaseLSH) MinHashLSH(com.alibaba.alink.operator.common.similarity.lsh.MinHashLSH) Tuple2(org.apache.flink.api.java.tuple.Tuple2) BaseVectorSummary(com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary) BucketRandomProjectionLSH(com.alibaba.alink.operator.common.similarity.lsh.BucketRandomProjectionLSH) VectorApproxNearestNeighborTrainParams(com.alibaba.alink.params.similarity.VectorApproxNearestNeighborTrainParams) ArrayList(java.util.ArrayList) List(java.util.List) Vector(com.alibaba.alink.common.linalg.Vector)

Aggregations

BaseVectorSummary (com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary)24 Row (org.apache.flink.types.Row)13 Vector (com.alibaba.alink.common.linalg.Vector)11 DenseVector (com.alibaba.alink.common.linalg.DenseVector)9 SparseVector (com.alibaba.alink.common.linalg.SparseVector)9 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)9 DataSet (org.apache.flink.api.java.DataSet)9 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)8 Test (org.junit.Test)8 ArrayList (java.util.ArrayList)7 Params (org.apache.flink.ml.api.misc.param.Params)5 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)4 IterativeComQueue (com.alibaba.alink.common.comqueue.IterativeComQueue)3 AllReduce (com.alibaba.alink.common.comqueue.communication.AllReduce)3 VectorSummarizerBatchOp (com.alibaba.alink.operator.batch.statistics.VectorSummarizerBatchOp)3 LdaModelDataConverter (com.alibaba.alink.operator.common.clustering.LdaModelDataConverter)3 MapFunction (org.apache.flink.api.common.functions.MapFunction)3 RichMapFunction (org.apache.flink.api.common.functions.RichMapFunction)3 Configuration (org.apache.flink.configuration.Configuration)3 DenseMatrix (com.alibaba.alink.common.linalg.DenseMatrix)2