use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class BaseKerasSequentialTrainBatchOp method linkFrom.
@Override
public T linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = inputs[0];
Params params = getParams();
TaskType taskType = params.get(HasTaskType.TASK_TYPE);
boolean isReg = TaskType.REGRESSION.equals(taskType);
String tensorCol = getTensorCol();
String labelCol = getLabelCol();
TypeInformation<?> labelType = TableUtil.findColTypeWithAssertAndHint(in.getSchema(), labelCol);
DataSet<List<Object>> sortedLabels = null;
BatchOperator<?> numLabelsOp = null;
if (!isReg) {
sortedLabels = in.select(labelCol).getDataSet().mapPartition(new MapPartitionFunction<Row, Object>() {
@Override
public void mapPartition(Iterable<Row> iterable, Collector<Object> collector) throws Exception {
Set<Object> distinctValue = new HashSet<>();
for (Row row : iterable) {
distinctValue.add(row.getField(0));
}
for (Object obj : distinctValue) {
collector.collect(obj);
}
}
}).reduceGroup(new GroupReduceFunction<Object, List<Object>>() {
@Override
public void reduce(Iterable<Object> iterable, Collector<List<Object>> collector) throws Exception {
Set<Object> distinctValue = new TreeSet<>();
for (Object obj : iterable) {
distinctValue.add(obj);
}
collector.collect(new ArrayList<>(distinctValue));
}
});
in = CommonUtils.mapLabelToIndex(in, labelCol, sortedLabels);
numLabelsOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), sortedLabels.map(new CountLabelsMapFunction()), new String[] { "count" }, new TypeInformation<?>[] { Types.INT })).setMLEnvironmentId(getMLEnvironmentId());
}
Boolean removeCheckpointBeforeTraining = getRemoveCheckpointBeforeTraining();
if (null == removeCheckpointBeforeTraining) {
// default to clean checkpoint
removeCheckpointBeforeTraining = true;
}
Map<String, Object> modelConfig = new HashMap<>();
modelConfig.put("layers", getLayers());
Map<String, String> userParams = new HashMap<>();
if (removeCheckpointBeforeTraining) {
userParams.put(DLConstants.REMOVE_CHECKPOINT_BEFORE_TRAINING, "true");
}
userParams.put("tensor_cols", JsonConverter.toJson(new String[] { tensorCol }));
userParams.put("label_col", labelCol);
userParams.put("label_type", "float");
userParams.put("batch_size", String.valueOf(getBatchSize()));
userParams.put("num_epochs", String.valueOf(getNumEpochs()));
userParams.put("model_config", JsonConverter.toJson(modelConfig));
userParams.put("optimizer", getOptimizer());
if (!StringUtils.isNullOrWhitespaceOnly(getCheckpointFilePath())) {
userParams.put("model_dir", getCheckpointFilePath());
}
ExecutionEnvironment env = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment();
if (env.getParallelism() == 1) {
userParams.put("ALINK:ONLY_ONE_WORKER", "true");
}
userParams.put("validation_split", String.valueOf(getValidationSplit()));
userParams.put("save_best_only", String.valueOf(getSaveBestOnly()));
userParams.put("best_exporter_metric", getBestMetric());
userParams.put("save_checkpoints_epochs", String.valueOf(getSaveCheckpointsEpochs()));
if (params.contains(BaseKerasSequentialTrainParams.SAVE_CHECKPOINTS_SECS)) {
userParams.put("save_checkpoints_secs", String.valueOf(getSaveCheckpointsSecs()));
}
TF2TableModelTrainBatchOp trainBatchOp = new TF2TableModelTrainBatchOp(params).setSelectedCols(tensorCol, labelCol).setUserFiles(RES_PY_FILES).setMainScriptFile(MAIN_SCRIPT_FILE_NAME).setUserParams(JsonConverter.toJson(userParams)).setIntraOpParallelism(getIntraOpParallelism()).setNumPSs(getNumPSs()).setNumWorkers(getNumWorkers()).setPythonEnv(params.get(HasPythonEnv.PYTHON_ENV));
if (isReg) {
trainBatchOp = trainBatchOp.linkFrom(in);
} else {
trainBatchOp = trainBatchOp.linkFrom(in, numLabelsOp);
}
String tfOutputSignatureDef = getTfOutputSignatureDef(taskType);
FlatMapOperator<Row, Row> constructModelFlatMapOperator = new NumSeqSourceBatchOp().setFrom(0).setTo(0).setMLEnvironmentId(getMLEnvironmentId()).getDataSet().flatMap(new ConstructModelFlatMapFunction(params, new String[] { tensorCol }, tfOutputSignatureDef, TF_OUTPUT_SIGNATURE_TYPE, null, true)).withBroadcastSet(trainBatchOp.getDataSet(), CommonUtils.TF_MODEL_BC_NAME);
BatchOperator<?> modelOp;
if (isReg) {
DataSet<Row> modelDataSet = constructModelFlatMapOperator;
modelOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), modelDataSet, new TFTableModelRegressionModelDataConverter(labelType).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId());
} else {
DataSet<Row> modelDataSet = constructModelFlatMapOperator.withBroadcastSet(sortedLabels, CommonUtils.SORTED_LABELS_BC_NAME);
modelOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), modelDataSet, new TFTableModelClassificationModelDataConverter(labelType).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId());
}
this.setOutputTable(modelOp.getOutputTable());
return (T) this;
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class WordCountUtil method trans.
private static BatchOperator<?> trans(BatchOperator<?> in, String[] selectedColNames, String[] keepColNames, BatchOperator<?> indexedVocab, String wordColName, String idxColName, boolean isWord, String wordDelimiter) {
String[] colnames = in.getColNames();
TypeInformation<?>[] coltypes = in.getColTypes();
int[] colIdxs = findColIdx(selectedColNames, colnames, coltypes);
int[] appendIdxs = findAppendColIdx(keepColNames, colnames);
// only 2 cols: word, idx
DataSet<Row> voc = indexedVocab.select(wordColName + "," + idxColName).getDataSet();
DataSet<Row> contentMapping = in.getDataSet().map(new GenContentMapping(colIdxs, appendIdxs, isWord, wordDelimiter)).withBroadcastSet(voc, "vocabulary");
int transColSize = colIdxs.length;
int keepColSize = keepColNames == null ? 0 : keepColNames.length;
int outputColSize = transColSize + keepColSize;
String[] names = new String[outputColSize];
TypeInformation<?>[] types = new TypeInformation<?>[outputColSize];
int i = 0;
for (; i < transColSize; ++i) {
names[i] = colnames[colIdxs[i]];
types[i] = isWord ? Types.DOUBLE : Types.STRING;
}
for (; i < outputColSize; ++i) {
names[i] = colnames[appendIdxs[i - transColSize]];
types[i] = coltypes[appendIdxs[i - transColSize]];
}
return new TableSourceBatchOp(DataSetConversionUtil.toTable(in.getMLEnvironmentId(), contentMapping, names, types)).setMLEnvironmentId(in.getMLEnvironmentId());
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class LdaTrainBatchOp method linkFrom.
@Override
public LdaTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
int parallelism = BatchOperator.getExecutionEnvironmentFromOps(in).getParallelism();
long mlEnvId = getMLEnvironmentId();
int numTopic = getTopicNum();
int numIter = getNumIter();
Integer seed = getRandomSeed();
boolean setSeed = true;
if (seed == null) {
setSeed = false;
}
String vectorColName = getSelectedCol();
Method optimizer = getMethod();
final DataSet<DocCountVectorizerModelData> resDocCountModel = DocCountVectorizerTrainBatchOp.generateDocCountModel(getParams(), in);
int index = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), vectorColName);
DataSet<Row> resRow = in.getDataSet().flatMap(new Document2Vector(index)).withBroadcastSet(resDocCountModel, "DocCountModel");
TypeInformation<?>[] types = in.getColTypes().clone();
types[index] = TypeInformation.of(SparseVector.class);
BatchOperator trainData = new TableSourceBatchOp(DataSetConversionUtil.toTable(mlEnvId, resRow, in.getColNames(), types)).setMLEnvironmentId(mlEnvId);
Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> dataAndStat = StatisticsHelper.summaryHelper(trainData, null, vectorColName);
if (setSeed) {
DataSet<Tuple2<Long, Vector>> hashValue = dataAndStat.f0.map(new MapHashValue(seed)).partitionCustom(new Partitioner<Long>() {
private static final long serialVersionUID = 5179898093029365608L;
@Override
public int partition(Long key, int numPartitions) {
return (int) (Math.abs(key) % ((long) numPartitions));
}
}, 0);
dataAndStat.f0 = hashValue.mapPartition(new MapPartitionFunction<Tuple2<Long, Vector>, Vector>() {
private static final long serialVersionUID = -550512476573928350L;
@Override
public void mapPartition(Iterable<Tuple2<Long, Vector>> values, Collector<Vector> out) throws Exception {
List<Tuple2<Long, Vector>> listValues = Lists.newArrayList(values);
listValues.sort(new Comparator<Tuple2<Long, Vector>>() {
@Override
public int compare(Tuple2<Long, Vector> o1, Tuple2<Long, Vector> o2) {
int compare1 = o1.f0.compareTo(o2.f0);
if (compare1 == 0) {
String o1s = o1.f1.toString();
String o2s = o2.f1.toString();
return o1s.compareTo(o2s);
}
return compare1;
}
});
listValues.forEach(x -> out.collect(x.f1));
}
}).setParallelism(parallelism);
}
double beta = getParams().get(BETA);
double alpha = getParams().get(ALPHA);
int gammaShape = 250;
switch(optimizer) {
case EM:
gibbsSample(dataAndStat, numTopic, numIter, alpha, beta, resDocCountModel, seed);
break;
case Online:
online(dataAndStat, numTopic, numIter, alpha, beta, resDocCountModel, gammaShape, seed);
break;
default:
throw new NotImplementedException("Optimizer not support.");
}
return this;
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class GraphEmbedding method trans2Index.
/**
* Transform vertex with index
* vocab, schema{NODE_COL:originalType, NODE_INDEX_COL:long}
* indexedGraph,schema {SOURCE_COL:long, TARGET_COL:long, WEIGHT_COL:double}
* indexWithType, if in2 is not null, then returned, schema {NODE_INDEX_COL:long, NODE_TYPE_COL:string}
*
* @param in1 is graph data
* @param in2 is the vertexList with vertexType, optional
* @param params user inputted parameters
* @return
*/
public static BatchOperator[] trans2Index(BatchOperator in1, BatchOperator in2, Params params) {
String sourceColName = params.get(HasSourceCol.SOURCE_COL);
String targetColName = params.get(HasTargetCol.TARGET_COL);
String clause;
if (params.contains(HasWeightCol.WEIGHT_COL)) {
String weightColName = params.get(HasWeightCol.WEIGHT_COL);
clause = "`" + sourceColName + "`, `" + targetColName + "`, `" + weightColName + "`";
} else {
clause = "`" + sourceColName + "`, `" + targetColName + "`, 1.0";
}
BatchOperator in = in1.select(clause).as(SOURCE_COL + ", " + TARGET_COL + ", " + WEIGHT_COL);
// count the times that all the words appear in the edges.
BatchOperator wordCnt = WordCountUtil.count(new UnionAllBatchOp().setMLEnvironmentId(in1.getMLEnvironmentId()).linkFrom(in.select(SOURCE_COL), in.select(TARGET_COL)).as(NODE_COL), NODE_COL);
// name each vocab with its index.
BatchOperator vocab = WordCountUtil.randomIndexVocab(wordCnt, 0).select(WordCountUtil.WORD_COL_NAME + " AS " + NODE_COL + ", " + WordCountUtil.INDEX_COL_NAME + " AS " + NODE_INDEX_COL);
// transform input and vocab to dataSet<Tuple>
DataSet<Tuple> inDataSet = in.getDataSet().map(new MapFunction<Row, Tuple3<Comparable, Comparable, Comparable>>() {
private static final long serialVersionUID = 8473819294214049730L;
@Override
public Tuple3<Comparable, Comparable, Comparable> map(Row value) throws Exception {
return Tuple3.of((Comparable) value.getField(0), (Comparable) value.getField(1), (Comparable) value.getField(2));
}
});
DataSet<Tuple2> vocabDataSet = vocab.getDataSet().map(new MapFunction<Row, Tuple2<Comparable, Long>>() {
private static final long serialVersionUID = 7241884458236714150L;
@Override
public Tuple2<Comparable, Long> map(Row value) throws Exception {
return Tuple2.of((Comparable) value.getField(0), (Long) value.getField(1));
}
});
// join operation
DataSet<Tuple> joinWithSourceColTuple = HackBatchOpJoin.join(inDataSet, vocabDataSet, 0, 0, new int[][] { { 1, 1 }, { 0, 1 }, { 0, 2 } });
DataSet<Tuple> indexGraphTuple = HackBatchOpJoin.join(joinWithSourceColTuple, vocabDataSet, 1, 0, new int[][] { { 0, 0 }, { 1, 1 }, { 0, 2 } });
// build batchOperator
TypeInformation<?>[] inTypes = in.getColTypes();
TypeInformation<?>[] vocabTypes = vocab.getColTypes();
BatchOperator indexedGraphBatchOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(in.getMLEnvironmentId(), indexGraphTuple.map(new MapFunction<Tuple, Row>() {
private static final long serialVersionUID = -5386264086074581748L;
@Override
public Row map(Tuple value) throws Exception {
Row res = new Row(3);
res.setField(0, value.getField(0));
res.setField(1, value.getField(1));
res.setField(2, value.getField(2));
return res;
}
}), new String[] { SOURCE_COL, TARGET_COL, WEIGHT_COL }, new TypeInformation<?>[] { vocabTypes[1], vocabTypes[1], inTypes[2] }));
if (null == in2) {
return new BatchOperator[] { vocab, indexedGraphBatchOp };
} else {
BatchOperator in2Selected = in2.select("`" + params.get(HasVertexCol.VERTEX_COL) + "`, `" + params.get(HasTypeCol.TYPE_COL) + "`").as(TEMP_NODE_COL + ", " + NODE_TYPE_COL);
TypeInformation<?>[] types = new TypeInformation[2];
types[1] = in2.getColTypes()[TableUtil.findColIndex(in2.getSchema(), params.get(HasTypeCol.TYPE_COL))];
types[0] = vocab.getColTypes()[TableUtil.findColIndex(vocab.getSchema(), NODE_INDEX_COL)];
DataSet<Tuple> in2Tuple = in2Selected.getDataSet().map(new MapFunction<Row, Tuple2<Comparable, Comparable>>() {
private static final long serialVersionUID = 3459700988499538679L;
@Override
public Tuple2<Comparable, Comparable> map(Row value) throws Exception {
Tuple2<Comparable, Comparable> res = new Tuple2<>();
res.setField(value.getField(0), 0);
res.setField(value.getField(1), 1);
return res;
}
});
DataSet<Row> indexWithTypeRow = HackBatchOpJoin.join(in2Tuple, vocabDataSet, 0, 0, new int[][] { { 1, 1 }, { 0, 1 } }).map(new MapFunction<Tuple, Row>() {
private static final long serialVersionUID = -5747375637774394150L;
@Override
public Row map(Tuple value) throws Exception {
int length = value.getArity();
Row res = new Row(length);
for (int i = 0; i < length; i++) {
res.setField(i, value.getField(i));
}
return res;
}
});
BatchOperator indexWithType = new TableSourceBatchOp(DataSetConversionUtil.toTable(in.getMLEnvironmentId(), indexWithTypeRow, new String[] { NODE_INDEX_COL, NODE_TYPE_COL }, types)).setMLEnvironmentId(in.getMLEnvironmentId());
return new BatchOperator[] { vocab, indexedGraphBatchOp, indexWithType };
}
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class ModelExporterUtils method postOrderDeserialize.
private static <T extends PipelineStageBase<?>> List<T> postOrderDeserialize(StageNode[] stages, BatchOperator<?> unpacked, final TableSchema schema, final int offset) {
if (stages == null || stages.length == 0) {
return new ArrayList<>();
}
final long[] id = new long[] { stages.length - 1 };
final BatchOperator<?>[] deserialized = new BatchOperator<?>[] { unpacked };
List<T> result = new ArrayList<>();
Consumer<Integer> deserializer = lp -> {
StageNode stageNode = stages[lp];
try {
if (stageNode.identifier != null) {
stageNode.stage = (PipelineStageBase<?>) Class.forName(stageNode.identifier).getConstructor(Params.class).newInstance(stageNode.params);
}
} catch (ClassNotFoundException | NoSuchMethodException | InstantiationException | IllegalAccessException | InvocationTargetException ex) {
throw new IllegalArgumentException(ex);
}
// leaf node.
if (stageNode.children == null) {
if (stageNode.parent >= 0 && stageNode.schemaIndices != null && stageNode.colNames != null) {
final long localId = id[0];
final int[] localSchemaIndices = stageNode.schemaIndices;
BatchOperator<?> model = new TableSourceBatchOp(DataSetConversionUtil.toTable(deserialized[0].getMLEnvironmentId(), deserialized[0].getDataSet().filter(new FilterFunction<Row>() {
private static final long serialVersionUID = 355683133177055891L;
@Override
public boolean filter(Row value) {
return value.getField(0).equals(localId);
}
}).map(new MapFunction<Row, Row>() {
private static final long serialVersionUID = -4286266312978550037L;
@Override
public Row map(Row value) throws Exception {
Row ret = new Row(localSchemaIndices.length);
for (int i = 0; i < localSchemaIndices.length; ++i) {
ret.setField(i, value.getField(localSchemaIndices[i] + offset));
}
return ret;
}
}).returns(new RowTypeInfo(stageNode.types)), new TableSchema(stageNode.colNames, stageNode.types))).setMLEnvironmentId(deserialized[0].getMLEnvironmentId());
((ModelBase<?>) stageNode.stage).setModelData(model);
deserialized[0] = new TableSourceBatchOp(DataSetConversionUtil.toTable(deserialized[0].getMLEnvironmentId(), deserialized[0].getDataSet().filter(new FilterFunction<Row>() {
private static final long serialVersionUID = -2803966833769030531L;
@Override
public boolean filter(Row value) {
return !value.getField(0).equals(localId);
}
}), schema)).setMLEnvironmentId(deserialized[0].getMLEnvironmentId());
}
} else {
List<T> pipelineStageBases = new ArrayList<>();
for (int i = 0; i < stageNode.children.length; ++i) {
pipelineStageBases.add((T) stages[stageNode.children[i]].stage);
}
if (stageNode.stage == null) {
result.addAll(pipelineStageBases);
return;
}
if (stageNode.stage instanceof Pipeline) {
stageNode.stage = new Pipeline(pipelineStageBases.toArray(new PipelineStageBase<?>[0]));
} else if (stageNode.stage instanceof PipelineModel) {
stageNode.stage = new PipelineModel(pipelineStageBases.toArray(new TransformerBase<?>[0]));
} else {
throw new IllegalArgumentException("Unsupported stage.");
}
}
id[0]--;
};
postOrder(stages, deserializer);
return result;
}
Aggregations