Search in sources :

Example 86 with SparseVector

use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.

the class OneHotModelMapperTest method testNullModel.

@Test
public void testNullModel() throws Exception {
    Params params = new Params().set(OneHotPredictParams.ENCODE, HasEncodeWithoutWoe.Encode.VECTOR).set(OneHotPredictParams.SELECTED_COLS, new String[] { "cnt", "word", "docid" });
    OneHotModelMapper mapper = new OneHotModelMapper(modelSchema, dataSchema, params);
    mapper.loadModel(nullModel);
    Assert.assertEquals(mapper.map(defaultRow), Row.of(new SparseVector(2, new int[] { 1 }, new double[] { 1.0 }), new SparseVector(2, new int[] { 1 }, new double[] { 1.0 }), new SparseVector(2, new int[] { 1 }, new double[] { 1.0 })));
}
Also used : OneHotPredictParams(com.alibaba.alink.params.feature.OneHotPredictParams) Params(org.apache.flink.ml.api.misc.param.Params) SparseVector(com.alibaba.alink.common.linalg.SparseVector) Test(org.junit.Test)

Example 87 with SparseVector

use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.

the class OneHotModelMapperTest method testHandleInvalidError.

@Test
public void testHandleInvalidError() throws Exception {
    Params params = new Params().set(OneHotPredictParams.ENCODE, HasEncodeWithoutWoe.Encode.VECTOR).set(OneHotPredictParams.HANDLE_INVALID, HasHandleInvalid.HandleInvalid.ERROR).set(OneHotPredictParams.SELECTED_COLS, new String[] { "cnt", "word" }).set(OneHotPredictParams.DROP_LAST, false);
    OneHotModelMapper mapper = new OneHotModelMapper(modelSchema, dataSchema, params);
    mapper.loadModel(model);
    assertEquals(mapper.map(nullElseRow), Row.of(null, new SparseVector(8, new int[] { 7 }, new double[] { 1.0 }), new SparseVector(6, new int[] { 2 }, new double[] { 1.0 })));
    mapper.loadModel(newModel);
    try {
        assertEquals(mapper.map(nullElseRow), Row.of(null, null, new SparseVector(5, new int[] { 2 }, new double[] { 1.0 })));
    } catch (Exception e) {
        assertEquals(e.getMessage(), "Unseen token: 梅");
    }
}
Also used : OneHotPredictParams(com.alibaba.alink.params.feature.OneHotPredictParams) Params(org.apache.flink.ml.api.misc.param.Params) SparseVector(com.alibaba.alink.common.linalg.SparseVector) Test(org.junit.Test)

Example 88 with SparseVector

use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.

the class QuantileDiscretizerModelMapperTest method testVector.

@Test
public void testVector() throws Exception {
    Params params = new Params().set(QuantileDiscretizerPredictParams.ENCODE, HasEncodeWithoutWoe.Encode.VECTOR).set(QuantileDiscretizerPredictParams.SELECTED_COLS, new String[] { "col2", "col3" }).set(QuantileDiscretizerPredictParams.DROP_LAST, false);
    QuantileDiscretizerModelMapper mapper = new QuantileDiscretizerModelMapper(modelSchema, dataSchema, params);
    mapper.loadModel(model);
    assertEquals(mapper.map(defaultRow), Row.of("a", new SparseVector(4, new int[] { 0 }, new double[] { 1.0 }), new SparseVector(4, new int[] { 2 }, new double[] { 1.0 })));
    assertEquals(mapper.map(nullElseRow), Row.of("b", new SparseVector(4, new int[] { 3 }, new double[] { 1.0 }), new SparseVector(4, new int[] { 1 }, new double[] { 1.0 })));
}
Also used : QuantileDiscretizerPredictParams(com.alibaba.alink.params.feature.QuantileDiscretizerPredictParams) Params(org.apache.flink.ml.api.misc.param.Params) SparseVector(com.alibaba.alink.common.linalg.SparseVector) Test(org.junit.Test)

Example 89 with SparseVector

use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.

the class SosBatchOp method linkFrom.

