Search in sources :

Example 16 with CollectSinkStreamOp

use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.

the class DocWordCountStreamOpTest method testDocWordCountStream.

@Test
public void testDocWordCountStream() throws Exception {
    MemSourceStreamOp dataStream = new MemSourceStreamOp(Arrays.asList(rows), new String[] { "id", "sentence" });
    DocWordCountStreamOp op = new DocWordCountStreamOp().setDocIdCol("id").setContentCol("sentence").linkFrom(dataStream);
    CollectSinkStreamOp sink = new CollectSinkStreamOp().linkFrom(op);
    StreamOperator.execute();
    assertListRowEqualWithoutOrder(expected, sink.getAndRemoveValues());
}
Also used : MemSourceStreamOp(com.alibaba.alink.operator.stream.source.MemSourceStreamOp) CollectSinkStreamOp(com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp) Test(org.junit.Test)

Example 17 with CollectSinkStreamOp

use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.

the class LogisticRegressionTest method streamTest.

@Test
public void streamTest() throws Exception {
    String[] xVars = new String[] { "f0", "f1", "f2", "f3" };
    String yVar = "labels";
    String vectorName = "vec";
    String svectorName = "svec";
    BatchOperator<?> trainData = (BatchOperator<?>) getData(true);
    LogisticRegressionTrainBatchOp svm = new LogisticRegressionTrainBatchOp().setLabelCol(yVar).setWithIntercept(false).setStandardization(false).setFeatureCols(xVars).setOptimMethod("lbfgs").linkFrom(trainData);
    LogisticRegressionTrainBatchOp vectorSvm = new LogisticRegressionTrainBatchOp().setLabelCol(yVar).setWithIntercept(false).setStandardization(false).setVectorCol(vectorName).linkFrom(trainData);
    LogisticRegressionTrainBatchOp sparseVectorSvm = new LogisticRegressionTrainBatchOp().setLabelCol(yVar).setVectorCol(svectorName).setWithIntercept(false).setStandardization(false).setOptimMethod("newton").setMaxIter(10).linkFrom(trainData);
    StreamOperator<?> result1 = new LogisticRegressionPredictStreamOp(svm).setPredictionCol("lrpred").linkFrom((StreamOperator<?>) getData(false));
    StreamOperator<?> result2 = new LogisticRegressionPredictStreamOp(vectorSvm).setPredictionCol("svpred").linkFrom(result1);
    StreamOperator<?> result3 = new LogisticRegressionPredictStreamOp(sparseVectorSvm).setPredictionCol("dvpred").linkFrom(result2);
    CollectSinkStreamOp sop = result3.link(new CollectSinkStreamOp());
    StreamOperator.execute();
    List<Row> rows = sop.getAndRemoveValues();
    for (Row row : rows) {
        for (int i = 7; i < 10; ++i) {
            Assert.assertEquals(row.getField(6), row.getField(i));
        }
    }
}
Also used : LogisticRegressionPredictStreamOp(com.alibaba.alink.operator.stream.classification.LogisticRegressionPredictStreamOp) CollectSinkStreamOp(com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp) Row(org.apache.flink.types.Row) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) Test(org.junit.Test)

Example 18 with CollectSinkStreamOp

use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.

the class LassoRegressionTest method regressionPipelineTest.

