use of com.alibaba.alink.pipeline.Pipeline in project Alink by alibaba.
the class LogisticRegressionMixVecTest method batchMixVecTest14.
@Test
public void batchMixVecTest14() {
BatchOperator<?> trainData = (BatchOperator<?>) getData();
Pipeline pipeline = new Pipeline().add(new VectorAssembler().setSelectedCols(new String[] { "svec", "vec", "f0", "f1", "f2", "f3" }).setOutputCol("allvec"));
PipelineModel model = pipeline.fit(trainData);
BatchOperator<?> result = model.transform(trainData);
result.collect();
}
use of com.alibaba.alink.pipeline.Pipeline in project Alink by alibaba.
the class NaiveBayesTextTest method testPipelineBatch.
@Test
public void testPipelineBatch() throws Exception {
String labelName = "labels";
NaiveBayesTextClassifier vnb = new NaiveBayesTextClassifier().setModelType("Bernoulli").setLabelCol(labelName).setVectorCol("vec").setPredictionCol("predvResult").setPredictionDetailCol("predvResultColName").setSmoothing(0.5);
NaiveBayesTextClassifier svnb = new NaiveBayesTextClassifier().setModelType("Bernoulli").setLabelCol(labelName).setVectorCol("svec").setPredictionCol("predsvResult").setPredictionDetailCol("predsvResultColName").setSmoothing(0.5);
Pipeline pl = new Pipeline().add(vnb).add(svnb);
PipelineModel model = pl.fit((BatchOperator) getData(true));
BatchOperator result = model.transform((BatchOperator) getData(true)).select(new String[] { "labels", "predvResult", "predsvResult" });
List<Row> data = result.collect();
for (Row row : data) {
for (int i = 1; i < 3; ++i) {
Assert.assertEquals(row.getField(0), row.getField(i));
}
}
// below is stream test code.
model.transform((StreamOperator) getData(false)).print();
StreamOperator.execute();
}
use of com.alibaba.alink.pipeline.Pipeline in project Alink by alibaba.
the class SvmTest method pipelineTest.
@Test
public void pipelineTest() throws Exception {
String[] xVars = new String[] { "f0", "f1", "f2", "f3" };
String yVar = "labels";
String vectorName = "vec";
String svectorName = "svec";
LinearSvm svm = new LinearSvm().setLabelCol(yVar).setFeatureCols(xVars).setOptimMethod("gd").setPredictionCol("svmpred");
LinearSvm vectorSvm = new LinearSvm().setLabelCol(yVar).setVectorCol(vectorName).setPredictionCol("vsvmpred").enableLazyPrintModelInfo().enableLazyPrintTrainInfo();
LinearSvm sparseVectorSvm = new LinearSvm().setLabelCol(yVar).setVectorCol(svectorName).setOptimMethod("sgd").setMaxIter(10).setPredictionCol("svsvmpred").setPredictionDetailCol("detail");
Pipeline plSvm = new Pipeline().add(svm).add(vectorSvm).add(sparseVectorSvm);
BatchOperator<?> trainData = (BatchOperator<?>) getData(true);
PipelineModel model = plSvm.fit(trainData);
BatchOperator<?> result = model.transform(trainData).select(new String[] { "labels", "svmpred", "vsvmpred", "svsvmpred" });
List<Row> d = result.collect();
for (Row row : d) {
for (int i = 1; i < 3; ++i) {
Assert.assertEquals(row.getField(0), row.getField(i));
}
}
// below is stream test code.
CollectSinkStreamOp sop = model.transform((StreamOperator<?>) getData(false)).select(new String[] { "labels", "svmpred", "vsvmpred", "svsvmpred" }).link(new CollectSinkStreamOp());
StreamOperator.execute();
List<Row> rows = sop.getAndRemoveValues();
for (Row row : rows) {
for (int i = 1; i < 3; ++i) {
Assert.assertEquals(row.getField(0), row.getField(i));
}
}
}
use of com.alibaba.alink.pipeline.Pipeline in project Alink by alibaba.
the class BisectingKMeansTest method test.
@Test
public void test() throws Exception {
BisectingKMeans bisectingKMeans = new BisectingKMeans().setVectorCol("vector").setPredictionCol("pred").setK(2).setMaxIter(10);
PipelineModel model = new Pipeline().add(bisectingKMeans).fit(inputBatchOp);
BatchOperator<?> batchPredOp = model.transform(inputBatchOp).select(new String[] { "id", "pred" });
verifyPredResult(batchPredOp.collect());
CollectSinkStreamOp streamPredOp = model.transform(inputStreamOp).select(new String[] { "id", "pred" }).link(new CollectSinkStreamOp());
StreamOperator.execute();
verifyPredResult(streamPredOp.getAndRemoveValues());
}
use of com.alibaba.alink.pipeline.Pipeline in project Alink by alibaba.
the class GeoKMeansTest method testGeoKmeans.
@Test
public void testGeoKmeans() throws Exception {
GeoKMeans geoKMeans = new GeoKMeans().setLatitudeCol("f0").setLongitudeCol("f1").setPredictionCol("pred").setPredictionDistanceCol("distance").setK(2);
PipelineModel model = new Pipeline().add(geoKMeans).fit(inputBatchOp);
BatchOperator<?> batchPredOp = model.transform(inputBatchOp);
verifyPredResult(batchPredOp.collect());
CollectSinkStreamOp streamPredOp = model.transform(inputStreamOp).link(new CollectSinkStreamOp());
StreamOperator.execute();
verifyPredResult(streamPredOp.getAndRemoveValues());
}
Aggregations