use of com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler in project Alink by alibaba.
the class ModelSaveAndLoadTest method testSaveLoadSave.
@Test
public void testSaveLoadSave() throws Exception {
VectorAssembler va = new VectorAssembler().setSelectedCols(Iris.getFeatureColNames()).setOutputCol("features");
MultilayerPerceptronClassifier classifier = new MultilayerPerceptronClassifier().setVectorCol("features").setLabelCol(Iris.getLabelColName()).setLayers(new int[] { 4, 5, 3 }).setMaxIter(30).setPredictionCol("pred_label").setPredictionDetailCol("pred_detail").setReservedCols(Iris.getLabelColName());
Pipeline pipeline = new Pipeline().add(va).add(classifier);
PipelineModel model = PipelineModel.collectLoad(pipeline.fit(data).save());
LocalPredictor localPredictor = model.collectLocalPredictor(data.getSchema());
Row pred = localPredictor.map(Row.of(4.8, 3.4, 1.9, 0.2, "Iris-setosa"));
Assert.assertEquals(pred.getArity(), 3);
}
use of com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler in project Alink by alibaba.
the class PipelineModelTest method getPipeline.
protected Pipeline getPipeline() {
// model mapper
QuantileDiscretizer quantileDiscretizer = new QuantileDiscretizer().setNumBuckets(2).setSelectedCols("sepal_length");
// SISO mapper
Binarizer binarizer = new Binarizer().setSelectedCol("petal_width").setOutputCol("bina").setReservedCols("sepal_length", "petal_width", "petal_length", "category").setThreshold(1.);
// MISO Mapper
VectorAssembler assembler = new VectorAssembler().setSelectedCols("sepal_length", "petal_width").setOutputCol("assem").setReservedCols("sepal_length", "petal_width", "petal_length", "category");
// Lda
Lda lda = new Lda().setPredictionCol("lda_pred").setPredictionDetailCol("lda_pred_detail").setSelectedCol("category").setTopicNum(2).setRandomSeed(0);
return new Pipeline().add(binarizer).add(assembler).add(quantileDiscretizer).add(lda);
}
use of com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler in project Alink by alibaba.
the class PipelineSaveAndLoadTest method testSaveLoadSave.
@Test
public void testSaveLoadSave() throws Exception {
VectorAssembler va = new VectorAssembler().setSelectedCols(Iris.getFeatureColNames()).setOutputCol("features");
MultilayerPerceptronClassifier classifier = new MultilayerPerceptronClassifier().setVectorCol("features").setLabelCol(Iris.getLabelColName()).setLayers(new int[] { 4, 5, 3 }).setMaxIter(30).setPredictionCol("pred_label").setPredictionDetailCol("pred_detail").setReservedCols(Iris.getLabelColName());
Pipeline pipeline = new Pipeline().add(va).add(classifier);
Pipeline pipeline1 = Pipeline.collectLoad(pipeline.save());
PipelineModel model = PipelineModel.collectLoad(pipeline1.fit(data).save());
LocalPredictor localPredictor = model.collectLocalPredictor(data.getSchema());
Row pred = localPredictor.map(Row.of(4.8, 3.4, 1.9, 0.2, "Iris-setosa"));
Assert.assertEquals(pred.getArity(), 3);
}
use of com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler in project Alink by alibaba.
the class PipelineSaveAndLoadTest method testNewSave.
@Test
public void testNewSave() throws Exception {
VectorAssembler va = new VectorAssembler().setSelectedCols(Iris.getFeatureColNames()).setOutputCol("features");
MultilayerPerceptronClassifier classifier = new MultilayerPerceptronClassifier().setVectorCol("features").setLabelCol(Iris.getLabelColName()).setLayers(new int[] { 4, 5, 3 }).setMaxIter(30).setPredictionCol("pred_label").setPredictionDetailCol("pred_detail").setReservedCols(Iris.getLabelColName());
Pipeline pipeline = new Pipeline().add(va).add(classifier);
PipelineModel model = pipeline.fit(data);
BatchOperator<?> saved = model.save();
PipelineModel modelLoaded = PipelineModel.collectLoad(saved);
Assert.assertEquals(modelLoaded.transform(Iris.getBatchData()).count(), 150);
}
use of com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler in project Alink by alibaba.
the class PipelineSaveAndLoadTest method testNestedSaveAndLoad.
@Test
public void testNestedSaveAndLoad() throws Exception {
VectorAssembler va = new VectorAssembler().setSelectedCols(Iris.getFeatureColNames()).setOutputCol("features");
MultilayerPerceptronClassifier classifier = new MultilayerPerceptronClassifier().setVectorCol("features").setLabelCol(Iris.getLabelColName()).setLayers(new int[] { 4, 5, 3 }).setMaxIter(2).setPredictionCol("pred_label").setPredictionDetailCol("pred_detail").setReservedCols(Iris.getLabelColName());
Pipeline pipeline = new Pipeline().add(new Pipeline().add(va).add(classifier));
Pipeline pipeline1 = Pipeline.collectLoad(pipeline.save());
Assert.assertEquals(PipelineModel.collectLoad(pipeline1.fit(data).save()).transform(Iris.getBatchData()).count(), 150);
}
Aggregations