@Test
public void regressionPipelineTest() throws Exception {
    BatchOperator<?> vecdata = new MemSourceBatchOp(Arrays.asList(vecRows), veccolNames);
    StreamOperator<?> svecdata = new MemSourceStreamOp(Arrays.asList(vecRows), veccolNames);
    String[] xVars = new String[] { "f0", "f1", "f2" };
    String yVar = "label";
    String vec = "vec";
    String svec = "svec";
    LassoRegression lasso = new LassoRegression().setLabelCol(yVar).setFeatureCols(xVars).setLambda(0.01).setMaxIter(20).setOptimMethod("owlqn").setPredictionCol("linpred");
    LassoRegression vlasso = new LassoRegression().setLabelCol(yVar).setVectorCol(vec).setMaxIter(20).setLambda(0.01).setOptimMethod("newton").setPredictionCol("vlinpred").enableLazyPrintModelInfo();
    LassoRegression svlasso = new LassoRegression().setLabelCol(yVar).setVectorCol(svec).setMaxIter(20).setLambda(0.01).setPredictionCol("svlinpred");
    Pipeline pl = new Pipeline().add(lasso).add(vlasso).add(svlasso);
    PipelineModel model = pl.fit(vecdata);
    BatchOperator<?> result = model.transform(vecdata).select(new String[] { "label", "linpred", "vlinpred", "svlinpred" });
    List<Row> data = result.collect();
    for (Row row : data) {
        if ((double) row.getField(0) == 16.8000) {
            Assert.assertEquals((double) row.getField(1), 16.784611802507232, 0.01);
            Assert.assertEquals((double) row.getField(2), 16.784611802507232, 0.01);
            Assert.assertEquals((double) row.getField(3), 16.78209421260283, 0.01);
        } else if ((double) row.getField(0) == 6.7000) {
            Assert.assertEquals((double) row.getField(1), 6.7713287283076, 0.01);
            Assert.assertEquals((double) row.getField(2), 6.7713287283076, 0.01);
            Assert.assertEquals((double) row.getField(3), 6.826846826823054, 0.01);
        }
    }
    // below is stream test code
    CollectSinkStreamOp sop = model.transform(svecdata).select(new String[] { "label", "linpred", "vlinpred", "svlinpred" }).link(new CollectSinkStreamOp());
    StreamOperator.execute();
    List<Row> rows = sop.getAndRemoveValues();
    for (Row row : rows) {
        if ((double) row.getField(0) == 16.8000) {
            Assert.assertEquals((double) row.getField(1), 16.784611802507232, 0.01);
            Assert.assertEquals((double) row.getField(2), 16.784611802507232, 0.01);
            Assert.assertEquals((double) row.getField(3), 16.78209421260283, 0.01);
        } else if ((double) row.getField(0) == 6.7000) {
            Assert.assertEquals((double) row.getField(1), 6.7713287283076, 0.01);
            Assert.assertEquals((double) row.getField(2), 6.7713287283076, 0.01);
            Assert.assertEquals((double) row.getField(3), 6.826846826823054, 0.01);
        }
    }
}
Also used : MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) MemSourceStreamOp(com.alibaba.alink.operator.stream.source.MemSourceStreamOp) CollectSinkStreamOp(com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp) Row(org.apache.flink.types.Row) Pipeline(com.alibaba.alink.pipeline.Pipeline) PipelineModel(com.alibaba.alink.pipeline.PipelineModel) Test(org.junit.Test)

Example 19 with CollectSinkStreamOp

use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.

the class RidgeRegressionTest method regressionPipelineTest.

@Test
public void regressionPipelineTest() throws Exception {
    BatchOperator<?> vecdata = new MemSourceBatchOp(Arrays.asList(vecrows), veccolNames);
    StreamOperator<?> svecdata = new MemSourceStreamOp(Arrays.asList(vecrows), veccolNames);
    String[] xVars = new String[] { "f0", "f1", "f2" };
    String yVar = "label";
    String vec = "vec";
    String svec = "svec";
    RidgeRegression ridge = new RidgeRegression().setLabelCol(yVar).setFeatureCols(xVars).setLambda(0.01).setMaxIter(10).setPredictionCol("linpred");
    RidgeRegression vridge = new RidgeRegression().setLabelCol(yVar).setVectorCol(vec).setLambda(0.01).setMaxIter(10).setOptimMethod("newton").setPredictionCol("vlinpred");
    RidgeRegression svridge = new RidgeRegression().setLabelCol(yVar).setVectorCol(svec).setLambda(0.01).setMaxIter(10).setPredictionCol("svlinpred");
    Pipeline pl = new Pipeline().add(ridge).add(vridge).add(svridge);
    PipelineModel model = pl.fit(vecdata);
    BatchOperator<?> result = model.transform(vecdata).select(new String[] { "label", "linpred", "vlinpred", "svlinpred" });
    List<Row> data = result.collect();
    for (Row row : data) {
        if ((double) row.getField(0) == 16.8000) {
            Assert.assertEquals((double) row.getField(1), 16.77322547668301, 0.01);
            Assert.assertEquals((double) row.getField(2), 16.620448399254673, 0.01);
            Assert.assertEquals((double) row.getField(3), 16.384437074591887, 0.01);
        } else if ((double) row.getField(0) == 6.7000) {
            Assert.assertEquals((double) row.getField(1), 6.932628087721653, 0.01);
            Assert.assertEquals((double) row.getField(2), 6.775060404865803, 0.01);
            Assert.assertEquals((double) row.getField(3), 7.425378715755974, 0.01);
        }
    }
    // below is stream test code
    CollectSinkStreamOp sop = model.transform(svecdata).select(new String[] { "label", "linpred", "vlinpred", "svlinpred" }).link(new CollectSinkStreamOp());
    StreamOperator.execute();
    List<Row> rows = sop.getAndRemoveValues();
    for (Row row : rows) {
        if ((double) row.getField(0) == 16.8000) {
            Assert.assertEquals((double) row.getField(1), 16.77322547668301, 0.01);
            Assert.assertEquals((double) row.getField(2), 16.620448399254673, 0.01);
            Assert.assertEquals((double) row.getField(3), 16.384437074591887, 0.01);
        } else if ((double) row.getField(0) == 6.7000) {
            Assert.assertEquals((double) row.getField(1), 6.932628087721653, 0.01);
            Assert.assertEquals((double) row.getField(2), 6.775060404865803, 0.01);
            Assert.assertEquals((double) row.getField(3), 7.425378715755974, 0.01);
        }
    }
}
Also used : MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) MemSourceStreamOp(com.alibaba.alink.operator.stream.source.MemSourceStreamOp) CollectSinkStreamOp(com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp) Row(org.apache.flink.types.Row) Pipeline(com.alibaba.alink.pipeline.Pipeline) PipelineModel(com.alibaba.alink.pipeline.PipelineModel) Test(org.junit.Test)

