use of com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler in project Alink by alibaba.
the class LocalPredictorTest method testGeneModelStream.
@Test
public void testGeneModelStream() throws Exception {
BatchOperator data = Iris.getBatchData();
LogisticRegression lr = new LogisticRegression().setFeatureCols(Iris.getFeatureColNames()).setLabelCol(Iris.getLabelColName()).setPredictionCol("pred_label").setPredictionDetailCol("pred_detail").setModelStreamFilePath("/tmp/rankModel").setMaxIter(100);
OneVsRest oneVsRest = new OneVsRest().setClassifier(lr).setNumClass(3).setPredictionCol("pred").setPredictionDetailCol("detail");
VectorAssembler va = new VectorAssembler().setSelectedCols("sepal_length", "sepal_width").setOutputCol("assem");
Pipeline pipeline = new Pipeline().add(oneVsRest).add(va);
PipelineModel model = pipeline.fit(data);
model.save().link(new AkSinkBatchOp().setFilePath("/tmp/rankModel.ak").setOverwriteSink(true));
BatchOperator.execute();
}
use of com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler in project Alink by alibaba.
the class ModelSaveAndLoadTest method testNestedSaveAndLoad.
@Test
public void testNestedSaveAndLoad() throws Exception {
VectorAssembler va = new VectorAssembler().setSelectedCols(Iris.getFeatureColNames()).setOutputCol("features");
MultilayerPerceptronClassifier classifier = new MultilayerPerceptronClassifier().setVectorCol("features").setLabelCol(Iris.getLabelColName()).setLayers(new int[] { 4, 5, 3 }).setMaxIter(2).setPredictionCol("pred_label").setPredictionDetailCol("pred_detail").setReservedCols(Iris.getLabelColName());
Pipeline pipeline = new Pipeline().add(va).add(classifier);
Pipeline pipeline2 = new Pipeline().add(pipeline);
Assert.assertEquals(PipelineModel.collectLoad(pipeline2.fit(data).save()).transform(Iris.getBatchData()).count(), 150);
}
use of com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler in project Alink by alibaba.
the class PipelineModelTest method pipelineTestSetLazy.
@Test
public void pipelineTestSetLazy() throws Exception {
String[] binaryNames = new String[] { "docid", "word", "cnt" };
TableSchema schema = new TableSchema(new String[] { "id", "docid", "word", "cnt" }, new TypeInformation<?>[] { Types.STRING, Types.STRING, Types.STRING, Types.LONG });
Row[] array = new Row[] { Row.of("0", "doc0", "天", 4L), Row.of("1", "doc0", "地", 5L), Row.of("2", "doc0", "人", 1L), Row.of("3", "doc1", null, 3L), Row.of("4", null, "人", 2L), Row.of("5", "doc1", "合", 4L), Row.of("6", "doc1", "一", 4L), Row.of("7", "doc2", "清", 3L), Row.of("8", "doc2", "一", 2L), Row.of("9", "doc2", "色", 2L) };
BatchOperator batchSource = new MemSourceBatchOp(Arrays.asList(array), schema);
OneHotEncoder oneHot = new OneHotEncoder().setSelectedCols(binaryNames).setOutputCols("results").setDropLast(false);
VectorAssembler va = new VectorAssembler().setSelectedCols(new String[] { "cnt", "results" }).enableLazyPrintTransformData(10, "xxxxxx").setOutputCol("outN");
VectorAssembler va2 = new VectorAssembler().setSelectedCols(new String[] { "cnt", "results" }).setOutputCol("outN");
VectorAssembler va3 = new VectorAssembler().setSelectedCols(new String[] { "cnt", "results" }).setOutputCol("outN");
VectorAssembler va4 = new VectorAssembler().setSelectedCols(new String[] { "cnt", "results" }).enableLazyPrintTransformStat("xxxxxx4").setOutputCol("outN");
Pipeline pl = new Pipeline().add(oneHot).add(va).add(va2).add(va3).add(va4);
PipelineModel model = pl.fit(batchSource);
Row[] parray = new Row[] { Row.of("0", "doc0", "天", 4L), Row.of("1", "doc2", null, 3L) };
// batch predict
MemSourceBatchOp predData = new MemSourceBatchOp(Arrays.asList(parray), schema);
BatchOperator result = model.transform(predData).select(new String[] { "docid", "outN" });
List<Row> rows = result.collect();
for (Row row : rows) {
if (row.getField(0).toString().equals("doc0")) {
Assert.assertEquals(VectorUtil.getVector(row.getField(1).toString()).size(), 19);
} else if (row.getField(0).toString().equals("doc2")) {
Assert.assertEquals(VectorUtil.getVector(row.getField(1).toString()).size(), 19);
}
}
// stream predict
MemSourceStreamOp predSData = new MemSourceStreamOp(Arrays.asList(parray), schema);
model.transform(predSData).print();
StreamOperator.execute();
}
use of com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler in project Alink by alibaba.
the class PipelineSaveAndLoadTest method testNewSaveToFile.
@Test
public void testNewSaveToFile() throws Exception {
VectorAssembler va = new VectorAssembler().setSelectedCols(Iris.getFeatureColNames()).setOutputCol("features");
MultilayerPerceptronClassifier classifier = new MultilayerPerceptronClassifier().setVectorCol("features").setLabelCol(Iris.getLabelColName()).setLayers(new int[] { 4, 5, 3 }).setMaxIter(30).setPredictionCol("pred_label").setPredictionDetailCol("pred_detail").setReservedCols(Iris.getLabelColName());
Pipeline pipeline = new Pipeline().add(va).add(classifier);
PipelineModel model = pipeline.fit(data);
FilePath filePath = new FilePath(folder.newFile().getAbsolutePath());
model.save(filePath, true);
BatchOperator.execute();
PipelineModel modelLoaded = PipelineModel.load(filePath);
Assert.assertEquals(modelLoaded.transform(Iris.getBatchData()).count(), 150);
}
use of com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler in project Alink by alibaba.
the class PipelineSaveAndLoadTest method testNewSaveToFileMultiFile.
@Test
public void testNewSaveToFileMultiFile() throws Exception {
VectorAssembler va = new VectorAssembler().setSelectedCols(Iris.getFeatureColNames()).setOutputCol("features");
MultilayerPerceptronClassifier classifier = new MultilayerPerceptronClassifier().setVectorCol("features").setLabelCol(Iris.getLabelColName()).setLayers(new int[] { 4, 5, 3 }).setMaxIter(30).setPredictionCol("pred_label").setPredictionDetailCol("pred_detail").setReservedCols(Iris.getLabelColName());
Pipeline pipeline = new Pipeline().add(va).add(classifier);
PipelineModel model = pipeline.fit(data);
FilePath filePath = new FilePath(folder.newFile().getAbsolutePath());
model.save(filePath, true, 2);
BatchOperator.execute();
PipelineModel modelLoaded = PipelineModel.load(filePath);
Assert.assertEquals(modelLoaded.transform(Iris.getBatchData()).count(), 150);
}
Aggregations