use of com.alibaba.alink.operator.stream.classification.LogisticRegressionPredictStreamOp in project Alink by alibaba.
the class LogisticRegressionTest method streamTest.
@Test
public void streamTest() throws Exception {
String[] xVars = new String[] { "f0", "f1", "f2", "f3" };
String yVar = "labels";
String vectorName = "vec";
String svectorName = "svec";
BatchOperator<?> trainData = (BatchOperator<?>) getData(true);
LogisticRegressionTrainBatchOp svm = new LogisticRegressionTrainBatchOp().setLabelCol(yVar).setWithIntercept(false).setStandardization(false).setFeatureCols(xVars).setOptimMethod("lbfgs").linkFrom(trainData);
LogisticRegressionTrainBatchOp vectorSvm = new LogisticRegressionTrainBatchOp().setLabelCol(yVar).setWithIntercept(false).setStandardization(false).setVectorCol(vectorName).linkFrom(trainData);
LogisticRegressionTrainBatchOp sparseVectorSvm = new LogisticRegressionTrainBatchOp().setLabelCol(yVar).setVectorCol(svectorName).setWithIntercept(false).setStandardization(false).setOptimMethod("newton").setMaxIter(10).linkFrom(trainData);
StreamOperator<?> result1 = new LogisticRegressionPredictStreamOp(svm).setPredictionCol("lrpred").linkFrom((StreamOperator<?>) getData(false));
StreamOperator<?> result2 = new LogisticRegressionPredictStreamOp(vectorSvm).setPredictionCol("svpred").linkFrom(result1);
StreamOperator<?> result3 = new LogisticRegressionPredictStreamOp(sparseVectorSvm).setPredictionCol("dvpred").linkFrom(result2);
CollectSinkStreamOp sop = result3.link(new CollectSinkStreamOp());
StreamOperator.execute();
List<Row> rows = sop.getAndRemoveValues();
for (Row row : rows) {
for (int i = 7; i < 10; ++i) {
Assert.assertEquals(row.getField(6), row.getField(i));
}
}
}
Aggregations