use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class DocWordCountStreamOpTest method testDocWordCountStream.
@Test
public void testDocWordCountStream() throws Exception {
MemSourceStreamOp dataStream = new MemSourceStreamOp(Arrays.asList(rows), new String[] { "id", "sentence" });
DocWordCountStreamOp op = new DocWordCountStreamOp().setDocIdCol("id").setContentCol("sentence").linkFrom(dataStream);
CollectSinkStreamOp sink = new CollectSinkStreamOp().linkFrom(op);
StreamOperator.execute();
assertListRowEqualWithoutOrder(expected, sink.getAndRemoveValues());
}
use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class LogisticRegressionTest method streamTest.
@Test
public void streamTest() throws Exception {
String[] xVars = new String[] { "f0", "f1", "f2", "f3" };
String yVar = "labels";
String vectorName = "vec";
String svectorName = "svec";
BatchOperator<?> trainData = (BatchOperator<?>) getData(true);
LogisticRegressionTrainBatchOp svm = new LogisticRegressionTrainBatchOp().setLabelCol(yVar).setWithIntercept(false).setStandardization(false).setFeatureCols(xVars).setOptimMethod("lbfgs").linkFrom(trainData);
LogisticRegressionTrainBatchOp vectorSvm = new LogisticRegressionTrainBatchOp().setLabelCol(yVar).setWithIntercept(false).setStandardization(false).setVectorCol(vectorName).linkFrom(trainData);
LogisticRegressionTrainBatchOp sparseVectorSvm = new LogisticRegressionTrainBatchOp().setLabelCol(yVar).setVectorCol(svectorName).setWithIntercept(false).setStandardization(false).setOptimMethod("newton").setMaxIter(10).linkFrom(trainData);
StreamOperator<?> result1 = new LogisticRegressionPredictStreamOp(svm).setPredictionCol("lrpred").linkFrom((StreamOperator<?>) getData(false));
StreamOperator<?> result2 = new LogisticRegressionPredictStreamOp(vectorSvm).setPredictionCol("svpred").linkFrom(result1);
StreamOperator<?> result3 = new LogisticRegressionPredictStreamOp(sparseVectorSvm).setPredictionCol("dvpred").linkFrom(result2);
CollectSinkStreamOp sop = result3.link(new CollectSinkStreamOp());
StreamOperator.execute();
List<Row> rows = sop.getAndRemoveValues();
for (Row row : rows) {
for (int i = 7; i < 10; ++i) {
Assert.assertEquals(row.getField(6), row.getField(i));
}
}
}
use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class LassoRegressionTest method regressionPipelineTest.
@Test
public void regressionPipelineTest() throws Exception {
BatchOperator<?> vecdata = new MemSourceBatchOp(Arrays.asList(vecRows), veccolNames);
StreamOperator<?> svecdata = new MemSourceStreamOp(Arrays.asList(vecRows), veccolNames);
String[] xVars = new String[] { "f0", "f1", "f2" };
String yVar = "label";
String vec = "vec";
String svec = "svec";
LassoRegression lasso = new LassoRegression().setLabelCol(yVar).setFeatureCols(xVars).setLambda(0.01).setMaxIter(20).setOptimMethod("owlqn").setPredictionCol("linpred");
LassoRegression vlasso = new LassoRegression().setLabelCol(yVar).setVectorCol(vec).setMaxIter(20).setLambda(0.01).setOptimMethod("newton").setPredictionCol("vlinpred").enableLazyPrintModelInfo();
LassoRegression svlasso = new LassoRegression().setLabelCol(yVar).setVectorCol(svec).setMaxIter(20).setLambda(0.01).setPredictionCol("svlinpred");
Pipeline pl = new Pipeline().add(lasso).add(vlasso).add(svlasso);
PipelineModel model = pl.fit(vecdata);
BatchOperator<?> result = model.transform(vecdata).select(new String[] { "label", "linpred", "vlinpred", "svlinpred" });
List<Row> data = result.collect();
for (Row row : data) {
if ((double) row.getField(0) == 16.8000) {
Assert.assertEquals((double) row.getField(1), 16.784611802507232, 0.01);
Assert.assertEquals((double) row.getField(2), 16.784611802507232, 0.01);
Assert.assertEquals((double) row.getField(3), 16.78209421260283, 0.01);
} else if ((double) row.getField(0) == 6.7000) {
Assert.assertEquals((double) row.getField(1), 6.7713287283076, 0.01);
Assert.assertEquals((double) row.getField(2), 6.7713287283076, 0.01);
Assert.assertEquals((double) row.getField(3), 6.826846826823054, 0.01);
}
}
// below is stream test code
CollectSinkStreamOp sop = model.transform(svecdata).select(new String[] { "label", "linpred", "vlinpred", "svlinpred" }).link(new CollectSinkStreamOp());
StreamOperator.execute();
List<Row> rows = sop.getAndRemoveValues();
for (Row row : rows) {
if ((double) row.getField(0) == 16.8000) {
Assert.assertEquals((double) row.getField(1), 16.784611802507232, 0.01);
Assert.assertEquals((double) row.getField(2), 16.784611802507232, 0.01);
Assert.assertEquals((double) row.getField(3), 16.78209421260283, 0.01);
} else if ((double) row.getField(0) == 6.7000) {
Assert.assertEquals((double) row.getField(1), 6.7713287283076, 0.01);
Assert.assertEquals((double) row.getField(2), 6.7713287283076, 0.01);
Assert.assertEquals((double) row.getField(3), 6.826846826823054, 0.01);
}
}
}
use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class RidgeRegressionTest method regressionPipelineTest.
@Test
public void regressionPipelineTest() throws Exception {
BatchOperator<?> vecdata = new MemSourceBatchOp(Arrays.asList(vecrows), veccolNames);
StreamOperator<?> svecdata = new MemSourceStreamOp(Arrays.asList(vecrows), veccolNames);
String[] xVars = new String[] { "f0", "f1", "f2" };
String yVar = "label";
String vec = "vec";
String svec = "svec";
RidgeRegression ridge = new RidgeRegression().setLabelCol(yVar).setFeatureCols(xVars).setLambda(0.01).setMaxIter(10).setPredictionCol("linpred");
RidgeRegression vridge = new RidgeRegression().setLabelCol(yVar).setVectorCol(vec).setLambda(0.01).setMaxIter(10).setOptimMethod("newton").setPredictionCol("vlinpred");
RidgeRegression svridge = new RidgeRegression().setLabelCol(yVar).setVectorCol(svec).setLambda(0.01).setMaxIter(10).setPredictionCol("svlinpred");
Pipeline pl = new Pipeline().add(ridge).add(vridge).add(svridge);
PipelineModel model = pl.fit(vecdata);
BatchOperator<?> result = model.transform(vecdata).select(new String[] { "label", "linpred", "vlinpred", "svlinpred" });
List<Row> data = result.collect();
for (Row row : data) {
if ((double) row.getField(0) == 16.8000) {
Assert.assertEquals((double) row.getField(1), 16.77322547668301, 0.01);
Assert.assertEquals((double) row.getField(2), 16.620448399254673, 0.01);
Assert.assertEquals((double) row.getField(3), 16.384437074591887, 0.01);
} else if ((double) row.getField(0) == 6.7000) {
Assert.assertEquals((double) row.getField(1), 6.932628087721653, 0.01);
Assert.assertEquals((double) row.getField(2), 6.775060404865803, 0.01);
Assert.assertEquals((double) row.getField(3), 7.425378715755974, 0.01);
}
}
// below is stream test code
CollectSinkStreamOp sop = model.transform(svecdata).select(new String[] { "label", "linpred", "vlinpred", "svlinpred" }).link(new CollectSinkStreamOp());
StreamOperator.execute();
List<Row> rows = sop.getAndRemoveValues();
for (Row row : rows) {
if ((double) row.getField(0) == 16.8000) {
Assert.assertEquals((double) row.getField(1), 16.77322547668301, 0.01);
Assert.assertEquals((double) row.getField(2), 16.620448399254673, 0.01);
Assert.assertEquals((double) row.getField(3), 16.384437074591887, 0.01);
} else if ((double) row.getField(0) == 6.7000) {
Assert.assertEquals((double) row.getField(1), 6.932628087721653, 0.01);
Assert.assertEquals((double) row.getField(2), 6.775060404865803, 0.01);
Assert.assertEquals((double) row.getField(3), 7.425378715755974, 0.01);
}
}
}
use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class StringNearestNeighborTest method testStringStream.
@Test
public void testStringStream() throws Exception {
BatchOperator dict = new MemSourceBatchOp(Arrays.asList(StringNearestNeighborBatchOpTest.dictRows), new String[] { "id", "str" });
StreamOperator query = new MemSourceStreamOp(Arrays.asList(StringNearestNeighborBatchOpTest.queryRows), new String[] { "id", "str" });
StringNearestNeighborModel model = new StringNearestNeighbor().setIdCol("id").setSelectedCol("str").setTopN(3).setOutputCol("output").fit(dict);
CollectSinkStreamOp collectSinkStreamOp = new CollectSinkStreamOp().linkFrom(model.transform(query));
StreamOperator.execute();
List<Row> res = collectSinkStreamOp.getAndRemoveValues();
Map<Object, Double[]> score = new HashMap<>();
score.put(1, new Double[] { 0.75, 0.667, 0.333 });
score.put(2, new Double[] { 0.667, 0.667, 0.5 });
score.put(3, new Double[] { 0.333, 0.333, 0.25 });
score.put(4, new Double[] { 0.75, 0.333, 0.333 });
score.put(5, new Double[] { 0.333, 0.25, 0.25 });
score.put(6, new Double[] { 0.333, 0.333, 0.333 });
for (Row row : res) {
Double[] actual = StringNearestNeighborBatchOpTest.extractScore((String) row.getField(2));
Double[] expect = score.get(row.getField(0));
for (int i = 0; i < actual.length; i++) {
Assert.assertEquals(actual[i], expect[i], 0.01);
}
}
}
Aggregations