Search in sources :

Example 1 with NumSeqSourceBatchOp

use of com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp in project Alink by alibaba.

the class EqualWidthDiscretizerTest method testException.

@Test
public void testException() throws Exception {
    thrown.expect(RuntimeException.class);
    NumSeqSourceBatchOp numSeqSourceBatchOp = new NumSeqSourceBatchOp(0, 10, "col0");
    EqualWidthDiscretizerTrainBatchOp op = new EqualWidthDiscretizerTrainBatchOp().setNumBuckets(5).setNumBucketsArray(5, 4).setSelectedCols("col0").linkFrom(numSeqSourceBatchOp);
    op.lazyPrintModelInfo();
    op = new EqualWidthDiscretizerTrainBatchOp().setNumBucketsArray(5).setSelectedCols("col0").linkFrom(numSeqSourceBatchOp);
    op.lazyCollect(new Consumer<List<Row>>() {

        @Override
        public void accept(List<Row> rows) {
            System.out.println(Arrays.toString(new EqualWidthDiscretizerModelInfoBatchOp.EqualWidthDiscretizerModelInfo(rows).getCutsArray("col0")));
        }
    });
}
Also used : NumSeqSourceBatchOp(com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp) EqualWidthDiscretizerTrainBatchOp(com.alibaba.alink.operator.batch.feature.EqualWidthDiscretizerTrainBatchOp) List(java.util.List) Row(org.apache.flink.types.Row) Test(org.junit.Test)

Example 2 with NumSeqSourceBatchOp

use of com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp in project Alink by alibaba.

the class EqualWidthDiscretizerTest method test.

@Test
public void test() throws Exception {
    try {
        NumSeqSourceBatchOp numSeqSourceBatchOp = new NumSeqSourceBatchOp(0, 10, "col0");
        Pipeline pipeline = new Pipeline().add(new EqualWidthDiscretizer().setNumBuckets(3).enableLazyPrintModelInfo().setSelectedCols("col0"));
        pipeline.fit(numSeqSourceBatchOp).transform(numSeqSourceBatchOp).collect();
    } catch (Exception ex) {
        ex.printStackTrace();
        Assert.fail("Should not throw exception here.");
    }
}
Also used : NumSeqSourceBatchOp(com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp) ExpectedException(org.junit.rules.ExpectedException) Pipeline(com.alibaba.alink.pipeline.Pipeline) Test(org.junit.Test)

Example 3 with NumSeqSourceBatchOp

use of com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp in project Alink by alibaba.

the class QuantileDiscretizerTest method train.

@Test
public void train() {
    try {
        NumSeqSourceBatchOp numSeqSourceBatchOp = new NumSeqSourceBatchOp(0, 1000, "col0");
        Pipeline pipeline = new Pipeline().add(new QuantileDiscretizer().setNumBuckets(2).setSelectedCols(new String[] { "col0" }).enableLazyPrintModelInfo());
        Assert.assertEquals(1001, pipeline.fit(numSeqSourceBatchOp).transform(numSeqSourceBatchOp).collect().size());
    } catch (Exception ex) {
        ex.printStackTrace();
        Assert.fail("Should not throw exception here.");
    }
}
Also used : NumSeqSourceBatchOp(com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp) Pipeline(com.alibaba.alink.pipeline.Pipeline) Test(org.junit.Test)

Example 4 with NumSeqSourceBatchOp

use of com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp in project Alink by alibaba.

the class BaseKerasSequentialTrainBatchOp method linkFrom.

