use of com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp in project Alink by alibaba.
the class EqualWidthDiscretizerTest method testException.
@Test
public void testException() throws Exception {
thrown.expect(RuntimeException.class);
NumSeqSourceBatchOp numSeqSourceBatchOp = new NumSeqSourceBatchOp(0, 10, "col0");
EqualWidthDiscretizerTrainBatchOp op = new EqualWidthDiscretizerTrainBatchOp().setNumBuckets(5).setNumBucketsArray(5, 4).setSelectedCols("col0").linkFrom(numSeqSourceBatchOp);
op.lazyPrintModelInfo();
op = new EqualWidthDiscretizerTrainBatchOp().setNumBucketsArray(5).setSelectedCols("col0").linkFrom(numSeqSourceBatchOp);
op.lazyCollect(new Consumer<List<Row>>() {
@Override
public void accept(List<Row> rows) {
System.out.println(Arrays.toString(new EqualWidthDiscretizerModelInfoBatchOp.EqualWidthDiscretizerModelInfo(rows).getCutsArray("col0")));
}
});
}
use of com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp in project Alink by alibaba.
the class EqualWidthDiscretizerTest method test.
@Test
public void test() throws Exception {
try {
NumSeqSourceBatchOp numSeqSourceBatchOp = new NumSeqSourceBatchOp(0, 10, "col0");
Pipeline pipeline = new Pipeline().add(new EqualWidthDiscretizer().setNumBuckets(3).enableLazyPrintModelInfo().setSelectedCols("col0"));
pipeline.fit(numSeqSourceBatchOp).transform(numSeqSourceBatchOp).collect();
} catch (Exception ex) {
ex.printStackTrace();
Assert.fail("Should not throw exception here.");
}
}
use of com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp in project Alink by alibaba.
the class QuantileDiscretizerTest method train.
@Test
public void train() {
try {
NumSeqSourceBatchOp numSeqSourceBatchOp = new NumSeqSourceBatchOp(0, 1000, "col0");
Pipeline pipeline = new Pipeline().add(new QuantileDiscretizer().setNumBuckets(2).setSelectedCols(new String[] { "col0" }).enableLazyPrintModelInfo());
Assert.assertEquals(1001, pipeline.fit(numSeqSourceBatchOp).transform(numSeqSourceBatchOp).collect().size());
} catch (Exception ex) {
ex.printStackTrace();
Assert.fail("Should not throw exception here.");
}
}
use of com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp in project Alink by alibaba.
the class BaseKerasSequentialTrainBatchOp method linkFrom.
@Override
public T linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = inputs[0];
Params params = getParams();
TaskType taskType = params.get(HasTaskType.TASK_TYPE);
boolean isReg = TaskType.REGRESSION.equals(taskType);
String tensorCol = getTensorCol();
String labelCol = getLabelCol();
TypeInformation<?> labelType = TableUtil.findColTypeWithAssertAndHint(in.getSchema(), labelCol);
DataSet<List<Object>> sortedLabels = null;
BatchOperator<?> numLabelsOp = null;
if (!isReg) {
sortedLabels = in.select(labelCol).getDataSet().mapPartition(new MapPartitionFunction<Row, Object>() {
@Override
public void mapPartition(Iterable<Row> iterable, Collector<Object> collector) throws Exception {
Set<Object> distinctValue = new HashSet<>();
for (Row row : iterable) {
distinctValue.add(row.getField(0));
}
for (Object obj : distinctValue) {
collector.collect(obj);
}
}
}).reduceGroup(new GroupReduceFunction<Object, List<Object>>() {
@Override
public void reduce(Iterable<Object> iterable, Collector<List<Object>> collector) throws Exception {
Set<Object> distinctValue = new TreeSet<>();
for (Object obj : iterable) {
distinctValue.add(obj);
}
collector.collect(new ArrayList<>(distinctValue));
}
});
in = CommonUtils.mapLabelToIndex(in, labelCol, sortedLabels);
numLabelsOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), sortedLabels.map(new CountLabelsMapFunction()), new String[] { "count" }, new TypeInformation<?>[] { Types.INT })).setMLEnvironmentId(getMLEnvironmentId());
}
Boolean removeCheckpointBeforeTraining = getRemoveCheckpointBeforeTraining();
if (null == removeCheckpointBeforeTraining) {
// default to clean checkpoint
removeCheckpointBeforeTraining = true;
}
Map<String, Object> modelConfig = new HashMap<>();
modelConfig.put("layers", getLayers());
Map<String, String> userParams = new HashMap<>();
if (removeCheckpointBeforeTraining) {
userParams.put(DLConstants.REMOVE_CHECKPOINT_BEFORE_TRAINING, "true");
}
userParams.put("tensor_cols", JsonConverter.toJson(new String[] { tensorCol }));
userParams.put("label_col", labelCol);
userParams.put("label_type", "float");
userParams.put("batch_size", String.valueOf(getBatchSize()));
userParams.put("num_epochs", String.valueOf(getNumEpochs()));
userParams.put("model_config", JsonConverter.toJson(modelConfig));
userParams.put("optimizer", getOptimizer());
if (!StringUtils.isNullOrWhitespaceOnly(getCheckpointFilePath())) {
userParams.put("model_dir", getCheckpointFilePath());
}
ExecutionEnvironment env = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment();
if (env.getParallelism() == 1) {
userParams.put("ALINK:ONLY_ONE_WORKER", "true");
}
userParams.put("validation_split", String.valueOf(getValidationSplit()));
userParams.put("save_best_only", String.valueOf(getSaveBestOnly()));
userParams.put("best_exporter_metric", getBestMetric());
userParams.put("save_checkpoints_epochs", String.valueOf(getSaveCheckpointsEpochs()));
if (params.contains(BaseKerasSequentialTrainParams.SAVE_CHECKPOINTS_SECS)) {
userParams.put("save_checkpoints_secs", String.valueOf(getSaveCheckpointsSecs()));
}
TF2TableModelTrainBatchOp trainBatchOp = new TF2TableModelTrainBatchOp(params).setSelectedCols(tensorCol, labelCol).setUserFiles(RES_PY_FILES).setMainScriptFile(MAIN_SCRIPT_FILE_NAME).setUserParams(JsonConverter.toJson(userParams)).setIntraOpParallelism(getIntraOpParallelism()).setNumPSs(getNumPSs()).setNumWorkers(getNumWorkers()).setPythonEnv(params.get(HasPythonEnv.PYTHON_ENV));
if (isReg) {
trainBatchOp = trainBatchOp.linkFrom(in);
} else {
trainBatchOp = trainBatchOp.linkFrom(in, numLabelsOp);
}
String tfOutputSignatureDef = getTfOutputSignatureDef(taskType);
FlatMapOperator<Row, Row> constructModelFlatMapOperator = new NumSeqSourceBatchOp().setFrom(0).setTo(0).setMLEnvironmentId(getMLEnvironmentId()).getDataSet().flatMap(new ConstructModelFlatMapFunction(params, new String[] { tensorCol }, tfOutputSignatureDef, TF_OUTPUT_SIGNATURE_TYPE, null, true)).withBroadcastSet(trainBatchOp.getDataSet(), CommonUtils.TF_MODEL_BC_NAME);
BatchOperator<?> modelOp;
if (isReg) {
DataSet<Row> modelDataSet = constructModelFlatMapOperator;
modelOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), modelDataSet, new TFTableModelRegressionModelDataConverter(labelType).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId());
} else {
DataSet<Row> modelDataSet = constructModelFlatMapOperator.withBroadcastSet(sortedLabels, CommonUtils.SORTED_LABELS_BC_NAME);
modelOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), modelDataSet, new TFTableModelClassificationModelDataConverter(labelType).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId());
}
this.setOutputTable(modelOp.getOutputTable());
return (T) this;
}
use of com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp in project Alink by alibaba.
the class OperatorConstructorTest method testConstructor.
@SuppressWarnings({ "unchecked", "rawtypes" })
public <T extends WithParams> void testConstructor(Class<T> clazz) {
Constructor<?>[] constructors = clazz.getConstructors();
for (Constructor<?> constructor : constructors) {
Parameter[] parameters = constructor.getParameters();
int nParams = parameters.length;
T instance = null;
try {
if (nParams == 0) {
instance = (T) constructor.newInstance();
} else if ((nParams == 1) && (parameters[0].getType().equals(Params.class))) {
Params params = new Params();
instance = (T) constructor.newInstance(params);
} else if ((nParams == 1) && (parameters[0].getType().equals(BatchOperator.class))) {
// fake model
BatchOperator model = new NumSeqSourceBatchOp(1);
instance = (T) constructor.newInstance(model);
} else if ((nParams == 2) && (parameters[0].getType().equals(BatchOperator.class)) && (parameters[1].getType().equals(Params.class))) {
// fake model
BatchOperator model = new NumSeqSourceBatchOp(1);
Params params = new Params();
instance = (T) constructor.newInstance(model, params);
} else {
// System.out.println(clazz.getCanonicalName());
}
} catch (Exception ex) {
Assert.fail(ex.toString());
}
if (null != instance) {
Assert.assertNotNull(instance.getParams());
}
}
}
Aggregations