Search in sources :

Example 1 with ModelBase

use of com.alibaba.alink.pipeline.ModelBase in project Alink by alibaba.

the class GridSearchCVTest method testSplit.

@Test
public void testSplit() throws Exception {
    List<Row> rows = Arrays.asList(Row.of(1.0, "A", 0, 0, 0), Row.of(2.0, "B", 1, 1, 0), Row.of(3.0, "C", 2, 2, 1), Row.of(4.0, "D", 3, 3, 1), Row.of(1.0, "A", 0, 0, 0), Row.of(2.0, "B", 1, 1, 0), Row.of(3.0, "C", 2, 2, 1), Row.of(4.0, "D", 3, 3, 1), Row.of(1.0, "A", 0, 0, 0), Row.of(2.0, "B", 1, 1, 0), Row.of(3.0, "C", 2, 2, 1));
    String[] colNames = new String[] { "f0", "f1", "f2", "f3", "label" };
    MemSourceBatchOp data = new MemSourceBatchOp(rows, colNames);
    String[] featureColNames = new String[] { colNames[0], colNames[1], colNames[2], colNames[3] };
    String[] categoricalColNames = new String[] { colNames[1] };
    String labelColName = colNames[4];
    RandomForestClassifier rf = new RandomForestClassifier().setFeatureCols(featureColNames).setCategoricalCols(categoricalColNames).setLabelCol(labelColName).setPredictionCol("pred_result").setPredictionDetailCol("pred_detail").setSubsamplingRatio(1.0);
    Pipeline pipeline = new Pipeline(rf);
    ParamGrid paramGrid = new ParamGrid().addGrid(rf, "SUBSAMPLING_RATIO", new Double[] { 1.0 }).addGrid(rf, "NUM_TREES", new Integer[] { 3 });
    BinaryClassificationTuningEvaluator tuning_evaluator = new BinaryClassificationTuningEvaluator().setLabelCol(labelColName).setPredictionDetailCol("pred_detail").setTuningBinaryClassMetric("Accuracy");
    GridSearchTVSplit cv = new GridSearchTVSplit().setEstimator(pipeline).setParamGrid(paramGrid).setTuningEvaluator(tuning_evaluator).setTrainRatio(0.8);
    ModelBase cvModel = cv.fit(data);
    cvModel.transform(data).print();
}
Also used : MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) Row(org.apache.flink.types.Row) ModelBase(com.alibaba.alink.pipeline.ModelBase) RandomForestClassifier(com.alibaba.alink.pipeline.classification.RandomForestClassifier) Pipeline(com.alibaba.alink.pipeline.Pipeline) Test(org.junit.Test)

Example 2 with ModelBase

use of com.alibaba.alink.pipeline.ModelBase in project Alink by alibaba.

the class OneVsRest method fit.

@Override
public OneVsRestModel fit(BatchOperator<?> input) {
    String labelColName = classifier.getParams().get(HasLabelCol.LABEL_COL);
    BatchOperator<?> allLabels = getAllLabels(input, labelColName);
    int numClasses = getNumClass();
    int labelColIdx = TableUtil.findColIndexWithAssertAndHint(input.getColNames(), labelColName);
    TypeInformation<?> labelColType = input.getColTypes()[labelColIdx];
    ModelBase<?>[] models = new ModelBase<?>[numClasses];
    for (int iCls = 0; iCls < numClasses; iCls++) {
        this.classifier.set(HasPositiveLabelValueString.POS_LABEL_VAL_STR, "1");
        BatchOperator<?> trainData = generateTrainData(input, allLabels, iCls, labelColIdx);
        models[iCls] = this.classifier.fit(trainData);
    }
    Table modelData = unionAllModels(models);
    Params meta = new Params().set(ModelParamName.NUM_CLASSES, numClasses).set(ModelParamName.BIN_CLS_CLASS_NAME, this.classifier.getClass().getCanonicalName()).set(ModelParamName.BIN_CLS_PARAMS, this.classifier.getParams().toJson()).set(ModelParamName.LABEL_TYPE_NAME, FlinkTypeConverter.getTypeString(labelColType)).set(ModelParamName.MODEL_COL_NAMES, models[0].getModelData().getSchema().getFieldNames()).set(ModelParamName.MODEL_COL_TYPES, toJdbcColTypes(models[0].getModelData().getSchema().getFieldTypes()));
    Table modelMeta = createModelMeta(meta, allLabels);
    OneVsRestModel oneVsRestModel = new OneVsRestModel(classifier.getParams().clone().merge(this.getParams()));
    oneVsRestModel.setModelData(BatchOperator.fromTable(TableUtil.concatTables(new Table[] { modelMeta, modelData, allLabels.getOutputTable() }, getMLEnvironmentId())));
    return oneVsRestModel;
}
Also used : Table(org.apache.flink.table.api.Table) OneVsRestTrainParams(com.alibaba.alink.params.classification.OneVsRestTrainParams) Params(org.apache.flink.ml.api.misc.param.Params) OneVsRestPredictParams(com.alibaba.alink.params.classification.OneVsRestPredictParams) HasPositiveLabelValueString(com.alibaba.alink.params.shared.linear.HasPositiveLabelValueString) ModelBase(com.alibaba.alink.pipeline.ModelBase)

Aggregations

ModelBase (com.alibaba.alink.pipeline.ModelBase)2 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)1 OneVsRestPredictParams (com.alibaba.alink.params.classification.OneVsRestPredictParams)1 OneVsRestTrainParams (com.alibaba.alink.params.classification.OneVsRestTrainParams)1 HasPositiveLabelValueString (com.alibaba.alink.params.shared.linear.HasPositiveLabelValueString)1 Pipeline (com.alibaba.alink.pipeline.Pipeline)1 RandomForestClassifier (com.alibaba.alink.pipeline.classification.RandomForestClassifier)1 Params (org.apache.flink.ml.api.misc.param.Params)1 Table (org.apache.flink.table.api.Table)1 Row (org.apache.flink.types.Row)1 Test (org.junit.Test)1