@Override
public T linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = inputs[0];
    Params params = getParams();
    TaskType taskType = params.get(HasTaskType.TASK_TYPE);
    boolean isReg = TaskType.REGRESSION.equals(taskType);
    String tensorCol = getTensorCol();
    String labelCol = getLabelCol();
    TypeInformation<?> labelType = TableUtil.findColTypeWithAssertAndHint(in.getSchema(), labelCol);
    DataSet<List<Object>> sortedLabels = null;
    BatchOperator<?> numLabelsOp = null;
    if (!isReg) {
        sortedLabels = in.select(labelCol).getDataSet().mapPartition(new MapPartitionFunction<Row, Object>() {

            @Override
            public void mapPartition(Iterable<Row> iterable, Collector<Object> collector) throws Exception {
                Set<Object> distinctValue = new HashSet<>();
                for (Row row : iterable) {
                    distinctValue.add(row.getField(0));
                }
                for (Object obj : distinctValue) {
                    collector.collect(obj);
                }
            }
        }).reduceGroup(new GroupReduceFunction<Object, List<Object>>() {

            @Override
            public void reduce(Iterable<Object> iterable, Collector<List<Object>> collector) throws Exception {
                Set<Object> distinctValue = new TreeSet<>();
                for (Object obj : iterable) {
                    distinctValue.add(obj);
                }
                collector.collect(new ArrayList<>(distinctValue));
            }
        });
        in = CommonUtils.mapLabelToIndex(in, labelCol, sortedLabels);
        numLabelsOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), sortedLabels.map(new CountLabelsMapFunction()), new String[] { "count" }, new TypeInformation<?>[] { Types.INT })).setMLEnvironmentId(getMLEnvironmentId());
    }
    Boolean removeCheckpointBeforeTraining = getRemoveCheckpointBeforeTraining();
    if (null == removeCheckpointBeforeTraining) {
        // default to clean checkpoint
        removeCheckpointBeforeTraining = true;
    }
    Map<String, Object> modelConfig = new HashMap<>();
    modelConfig.put("layers", getLayers());
    Map<String, String> userParams = new HashMap<>();
    if (removeCheckpointBeforeTraining) {
        userParams.put(DLConstants.REMOVE_CHECKPOINT_BEFORE_TRAINING, "true");
    }
    userParams.put("tensor_cols", JsonConverter.toJson(new String[] { tensorCol }));
    userParams.put("label_col", labelCol);
    userParams.put("label_type", "float");
    userParams.put("batch_size", String.valueOf(getBatchSize()));
    userParams.put("num_epochs", String.valueOf(getNumEpochs()));
    userParams.put("model_config", JsonConverter.toJson(modelConfig));
    userParams.put("optimizer", getOptimizer());
    if (!StringUtils.isNullOrWhitespaceOnly(getCheckpointFilePath())) {
        userParams.put("model_dir", getCheckpointFilePath());
    }
    ExecutionEnvironment env = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment();
    if (env.getParallelism() == 1) {
        userParams.put("ALINK:ONLY_ONE_WORKER", "true");
    }
    userParams.put("validation_split", String.valueOf(getValidationSplit()));
    userParams.put("save_best_only", String.valueOf(getSaveBestOnly()));
    userParams.put("best_exporter_metric", getBestMetric());
    userParams.put("save_checkpoints_epochs", String.valueOf(getSaveCheckpointsEpochs()));
    if (params.contains(BaseKerasSequentialTrainParams.SAVE_CHECKPOINTS_SECS)) {
        userParams.put("save_checkpoints_secs", String.valueOf(getSaveCheckpointsSecs()));
    }
    TF2TableModelTrainBatchOp trainBatchOp = new TF2TableModelTrainBatchOp(params).setSelectedCols(tensorCol, labelCol).setUserFiles(RES_PY_FILES).setMainScriptFile(MAIN_SCRIPT_FILE_NAME).setUserParams(JsonConverter.toJson(userParams)).setIntraOpParallelism(getIntraOpParallelism()).setNumPSs(getNumPSs()).setNumWorkers(getNumWorkers()).setPythonEnv(params.get(HasPythonEnv.PYTHON_ENV));
    if (isReg) {
        trainBatchOp = trainBatchOp.linkFrom(in);
    } else {
        trainBatchOp = trainBatchOp.linkFrom(in, numLabelsOp);
    }
    String tfOutputSignatureDef = getTfOutputSignatureDef(taskType);
    FlatMapOperator<Row, Row> constructModelFlatMapOperator = new NumSeqSourceBatchOp().setFrom(0).setTo(0).setMLEnvironmentId(getMLEnvironmentId()).getDataSet().flatMap(new ConstructModelFlatMapFunction(params, new String[] { tensorCol }, tfOutputSignatureDef, TF_OUTPUT_SIGNATURE_TYPE, null, true)).withBroadcastSet(trainBatchOp.getDataSet(), CommonUtils.TF_MODEL_BC_NAME);
    BatchOperator<?> modelOp;
    if (isReg) {
        DataSet<Row> modelDataSet = constructModelFlatMapOperator;
        modelOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), modelDataSet, new TFTableModelRegressionModelDataConverter(labelType).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId());
    } else {
        DataSet<Row> modelDataSet = constructModelFlatMapOperator.withBroadcastSet(sortedLabels, CommonUtils.SORTED_LABELS_BC_NAME);
        modelOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), modelDataSet, new TFTableModelClassificationModelDataConverter(labelType).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId());
    }
    this.setOutputTable(modelOp.getOutputTable());
    return (T) this;
}
Also used : ExecutionEnvironment(org.apache.flink.api.java.ExecutionEnvironment) ConstructModelFlatMapFunction(com.alibaba.alink.operator.common.tensorflow.CommonUtils.ConstructModelFlatMapFunction) DataSet(org.apache.flink.api.java.DataSet) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) HasTaskType(com.alibaba.alink.params.dl.HasTaskType) Collector(org.apache.flink.util.Collector) NumSeqSourceBatchOp(com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp) GroupReduceFunction(org.apache.flink.api.common.functions.GroupReduceFunction) BaseKerasSequentialTrainParams(com.alibaba.alink.params.tensorflow.kerasequential.BaseKerasSequentialTrainParams) Params(org.apache.flink.ml.api.misc.param.Params) TFTableModelRegressionModelDataConverter(com.alibaba.alink.operator.common.regression.tensorflow.TFTableModelRegressionModelDataConverter) TFTableModelClassificationModelDataConverter(com.alibaba.alink.operator.common.classification.tensorflow.TFTableModelClassificationModelDataConverter) TF2TableModelTrainBatchOp(com.alibaba.alink.operator.batch.tensorflow.TF2TableModelTrainBatchOp) CountLabelsMapFunction(com.alibaba.alink.operator.common.tensorflow.CommonUtils.CountLabelsMapFunction) Row(org.apache.flink.types.Row)

