use of com.alibaba.alink.pipeline.classification.MultilayerPerceptronClassifier in project Alink by alibaba.
the class PipelineSaveAndLoadTest method testLocalPredictorMultiFile.
@Test
public void testLocalPredictorMultiFile() throws Exception {
VectorAssembler va = new VectorAssembler().setSelectedCols(Iris.getFeatureColNames()).setOutputCol("features");
MultilayerPerceptronClassifier classifier = new MultilayerPerceptronClassifier().setVectorCol("features").setLabelCol(Iris.getLabelColName()).setLayers(new int[] { 4, 5, 3 }).setMaxIter(30).setPredictionCol("pred_label").setPredictionDetailCol("pred_detail").setReservedCols(Iris.getLabelColName());
Pipeline pipeline = new Pipeline().add(va).add(classifier);
PipelineModel model = pipeline.fit(data);
FilePath filePath = new FilePath(folder.newFile().getAbsolutePath());
model.save(filePath, true, 2);
BatchOperator.execute();
LocalPredictor localPredictor = new LocalPredictor(filePath, new TableSchema(ArrayUtils.add(data.getColNames(), "features"), ArrayUtils.add(data.getColTypes(), VectorTypes.DENSE_VECTOR)));
Row result = localPredictor.map(Row.of(5.1, 3.5, 1.4, 0.2, "Iris-setosanew", new DenseVector(new double[] { 5.1, 3.5, 1.4, 0.2 })));
System.out.println(JsonConverter.toJson(result));
}
Aggregations