Search in sources :

Example 11 with MultilayerPerceptronClassifier

use of com.alibaba.alink.pipeline.classification.MultilayerPerceptronClassifier in project Alink by alibaba.

the class PipelineSaveAndLoadTest method testLocalPredictorMultiFile.

@Test
public void testLocalPredictorMultiFile() throws Exception {
    VectorAssembler va = new VectorAssembler().setSelectedCols(Iris.getFeatureColNames()).setOutputCol("features");
    MultilayerPerceptronClassifier classifier = new MultilayerPerceptronClassifier().setVectorCol("features").setLabelCol(Iris.getLabelColName()).setLayers(new int[] { 4, 5, 3 }).setMaxIter(30).setPredictionCol("pred_label").setPredictionDetailCol("pred_detail").setReservedCols(Iris.getLabelColName());
    Pipeline pipeline = new Pipeline().add(va).add(classifier);
    PipelineModel model = pipeline.fit(data);
    FilePath filePath = new FilePath(folder.newFile().getAbsolutePath());
    model.save(filePath, true, 2);
    BatchOperator.execute();
    LocalPredictor localPredictor = new LocalPredictor(filePath, new TableSchema(ArrayUtils.add(data.getColNames(), "features"), ArrayUtils.add(data.getColTypes(), VectorTypes.DENSE_VECTOR)));
    Row result = localPredictor.map(Row.of(5.1, 3.5, 1.4, 0.2, "Iris-setosanew", new DenseVector(new double[] { 5.1, 3.5, 1.4, 0.2 })));
    System.out.println(JsonConverter.toJson(result));
}
Also used : FilePath(com.alibaba.alink.common.io.filesystem.FilePath) MultilayerPerceptronClassifier(com.alibaba.alink.pipeline.classification.MultilayerPerceptronClassifier) TableSchema(org.apache.flink.table.api.TableSchema) VectorAssembler(com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler) Row(org.apache.flink.types.Row) DenseVector(com.alibaba.alink.common.linalg.DenseVector) Test(org.junit.Test)

Aggregations

MultilayerPerceptronClassifier (com.alibaba.alink.pipeline.classification.MultilayerPerceptronClassifier)11 VectorAssembler (com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler)9 Test (org.junit.Test)9 FilePath (com.alibaba.alink.common.io.filesystem.FilePath)4 Row (org.apache.flink.types.Row)4 DenseVector (com.alibaba.alink.common.linalg.DenseVector)2 EvalMultiClassBatchOp (com.alibaba.alink.operator.batch.evaluation.EvalMultiClassBatchOp)2 AkSourceBatchOp (com.alibaba.alink.operator.batch.source.AkSourceBatchOp)2 TableSchema (org.apache.flink.table.api.TableSchema)2