@Override
public SosBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    final String vectorColName = getVectorCol();
    final String predResultColName = getPredictionCol();
    final int vectorColIdx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), vectorColName);
    DataSet<Tuple2<Integer, Row>> pointsWithIndex = DataSetUtils.zipWithIndex(in.getDataSet()).map(new MapFunction<Tuple2<Long, Row>, Tuple2<Integer, Row>>() {

        private static final long serialVersionUID = -2866816204744880078L;

        @Override
        public Tuple2<Integer, Row> map(Tuple2<Long, Row> in) throws Exception {
            return new Tuple2<>(in.f0.intValue(), in.f1);
        }
    });
    DataSet<Tuple2<Integer, DenseVector>> sosInput = pointsWithIndex.map(new MapFunction<Tuple2<Integer, Row>, Tuple2<Integer, DenseVector>>() {

        private static final long serialVersionUID = 5798841020968290284L;

        @Override
        public Tuple2<Integer, DenseVector> map(Tuple2<Integer, Row> in) throws Exception {
            Vector vec = VectorUtil.getVector(in.f1.getField(vectorColIdx));
            if (null == vec) {
                return new Tuple2<>(in.f0, null);
            } else {
                return new Tuple2<>(in.f0, (vec instanceof DenseVector) ? (DenseVector) vec : ((SparseVector) vec).toDenseVector());
            }
        }
    });
    SOSImpl sos = new SOSImpl(this.getParams());
    DataSet<Tuple2<Integer, Double>> outlierProb = sos.outlierSelection(sosInput);
    DataSet<Row> output = outlierProb.join(pointsWithIndex).where(0).equalTo(0).with(new JoinFunction<Tuple2<Integer, Double>, Tuple2<Integer, Row>, Row>() {

        private static final long serialVersionUID = 7086848937713200592L;

        @Override
        public Row join(Tuple2<Integer, Double> in1, Tuple2<Integer, Row> in2) throws Exception {
            Row row = new Row(in2.f1.getArity() + 1);
            for (int i = 0; i < in2.f1.getArity(); i++) {
                row.setField(i, in2.f1.getField(i));
            }
            row.setField(in2.f1.getArity(), in1.f1);
            return row;
        }
    }).returns(new RowTypeInfo(ArrayUtils.add(in.getColTypes(), Types.DOUBLE)));
    this.setOutput(output, ArrayUtils.add(in.getColNames(), predResultColName));
    return this;
}
Also used : RowTypeInfo(org.apache.flink.api.java.typeutils.RowTypeInfo) Vector(com.alibaba.alink.common.linalg.Vector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) SparseVector(com.alibaba.alink.common.linalg.SparseVector) SOSImpl(com.alibaba.alink.operator.common.outlier.SOSImpl) Tuple2(org.apache.flink.api.java.tuple.Tuple2) JoinFunction(org.apache.flink.api.common.functions.JoinFunction) Row(org.apache.flink.types.Row) DenseVector(com.alibaba.alink.common.linalg.DenseVector)

Example 90 with SparseVector

use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.

the class OnlineCorpusStep method onlineCorpusUpdate.

public static Tuple4<DenseMatrix, DenseMatrix, Long, Long> onlineCorpusUpdate(List<Vector> data, DenseMatrix lambda, DenseMatrix alpha, DenseMatrix gammad, int vocabularySize, int numTopic, double subSamplingRate, RandomDataGenerator random, int gammaShape) {
    DenseMatrix wordTopicStat = DenseMatrix.zeros(numTopic, vocabularySize);
    DenseMatrix logPhatPart = new DenseMatrix(numTopic, 1);
    DenseMatrix expELogBeta = LdaUtil.expDirichletExpectation(lambda).transpose();
    long nonEmptyWordCount = 0;
    long nonEmptyDocCount = 0;
    // the online corpus update stage can update the model in two way.
    // if the document order is determined, then it will update in the order.
    // or it will choose documents randomly.
    int dataSize = data.size();
    int[] indices = generateOnlineDocs(dataSize, subSamplingRate, random);
    for (int index : indices) {
        Vector vec = data.get(index);
        SparseVector sv = (SparseVector) vec;
        sv.setSize(vocabularySize);
        sv.removeZeroValues();
        for (int i = 0; i < sv.numberOfValues(); i++) {
            nonEmptyWordCount += sv.getValues()[i];
        }
        gammad = LdaUtil.geneGamma(numTopic, gammaShape, random);
        Tuple2<DenseMatrix, DenseMatrix> topicDistributionTuple = LdaUtil.getTopicDistributionMethod(sv, expELogBeta, alpha, gammad, numTopic);
        for (int i = 0; i < sv.getIndices().length; i++) {
            for (int k = 0; k < numTopic; k++) {
                wordTopicStat.add(k, sv.getIndices()[i], topicDistributionTuple.f1.get(k, i));
            }
        }
        gammad = topicDistributionTuple.f0;
        DenseMatrix deGammad = LdaUtil.dirichletExpectationVec(gammad);
        for (int k = 0; k < numTopic; k++) {
            logPhatPart.add(k, 0, deGammad.get(k, 0));
        }
        nonEmptyDocCount++;
    }
    return new Tuple4<>(wordTopicStat, logPhatPart, nonEmptyWordCount, nonEmptyDocCount);
}
Also used : Tuple4(org.apache.flink.api.java.tuple.Tuple4) SparseVector(com.alibaba.alink.common.linalg.SparseVector) Vector(com.alibaba.alink.common.linalg.Vector) SparseVector(com.alibaba.alink.common.linalg.SparseVector) DenseMatrix(com.alibaba.alink.common.linalg.DenseMatrix)

Aggregations

SparseVector (com.alibaba.alink.common.linalg.SparseVector)125 Test (org.junit.Test)63 DenseVector (com.alibaba.alink.common.linalg.DenseVector)60 Params (org.apache.flink.ml.api.misc.param.Params)45 Row (org.apache.flink.types.Row)45 Vector (com.alibaba.alink.common.linalg.Vector)40 TableSchema (org.apache.flink.table.api.TableSchema)27 ArrayList (java.util.ArrayList)21 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)15 HashMap (java.util.HashMap)12 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)12 List (java.util.List)11 DenseMatrix (com.alibaba.alink.common.linalg.DenseMatrix)10 MTable (com.alibaba.alink.common.MTable)7 BaseVectorSummary (com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary)6 CollectSinkStreamOp (com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp)6 Map (java.util.Map)6 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)5 VectorAssemblerParams (com.alibaba.alink.params.dataproc.vector.VectorAssemblerParams)5 OneHotPredictParams (com.alibaba.alink.params.feature.OneHotPredictParams)5