Example 5 with NumSeqSourceBatchOp

use of com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp in project Alink by alibaba.

the class OperatorConstructorTest method testConstructor.

@SuppressWarnings({ "unchecked", "rawtypes" })
public <T extends WithParams> void testConstructor(Class<T> clazz) {
    Constructor<?>[] constructors = clazz.getConstructors();
    for (Constructor<?> constructor : constructors) {
        Parameter[] parameters = constructor.getParameters();
        int nParams = parameters.length;
        T instance = null;
        try {
            if (nParams == 0) {
                instance = (T) constructor.newInstance();
            } else if ((nParams == 1) && (parameters[0].getType().equals(Params.class))) {
                Params params = new Params();
                instance = (T) constructor.newInstance(params);
            } else if ((nParams == 1) && (parameters[0].getType().equals(BatchOperator.class))) {
                // fake model
                BatchOperator model = new NumSeqSourceBatchOp(1);
                instance = (T) constructor.newInstance(model);
            } else if ((nParams == 2) && (parameters[0].getType().equals(BatchOperator.class)) && (parameters[1].getType().equals(Params.class))) {
                // fake model
                BatchOperator model = new NumSeqSourceBatchOp(1);
                Params params = new Params();
                instance = (T) constructor.newInstance(model, params);
            } else {
            // System.out.println(clazz.getCanonicalName());
            }
        } catch (Exception ex) {
            Assert.fail(ex.toString());
        }
        if (null != instance) {
            Assert.assertNotNull(instance.getParams());
        }
    }
}
Also used : NumSeqSourceBatchOp(com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp) Constructor(java.lang.reflect.Constructor) Parameter(java.lang.reflect.Parameter) WithParams(org.apache.flink.ml.api.misc.param.WithParams) Params(org.apache.flink.ml.api.misc.param.Params) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator)

Aggregations

NumSeqSourceBatchOp (com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp)5 Test (org.junit.Test)3 Pipeline (com.alibaba.alink.pipeline.Pipeline)2 Params (org.apache.flink.ml.api.misc.param.Params)2 Row (org.apache.flink.types.Row)2 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)1 EqualWidthDiscretizerTrainBatchOp (com.alibaba.alink.operator.batch.feature.EqualWidthDiscretizerTrainBatchOp)1 TableSourceBatchOp (com.alibaba.alink.operator.batch.source.TableSourceBatchOp)1 TF2TableModelTrainBatchOp (com.alibaba.alink.operator.batch.tensorflow.TF2TableModelTrainBatchOp)1 TFTableModelClassificationModelDataConverter (com.alibaba.alink.operator.common.classification.tensorflow.TFTableModelClassificationModelDataConverter)1 TFTableModelRegressionModelDataConverter (com.alibaba.alink.operator.common.regression.tensorflow.TFTableModelRegressionModelDataConverter)1 ConstructModelFlatMapFunction (com.alibaba.alink.operator.common.tensorflow.CommonUtils.ConstructModelFlatMapFunction)1 CountLabelsMapFunction (com.alibaba.alink.operator.common.tensorflow.CommonUtils.CountLabelsMapFunction)1 HasTaskType (com.alibaba.alink.params.dl.HasTaskType)1 BaseKerasSequentialTrainParams (com.alibaba.alink.params.tensorflow.kerasequential.BaseKerasSequentialTrainParams)1 Constructor (java.lang.reflect.Constructor)1 Parameter (java.lang.reflect.Parameter)1 List (java.util.List)1 GroupReduceFunction (org.apache.flink.api.common.functions.GroupReduceFunction)1 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)1