use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class GeoKMeansTest method before.
@Before
public void before() {
Row[] rows = new Row[] { Row.of(0, 0, 0), Row.of(1, 8, 8), Row.of(2, 1, 2), Row.of(3, 9, 10), Row.of(4, 3, 1), Row.of(5, 10, 7) };
inputBatchOp = new TableSourceBatchOp(MLEnvironmentFactory.getDefault().createBatchTable(rows, new String[] { "id", "f0", "f1" }));
inputStreamOp = new TableSourceStreamOp(MLEnvironmentFactory.getDefault().createStreamTable(rows, new String[] { "id", "f0", "f1" }));
expectedPrediction = new double[] { 185.31, 117.08, 117.18, 183.04, 185.32, 183.70 };
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class CorrelationBatchOp method linkFrom.
@Override
public CorrelationBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
String[] selectedColNames = this.getParams().get(SELECTED_COLS);
if (selectedColNames == null) {
selectedColNames = in.getColNames();
}
// check col types must be double or bigint
TableUtil.assertNumericalCols(in.getSchema(), selectedColNames);
Method corrType = getMethod();
if (Method.PEARSON == corrType) {
DataSet<Tuple2<TableSummary, CorrelationResult>> srt = StatisticsHelper.pearsonCorrelation(in, selectedColNames);
DataSet<Row> result = srt.flatMap(new FlatMapFunction<Tuple2<TableSummary, CorrelationResult>, Row>() {
private static final long serialVersionUID = -4498296161046449646L;
@Override
public void flatMap(Tuple2<TableSummary, CorrelationResult> summary, Collector<Row> collector) {
new CorrelationDataConverter().save(summary.f1, collector);
}
});
this.setOutput(result, new CorrelationDataConverter().getModelSchema());
} else {
DataSet<Row> data = inputs[0].select(selectedColNames).getDataSet();
DataSet<Row> rank = SpearmanCorrelation.calcRank(data, false);
TypeInformation[] colTypes = new TypeInformation[selectedColNames.length];
for (int i = 0; i < colTypes.length; i++) {
colTypes[i] = Types.DOUBLE;
}
BatchOperator rankOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), rank, selectedColNames, colTypes)).setMLEnvironmentId(getMLEnvironmentId());
CorrelationBatchOp corrBatchOp = new CorrelationBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setSelectedCols(selectedColNames);
rankOp.link(corrBatchOp);
this.setOutput(corrBatchOp.getDataSet(), corrBatchOp.getSchema());
}
return this;
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class VectorCorrelationBatchOp method linkFrom.
@Override
public VectorCorrelationBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
String vectorColName = getSelectedCol();
Method corrType = getMethod();
if (Method.PEARSON == corrType) {
DataSet<Tuple2<BaseVectorSummary, CorrelationResult>> srt = StatisticsHelper.vectorPearsonCorrelation(in, vectorColName);
// block
DataSet<Row> result = srt.flatMap(new FlatMapFunction<Tuple2<BaseVectorSummary, CorrelationResult>, Row>() {
private static final long serialVersionUID = 2134644397476490118L;
@Override
public void flatMap(Tuple2<BaseVectorSummary, CorrelationResult> srt, Collector<Row> collector) throws Exception {
new CorrelationDataConverter().save(srt.f1, collector);
}
});
this.setOutput(result, new CorrelationDataConverter().getModelSchema());
} else {
DataSet<Row> data = StatisticsHelper.transformToColumns(in, null, vectorColName, null);
DataSet<Row> rank = SpearmanCorrelation.calcRank(data, true);
BatchOperator rankOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), rank, new String[] { "col" }, new TypeInformation[] { Types.STRING })).setMLEnvironmentId(getMLEnvironmentId());
VectorCorrelationBatchOp corrBatchOp = new VectorCorrelationBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setSelectedCol("col");
rankOp.link(corrBatchOp);
this.setOutput(corrBatchOp.getDataSet(), corrBatchOp.getSchema());
}
return this;
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class PipelineModel method collectLocalPredictor.
@Override
public LocalPredictor collectLocalPredictor(TableSchema inputSchema) throws Exception {
if (params.get(ModelStreamScanParams.MODEL_STREAM_FILE_PATH) != null) {
BatchOperator<?> modelSave = ModelExporterUtils.serializePipelineStages(Arrays.asList(transformers), params);
TableSchema extendSchema = getOutSchema(this, inputSchema);
BatchOperator<?> model = new TableSourceBatchOp(DataSetConversionUtil.toTable(modelSave.getMLEnvironmentId(), modelSave.getDataSet().map(new PipelineModelMapper.ExtendPipelineModelRow(extendSchema.getFieldNames().length + 1)), PipelineModelMapper.getExtendModelSchema(modelSave.getSchema(), extendSchema.getFieldNames(), extendSchema.getFieldTypes())));
List<Row> modelRows = model.collect();
ModelMapper mapper = new PipelineModelMapper(model.getSchema(), inputSchema, this.params);
mapper.loadModel(modelRows);
return new LocalPredictor(mapper);
}
if (null == transformers || transformers.length == 0) {
throw new RuntimeException("PipelineModel is empty.");
}
List<BatchOperator<?>> allModelData = new ArrayList<>();
for (TransformerBase<?> transformer : transformers) {
if (!(transformer instanceof LocalPredictable)) {
throw new RuntimeException(transformer.getClass().toString() + " not support local predict.");
}
if (transformer instanceof ModelBase) {
allModelData.add(((ModelBase<?>) transformer).getModelData());
}
}
List<List<Row>> allModelDataRows;
if (!allModelData.isEmpty()) {
allModelDataRows = BatchOperator.collect(allModelData.toArray(new BatchOperator<?>[0]));
} else {
allModelDataRows = new ArrayList<>();
}
TableSchema schema = inputSchema;
int numMapperModel = 0;
List<Mapper> mappers = new ArrayList<>();
for (TransformerBase<?> transformer : transformers) {
Mapper mapper;
if (transformer instanceof MapModel) {
mapper = ModelExporterUtils.createMapperFromStage(transformer, ((MapModel<?>) transformer).modelData.getSchema(), schema, allModelDataRows.get(numMapperModel));
numMapperModel++;
} else if (transformer instanceof BaseRecommender) {
mapper = ModelExporterUtils.createMapperFromStage(transformer, ((BaseRecommender<?>) transformer).modelData.getSchema(), schema, allModelDataRows.get(numMapperModel));
numMapperModel++;
} else {
mapper = ModelExporterUtils.createMapperFromStage(transformer, null, schema, null);
}
mappers.add(mapper);
schema = mapper.getOutputSchema();
}
return new LocalPredictor(mappers.toArray(new Mapper[0]));
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class BaseEasyTransferTrainBatchOp method linkFrom.
@Override
public T linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = inputs[0];
Params params = getParams();
TaskType taskType = params.get(HasTaskType.TASK_TYPE);
String labelCol = params.get(HasLabelCol.LABEL_COL);
TypeInformation<?> labelType = TableUtil.findColTypeWithAssertAndHint(in.getSchema(), labelCol);
DataSet<List<Object>> sortedLabels = null;
if (TaskType.CLASSIFICATION.equals(taskType)) {
sortedLabels = in.select(labelCol).distinct().getDataSet().reduceGroup(new SortLabelsReduceGroupFunction());
in = mapLabelToIntIndex(in, labelCol, sortedLabels);
}
BertTokenizer bertTokenizer = new BertTokenizer(params.clone()).set(HasMaxSeqLengthDefaultAsNull.MAX_SEQ_LENGTH, params.get(HasMaxSeqLength.MAX_SEQ_LENGTH));
PipelineModel preprocessPipelineMode = new PipelineModel(bertTokenizer);
in = preprocessPipelineMode.transform(in);
BatchOperator<?> preprocessPipelineModelOp = preprocessPipelineMode.save();
String preprocessPipelineModelSchemaStr = CsvUtil.schema2SchemaStr(preprocessPipelineModelOp.getSchema());
Map<String, String> userParams = new HashMap<>();
String bertModelName = params.get(HasBertModelName.BERT_MODEL_NAME);
String bertModelCkptPath = params.contains(HasModelPath.MODEL_PATH) && (null != params.get(HasModelPath.MODEL_PATH)) ? params.get(HasModelPath.MODEL_PATH) : BertResources.getBertModelCkpt(bertModelName);
String checkpointFilePath = params.get(HasCheckpointFilePathDefaultAsNull.CHECKPOINT_FILE_PATH);
if (!StringUtils.isNullOrWhitespaceOnly(checkpointFilePath)) {
userParams.put("model_dir", checkpointFilePath);
}
ExternalFilesConfig externalFilesConfig = params.contains(HasUserFiles.USER_FILES) ? ExternalFilesConfig.fromJson(params.get(HasUserFiles.USER_FILES)) : new ExternalFilesConfig();
if (PythonFileUtils.isLocalFile(bertModelCkptPath)) {
// should be a directory
userParams.put("pretrained_ckpt_path", bertModelCkptPath.substring("file://".length()));
} else {
externalFilesConfig.addFilePaths(bertModelCkptPath);
userParams.put("pretrained_ckpt_path", PythonFileUtils.getCompressedFileName(bertModelCkptPath));
}
Map<String, Map<String, Object>> config = getConfig(getParams(), false);
String configJson = JsonConverter.toJson(config);
LOG.info("EasyTransfer config: {}", configJson);
if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
System.out.println("EasyTransfer config: " + configJson);
}
BertTaskName taskName = params.get(HasTaskName.TASK_NAME);
userParams.put("app_name", taskName.name());
EasyTransferConfigTrainBatchOp trainBatchOp = new EasyTransferConfigTrainBatchOp().setSelectedCols(ArrayUtils.add(SAFE_MODEL_INPUTS, labelCol)).setConfigJson(configJson).setUserFiles(externalFilesConfig).setUserParams(JsonConverter.toJson(userParams)).setNumWorkers(params.get(HasNumWorkersDefaultAsNull.NUM_WORKERS)).setNumPSs(params.get(HasNumPssDefaultAsNull.NUM_PSS)).setPythonEnv(params.get(HasPythonEnv.PYTHON_ENV)).setIntraOpParallelism(params.get(HasIntraOpParallelism.INTRA_OP_PARALLELISM)).setMLEnvironmentId(getMLEnvironmentId());
BatchOperator<?>[] tfInputs;
tfInputs = new BatchOperator<?>[inputs.length];
tfInputs[0] = in;
System.arraycopy(inputs, 1, tfInputs, 1, inputs.length - 1);
BatchOperator<?> tfModel = trainBatchOp.linkFrom(tfInputs);
String tfOutputSignatureDef = EasyTransferUtils.getTfOutputSignatureDef(taskType);
MapPartitionOperator<Row, Row> constructModelMapPartitionOperator = tfModel.getDataSet().partitionCustom(new Partitioner<Long>() {
@Override
public int partition(Long key, int numPartitions) {
return 0;
}
}, 0).mapPartition(new ConstructModelMapPartitionFunction(params, SAFE_MODEL_INPUTS, tfOutputSignatureDef, TF_OUTPUT_SIGNATURE_TYPE, preprocessPipelineModelSchemaStr)).withBroadcastSet(preprocessPipelineModelOp.getDataSet(), PREPROCESS_PIPELINE_MODEL_BC_NAME);
DataSet<Row> modelDataSet = TaskType.CLASSIFICATION.equals(taskType) ? constructModelMapPartitionOperator.withBroadcastSet(sortedLabels, SORTED_LABELS_BC_NAME) : constructModelMapPartitionOperator;
BatchOperator<?> modelOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), modelDataSet, new TFTableModelClassificationModelDataConverter(labelType).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId());
this.setOutputTable(modelOp.getOutputTable());
return (T) this;
}
Aggregations