Search in sources :

Example 1 with OneVsRest

use of com.alibaba.alink.pipeline.classification.OneVsRest in project Alink by alibaba.

the class Chap13 method c_3.

static void c_3() throws Exception {
    AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TRAIN_FILE);
    AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + SPARSE_TEST_FILE);
    BatchOperator.setParallelism(1);
    new OneVsRest().setClassifier(new LogisticRegression().setVectorCol(VECTOR_COL_NAME).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME)).setNumClass(10).fit(train_data).transform(test_data).link(new EvalMultiClassBatchOp().setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).lazyPrintMetrics("OneVsRest - LogisticRegression"));
    new OneVsRest().setClassifier(new LinearSvm().setVectorCol(VECTOR_COL_NAME).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME)).setNumClass(10).fit(train_data).transform(test_data).link(new EvalMultiClassBatchOp().setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).lazyPrintMetrics("OneVsRest - LinearSvm"));
    BatchOperator.execute();
}
Also used : EvalMultiClassBatchOp(com.alibaba.alink.operator.batch.evaluation.EvalMultiClassBatchOp) AkSourceBatchOp(com.alibaba.alink.operator.batch.source.AkSourceBatchOp) OneVsRest(com.alibaba.alink.pipeline.classification.OneVsRest) LinearSvm(com.alibaba.alink.pipeline.classification.LinearSvm) LogisticRegression(com.alibaba.alink.pipeline.classification.LogisticRegression)

Example 2 with OneVsRest

use of com.alibaba.alink.pipeline.classification.OneVsRest in project Alink by alibaba.

the class LocalPredictorTest method testGeneModelStream.

@Test
public void testGeneModelStream() throws Exception {
    BatchOperator data = Iris.getBatchData();
    LogisticRegression lr = new LogisticRegression().setFeatureCols(Iris.getFeatureColNames()).setLabelCol(Iris.getLabelColName()).setPredictionCol("pred_label").setPredictionDetailCol("pred_detail").setModelStreamFilePath("/tmp/rankModel").setMaxIter(100);
    OneVsRest oneVsRest = new OneVsRest().setClassifier(lr).setNumClass(3).setPredictionCol("pred").setPredictionDetailCol("detail");
    VectorAssembler va = new VectorAssembler().setSelectedCols("sepal_length", "sepal_width").setOutputCol("assem");
    Pipeline pipeline = new Pipeline().add(oneVsRest).add(va);
    PipelineModel model = pipeline.fit(data);
    model.save().link(new AkSinkBatchOp().setFilePath("/tmp/rankModel.ak").setOverwriteSink(true));
    BatchOperator.execute();
}
Also used : OneVsRest(com.alibaba.alink.pipeline.classification.OneVsRest) VectorAssembler(com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler) LogisticRegression(com.alibaba.alink.pipeline.classification.LogisticRegression) AkSinkBatchOp(com.alibaba.alink.operator.batch.sink.AkSinkBatchOp) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) Test(org.junit.Test)

Example 3 with OneVsRest

use of com.alibaba.alink.pipeline.classification.OneVsRest in project Alink by alibaba.

the class Chap12 method c_4.

static void c_4() throws Exception {
    AkSourceBatchOp train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
    AkSourceBatchOp test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
    new OneVsRest().setClassifier(new LogisticRegression().setFeatureCols(FEATURE_COL_NAMES).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME)).setNumClass(3).fit(train_data).transform(test_data).link(new EvalMultiClassBatchOp().setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).lazyPrintMetrics("OneVsRest_LogisticRegression"));
    new OneVsRest().setClassifier(new GbdtClassifier().setFeatureCols(FEATURE_COL_NAMES).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME)).setNumClass(3).fit(train_data).transform(test_data).link(new EvalMultiClassBatchOp().setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).lazyPrintMetrics("OneVsRest_GBDT"));
    new OneVsRest().setClassifier(new LinearSvm().setFeatureCols(FEATURE_COL_NAMES).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME)).setNumClass(3).fit(train_data).transform(test_data).link(new EvalMultiClassBatchOp().setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).lazyPrintMetrics("OneVsRest_LinearSvm"));
    BatchOperator.execute();
}
Also used : EvalMultiClassBatchOp(com.alibaba.alink.operator.batch.evaluation.EvalMultiClassBatchOp) AkSourceBatchOp(com.alibaba.alink.operator.batch.source.AkSourceBatchOp) OneVsRest(com.alibaba.alink.pipeline.classification.OneVsRest) GbdtClassifier(com.alibaba.alink.pipeline.classification.GbdtClassifier) LinearSvm(com.alibaba.alink.pipeline.classification.LinearSvm) LogisticRegression(com.alibaba.alink.pipeline.classification.LogisticRegression)

Aggregations

LogisticRegression (com.alibaba.alink.pipeline.classification.LogisticRegression)3 OneVsRest (com.alibaba.alink.pipeline.classification.OneVsRest)3 EvalMultiClassBatchOp (com.alibaba.alink.operator.batch.evaluation.EvalMultiClassBatchOp)2 AkSourceBatchOp (com.alibaba.alink.operator.batch.source.AkSourceBatchOp)2 LinearSvm (com.alibaba.alink.pipeline.classification.LinearSvm)2 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)1 AkSinkBatchOp (com.alibaba.alink.operator.batch.sink.AkSinkBatchOp)1 GbdtClassifier (com.alibaba.alink.pipeline.classification.GbdtClassifier)1 VectorAssembler (com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler)1 Test (org.junit.Test)1