Example 20 with CollectSinkStreamOp

use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.

the class StringNearestNeighborTest method testStringStream.

@Test
public void testStringStream() throws Exception {
    BatchOperator dict = new MemSourceBatchOp(Arrays.asList(StringNearestNeighborBatchOpTest.dictRows), new String[] { "id", "str" });
    StreamOperator query = new MemSourceStreamOp(Arrays.asList(StringNearestNeighborBatchOpTest.queryRows), new String[] { "id", "str" });
    StringNearestNeighborModel model = new StringNearestNeighbor().setIdCol("id").setSelectedCol("str").setTopN(3).setOutputCol("output").fit(dict);
    CollectSinkStreamOp collectSinkStreamOp = new CollectSinkStreamOp().linkFrom(model.transform(query));
    StreamOperator.execute();
    List<Row> res = collectSinkStreamOp.getAndRemoveValues();
    Map<Object, Double[]> score = new HashMap<>();
    score.put(1, new Double[] { 0.75, 0.667, 0.333 });
    score.put(2, new Double[] { 0.667, 0.667, 0.5 });
    score.put(3, new Double[] { 0.333, 0.333, 0.25 });
    score.put(4, new Double[] { 0.75, 0.333, 0.333 });
    score.put(5, new Double[] { 0.333, 0.25, 0.25 });
    score.put(6, new Double[] { 0.333, 0.333, 0.333 });
    for (Row row : res) {
        Double[] actual = StringNearestNeighborBatchOpTest.extractScore((String) row.getField(2));
        Double[] expect = score.get(row.getField(0));
        for (int i = 0; i < actual.length; i++) {
            Assert.assertEquals(actual[i], expect[i], 0.01);
        }
    }
}
Also used : MemSourceStreamOp(com.alibaba.alink.operator.stream.source.MemSourceStreamOp) HashMap(java.util.HashMap) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) CollectSinkStreamOp(com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp) Row(org.apache.flink.types.Row) StreamOperator(com.alibaba.alink.operator.stream.StreamOperator) Test(org.junit.Test) StringNearestNeighborBatchOpTest(com.alibaba.alink.operator.batch.similarity.StringNearestNeighborBatchOpTest) TextApproxNearestNeighborBatchOpTest(com.alibaba.alink.operator.batch.similarity.TextApproxNearestNeighborBatchOpTest)

Aggregations

CollectSinkStreamOp (com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp)80 Test (org.junit.Test)76 Row (org.apache.flink.types.Row)72 MemSourceStreamOp (com.alibaba.alink.operator.stream.source.MemSourceStreamOp)60 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)30 RowComparator (com.alibaba.alink.operator.common.dataproc.SortUtils.RowComparator)25 StreamOperator (com.alibaba.alink.operator.stream.StreamOperator)25 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)20 Pipeline (com.alibaba.alink.pipeline.Pipeline)9 PipelineModel (com.alibaba.alink.pipeline.PipelineModel)9 Timestamp (java.sql.Timestamp)8 SparseVector (com.alibaba.alink.common.linalg.SparseVector)6 StringNearestNeighborBatchOpTest (com.alibaba.alink.operator.batch.similarity.StringNearestNeighborBatchOpTest)6 TextApproxNearestNeighborBatchOpTest (com.alibaba.alink.operator.batch.similarity.TextApproxNearestNeighborBatchOpTest)6 OverCountWindowStreamOp (com.alibaba.alink.operator.stream.feature.OverCountWindowStreamOp)6 ArrayList (java.util.ArrayList)6 HashMap (java.util.HashMap)6 MTable (com.alibaba.alink.common.MTable)3 DenseVector (com.alibaba.alink.common.linalg.DenseVector)3 TableSchema (org.apache.flink.table.api.TableSchema)3