use of com.alibaba.alink.pipeline.classification.NaiveBayesTextClassifier in project Alink by alibaba.
the class Chap23 method c_1.
static void c_1() throws Exception {
BatchOperator<?> train_set = new LibSvmSourceBatchOp().setFilePath(ORIGIN_DATA_DIR + "train" + File.separator + "labeledBow.feat").setStartIndex(0);
train_set.lazyPrint(1, "train_set");
train_set.groupBy("label", "label, COUNT(label) AS cnt").orderBy("label", 100).lazyPrint(-1, "labels of train_set");
BatchOperator<?> test_set = new LibSvmSourceBatchOp().setFilePath(ORIGIN_DATA_DIR + "test" + File.separator + "labeledBow.feat").setStartIndex(0);
train_set = train_set.select("CASE WHEN label>5 THEN 'pos' ELSE 'neg' END AS label, " + "features AS " + VECTOR_COL_NAME);
test_set = test_set.select("CASE WHEN label>5 THEN 'pos' ELSE 'neg' END AS label, " + "features AS " + VECTOR_COL_NAME);
train_set.lazyPrint(1, "train_set");
new NaiveBayesTextClassifier().setModelType("Multinomial").setVectorCol(VECTOR_COL_NAME).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).enableLazyPrintModelInfo().fit(train_set).transform(test_set).link(new EvalBinaryClassBatchOp().setPositiveLabelValueString("pos").setLabelCol(LABEL_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).lazyPrintMetrics("NaiveBayesTextClassifier + Multinomial"));
BatchOperator.execute();
new Pipeline().add(new Binarizer().setSelectedCol(VECTOR_COL_NAME).enableLazyPrintTransformData(1, "After Binarizer")).add(new NaiveBayesTextClassifier().setModelType("Bernoulli").setVectorCol(VECTOR_COL_NAME).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).enableLazyPrintModelInfo()).fit(train_set).transform(test_set).link(new EvalBinaryClassBatchOp().setPositiveLabelValueString("pos").setLabelCol(LABEL_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).lazyPrintMetrics("Binarizer + NaiveBayesTextClassifier + Bernoulli"));
BatchOperator.execute();
new LogisticRegression().setVectorCol(VECTOR_COL_NAME).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).enableLazyPrintTrainInfo("< LR train info >").enableLazyPrintModelInfo("< LR model info >").fit(train_set).transform(test_set).link(new EvalBinaryClassBatchOp().setPositiveLabelValueString("pos").setLabelCol(LABEL_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).lazyPrintMetrics("LogisticRegression"));
BatchOperator.execute();
AlinkGlobalConfiguration.setPrintProcessInfo(true);
LogisticRegression lr = new LogisticRegression().setVectorCol(VECTOR_COL_NAME).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME);
GridSearchCV gridSearch = new GridSearchCV().setEstimator(new Pipeline().add(lr)).setParamGrid(new ParamGrid().addGrid(lr, LogisticRegression.MAX_ITER, new Integer[] { 10, 20, 30, 40, 50, 60, 80, 100 })).setTuningEvaluator(new BinaryClassificationTuningEvaluator().setLabelCol(LABEL_COL_NAME).setPositiveLabelValueString("pos").setPredictionDetailCol(PRED_DETAIL_COL_NAME).setTuningBinaryClassMetric(TuningBinaryClassMetric.AUC)).setNumFolds(6).enableLazyPrintTrainInfo();
GridSearchCVModel bestModel = gridSearch.fit(train_set);
bestModel.transform(test_set).link(new EvalBinaryClassBatchOp().setPositiveLabelValueString("pos").setLabelCol(LABEL_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).lazyPrintMetrics("LogisticRegression"));
BatchOperator.execute();
}
Aggregations