use of com.alibaba.alink.operator.batch.feature.QuantileDiscretizerTrainBatchOp in project Alink by alibaba.
the class Preprocessing method generateQuantileDiscretizerModel.
public static BatchOperator<?> generateQuantileDiscretizerModel(BatchOperator<?> input, Params params) {
if (params.contains(HasVectorCol.VECTOR_COL)) {
return sample(input, params).linkTo(new VectorTrain(new Params().set(ZERO_AS_MISSING, params.get(ZERO_AS_MISSING))).setMLEnvironmentId(input.getMLEnvironmentId()).setVectorCol(params.get(HasVectorCol.VECTOR_COL)).setNumBuckets(params.get(HasMaxBins.MAX_BINS)));
}
String[] continuousColNames = ArrayUtils.removeElements(params.get(HasFeatureCols.FEATURE_COLS), params.get(HasCategoricalCols.CATEGORICAL_COLS));
BatchOperator<?> quantileDiscretizerModel;
if (continuousColNames != null && continuousColNames.length > 0) {
quantileDiscretizerModel = sample(input, params).linkTo(new QuantileDiscretizerTrainBatchOp(new Params().set(ZERO_AS_MISSING, params.get(ZERO_AS_MISSING))).setMLEnvironmentId(input.getMLEnvironmentId()).setSelectedCols(continuousColNames).setNumBuckets(params.get(HasMaxBins.MAX_BINS)));
} else {
QuantileDiscretizerModelDataConverter emptyModel = new QuantileDiscretizerModelDataConverter();
quantileDiscretizerModel = new DataSetWrapperBatchOp(MLEnvironmentFactory.get(input.getMLEnvironmentId()).getExecutionEnvironment().fromElements(1).mapPartition(new MapPartitionFunction<Integer, Row>() {
private static final long serialVersionUID = 2328781103352773618L;
@Override
public void mapPartition(Iterable<Integer> values, Collector<Row> out) throws Exception {
// pass
}
}), emptyModel.getModelSchema().getFieldNames(), emptyModel.getModelSchema().getFieldTypes()).setMLEnvironmentId(input.getMLEnvironmentId());
}
return quantileDiscretizerModel;
}
use of com.alibaba.alink.operator.batch.feature.QuantileDiscretizerTrainBatchOp in project Alink by alibaba.
the class PipelineSaveAndLoadTest method test2.
@Test
public void test2() throws Exception {
String model_filename = "/tmp/model2.csv";
CsvSourceBatchOp source = new CsvSourceBatchOp().setSchemaStr("sepal_length double, sepal_width double, petal_length double, petal_width double, category string").setFilePath("https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/iris.csv");
QuantileDiscretizerTrainBatchOp train = new QuantileDiscretizerTrainBatchOp().setNumBuckets(2).setSelectedCols("petal_length").linkFrom(source);
train.link(new AkSinkBatchOp().setFilePath(model_filename).setOverwriteSink(true));
BatchOperator.execute();
// # save pipeline model data to file
String pipelineModelFilename = "/tmp/model23424.csv";
QuantileDiscretizer stage1 = new QuantileDiscretizer().setNumBuckets(2).setSelectedCols("sepal_length");
Binarizer stage2 = new Binarizer().setSelectedCol("petal_width").setThreshold(1.);
AkSourceBatchOp modelData = new AkSourceBatchOp().setFilePath(model_filename);
QuantileDiscretizerModel stage3 = new QuantileDiscretizerModel().setSelectedCols("petal_length").setModelData(modelData);
PipelineModel prevPipelineModel = new Pipeline(stage1, stage2, stage3).fit(source);
prevPipelineModel.save(pipelineModelFilename, true);
BatchOperator.execute();
}
Aggregations