use of com.alibaba.alink.common.utils.OutputColsHelper in project Alink by alibaba.
the class StreamingKMeansStreamOp method linkFrom.
/**
* Update model with stream in1, predict for stream in2
*/
@Override
public StreamingKMeansStreamOp linkFrom(StreamOperator<?>... inputs) {
checkMinOpSize(1, inputs);
StreamOperator<?> in1 = inputs[0];
StreamOperator<?> in2 = inputs[0];
if (inputs.length > 1) {
in2 = inputs[1];
}
if (!this.getParams().contains(HasPredictionCol.PREDICTION_COL)) {
this.setPredictionCol("cluster_id");
}
/**
* time interval for updating the model, in seconds
*/
final long timeInterval = getParams().get(TIME_INTERVAL);
final long halfLife = getParams().get(HALF_LIFE);
final double decayFactor = Math.pow(0.5, (double) timeInterval / (double) halfLife);
try {
DataStream<Row> trainingData = in1.getDataStream();
DataStream<Row> predictData = in2.getDataStream();
PredType predType = PredType.fromInputs(getParams());
OutputColsHelper outputColsHelper = null;
switch(predType) {
case PRED:
{
outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol() }, new TypeInformation[] { Types.LONG }, this.getReservedCols());
break;
}
case PRED_CLUS:
{
outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol(), getPredictionClusterCol() }, new TypeInformation[] { Types.LONG, VectorTypes.DENSE_VECTOR }, this.getReservedCols());
break;
}
case PRED_DIST:
{
outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol(), getPredictionDistanceCol() }, new TypeInformation[] { Types.LONG, Types.DOUBLE }, this.getReservedCols());
break;
}
case PRED_CLUS_DIST:
{
outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { this.getPredictionCol(), getPredictionClusterCol(), getPredictionDistanceCol() }, new TypeInformation[] { Types.LONG, VectorTypes.DENSE_VECTOR, Types.DOUBLE }, this.getReservedCols());
}
}
// for direct read
DataBridge modelDataBridge = DirectReader.collect(batchModel);
// incremental train on every window of data
DataStream<Tuple3<DenseVector[], int[], Long>> updateData = trainingData.flatMap(new CollectUpdateData(modelDataBridge, in1.getColNames(), timeInterval)).name("local_aggregate");
int taskNum = updateData.getParallelism();
DataStream<KMeansTrainModelData> streamModel = updateData.flatMap(new AllDataMerge(taskNum)).name("global_aggregate").setParallelism(1).map(new UpdateModelOp(modelDataBridge, decayFactor)).name("update_model").setParallelism(1);
// predict
DataStream<Row> predictResult = predictData.connect(streamModel.broadcast()).flatMap(new PredictOp(modelDataBridge, in2.getColNames(), outputColsHelper, predType)).name("kmeans_prediction");
this.setOutput(predictResult, outputColsHelper.getResultSchema());
this.setSideOutputTables(outputModel(streamModel, getMLEnvironmentId()));
return this;
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e.getMessage());
}
}
use of com.alibaba.alink.common.utils.OutputColsHelper in project Alink by alibaba.
the class KeywordsExtractionStreamOp method linkFrom.
@Override
public KeywordsExtractionStreamOp linkFrom(StreamOperator<?>... inputs) {
StreamOperator<?> in = checkAndGetFirst(inputs);
String selectedColName = this.getSelectedCol();
int textColIndex = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), selectedColName);
String outputColName = this.getOutputCol();
if (null == outputColName) {
outputColName = selectedColName;
}
OutputColsHelper outputColsHelper = new OutputColsHelper(in.getSchema(), new String[] { outputColName }, new TypeInformation[] { org.apache.flink.table.api.Types.STRING() }, in.getColNames());
DataStream<Row> res = in.getDataStream().map(new KeywordsExtractionMap(this.getParams(), textColIndex, outputColsHelper));
// Set the output into table.
this.setOutput(res, outputColsHelper.getResultSchema());
return this;
}
use of com.alibaba.alink.common.utils.OutputColsHelper in project Alink by alibaba.
the class HugeIndexerStringPredictBatchOp method linkFrom.
@Override
public HugeIndexerStringPredictBatchOp linkFrom(BatchOperator<?>... inputs) {
Params params = super.getParams();
BatchOperator model = inputs[0];
BatchOperator data = inputs[1];
String[] selectedColNames = params.get(HugeMultiStringIndexerPredictParams.SELECTED_COLS);
String[] outputColNames = params.get(HugeMultiStringIndexerPredictParams.OUTPUT_COLS);
if (outputColNames == null) {
outputColNames = selectedColNames;
}
String[] keepColNames = params.get(HugeMultiStringIndexerPredictParams.RESERVED_COLS);
TypeInformation[] outputColTypes = new TypeInformation[outputColNames.length];
TypeInformation[] inputColTypes = TableUtil.findColTypesWithAssert(data.getSchema(), selectedColNames);
Arrays.fill(outputColTypes, Types.STRING);
OutputColsHelper outputColsHelper = new OutputColsHelper(data.getSchema(), outputColNames, outputColTypes, keepColNames);
final int[] selectedColIdx = TableUtil.findColIndicesWithAssertAndHint(data.getSchema(), selectedColNames);
final HandleInvalid handleInvalidStrategy = HandleInvalid.valueOf(params.get(HugeMultiStringIndexerPredictParams.HANDLE_INVALID).toString());
DataSet<Tuple2<Long, Row>> dataWithId = DataSetUtils.zipWithUniqueId(data.getDataSet());
DataSet<Tuple2<String, Long>> modelData = getModelData(model);
// tuple: record id, column index, column array index, token
DataSet<Tuple4<Long, Integer, Integer, Long>> flattened = dataWithId.flatMap(new RichFlatMapFunction<Tuple2<Long, Row>, Tuple4<Long, Integer, Integer, Long>>() {
private static final long serialVersionUID = -8382461068855755626L;
@Override
public void flatMap(Tuple2<Long, Row> value, Collector<Tuple4<Long, Integer, Integer, Long>> out) throws Exception {
for (int i = 0; i < selectedColIdx.length; i++) {
Object o = value.f1.getField(selectedColIdx[i]);
if (null == o) {
out.collect(Tuple4.of(value.f0, i, 0, -1L));
} else {
if (inputColTypes[i].isBasicType()) {
out.collect(Tuple4.of(value.f0, i, 0, (Long) o));
} else {
Long[] ids = (Long[]) o;
for (int j = 0; j < ids.length; j++) {
out.collect(Tuple4.of(value.f0, i, j, ids[j]));
}
}
}
}
}
}).name("flatten_pred_data").returns(new TupleTypeInfo<>(Types.LONG, Types.INT, Types.INT, Types.LONG));
// record id, column index, token index
DataSet<Tuple4<Long, Integer, Integer, String>> indexed = flattened.leftOuterJoin(modelData).where(3).equalTo(1).with(new JoinFunction<Tuple4<Long, Integer, Integer, Long>, Tuple2<String, Long>, Tuple4<Long, Integer, Integer, String>>() {
private static final long serialVersionUID = 2270459281179536013L;
@Override
public Tuple4<Long, Integer, Integer, String> join(Tuple4<Long, Integer, Integer, Long> first, Tuple2<String, Long> second) throws Exception {
if (second == null) {
return Tuple4.of(first.f0, first.f1, first.f2, "notFound");
} else {
return Tuple4.of(first.f0, first.f1, first.f2, second.f0);
}
}
}).name("map_index_to_token").returns(new TupleTypeInfo<>(Types.LONG, Types.INT, Types.INT, Types.STRING));
// tuple: record id, prediction result
DataSet<Tuple2<Long, Row>> aggregateResult = indexed.groupBy(0).reduceGroup(new GroupReduceFunction<Tuple4<Long, Integer, Integer, String>, Tuple2<Long, Row>>() {
private static final long serialVersionUID = -1581264399340055162L;
@Override
public void reduce(Iterable<Tuple4<Long, Integer, Integer, String>> values, Collector<Tuple2<Long, Row>> out) throws Exception {
Long id = null;
Row r = new Row(selectedColIdx.length);
ArrayList<Tuple3<Integer, Integer, String>> list = new ArrayList<>();
for (Tuple4<Long, Integer, Integer, String> v : values) {
list.add(Tuple3.of(v.f1, v.f2, v.f3));
id = v.f0;
}
list.sort(new Comparator<Tuple3<Integer, Integer, String>>() {
@Override
public int compare(Tuple3<Integer, Integer, String> o1, Tuple3<Integer, Integer, String> o2) {
if (o1.f0.equals(o2.f0)) {
return o1.f1.compareTo(o2.f1);
}
return o1.f0.compareTo(o2.f0);
}
});
ArrayList<String> allFeatures = new ArrayList<>(list.size());
for (Tuple3<Integer, Integer, String> v : list) {
allFeatures.add(v.f2);
}
String[] originFeatures = new String[selectedColIdx.length];
int startIndex = 0, endIndex = 0;
int lastIndex = 0;
for (int i = 0; i < list.size(); i++) {
Tuple3<Integer, Integer, String> v = list.get(i);
if (lastIndex != v.f0) {
originFeatures[lastIndex] = StringUtils.join(allFeatures.subList(startIndex, endIndex), ",");
lastIndex = v.f0;
startIndex = i;
endIndex = i;
}
endIndex += 1;
if (i == list.size() - 1) {
originFeatures[lastIndex] = StringUtils.join(allFeatures.subList(startIndex, endIndex), ",");
}
}
for (int i = 0; i < originFeatures.length; i++) {
r.setField(i, originFeatures[i]);
}
out.collect(Tuple2.of(id, r));
}
}).name("aggregate_result").returns(new TupleTypeInfo<>(Types.LONG, new RowTypeInfo(outputColTypes)));
DataSet<Row> output = dataWithId.join(aggregateResult).where(0).equalTo(0).with(new JoinFunction<Tuple2<Long, Row>, Tuple2<Long, Row>, Row>() {
private static final long serialVersionUID = 3724539437313089427L;
@Override
public Row join(Tuple2<Long, Row> first, Tuple2<Long, Row> second) throws Exception {
return outputColsHelper.getResultRow(first.f1, second.f1);
}
}).name("merge_result").returns(new RowTypeInfo(outputColsHelper.getResultSchema().getFieldTypes()));
this.setOutput(output, outputColsHelper.getResultSchema());
return this;
}
use of com.alibaba.alink.common.utils.OutputColsHelper in project Alink by alibaba.
the class HugeLookupBatchOp method linkFrom.
@Override
public HugeLookupBatchOp linkFrom(BatchOperator<?>... inputs) {
checkOpSize(2, inputs);
BatchOperator<?> model = inputs[0];
BatchOperator<?> data = inputs[1];
final String[] mapKeyColNames = getMapKeyCols();
final String[] mapValueColNames = getMapValueCols();
final String[] selectedColNames = getSelectedCols();
final String[] reservedColNames = getReservedCols();
String[] outputColNames = getOutputCols();
TableSchema modelSchema = model.getSchema();
TableSchema dataSchema = data.getSchema();
if (modelSchema.getFieldNames().length != 2 && (mapKeyColNames == null || mapValueColNames == null)) {
throw new RuntimeException("LookUp err : mapKeyCols and mapValueCols should set in parameters.");
}
final int[] selectedColIndices = TableUtil.findColIndicesWithAssertAndHint(dataSchema, selectedColNames);
;
final int[] mapKeyColIndices = (mapKeyColNames != null) ? TableUtil.findColIndicesWithAssertAndHint(modelSchema, mapKeyColNames) : new int[] { 0 };
final int[] mapValueColIndices = (mapValueColNames != null) ? TableUtil.findColIndicesWithAssertAndHint(modelSchema, mapValueColNames) : new int[] { 1 };
for (int i = 0; i < selectedColNames.length; ++i) {
if (mapKeyColNames != null && mapValueColNames != null) {
if (TableUtil.findColTypeWithAssertAndHint(dataSchema, selectedColNames[i]) != TableUtil.findColTypeWithAssertAndHint(modelSchema, mapKeyColNames[i])) {
throw new IllegalArgumentException("Data types are not match. selected column type is " + TableUtil.findColTypeWithAssertAndHint(dataSchema, selectedColNames[i]) + " , and the map key column type is " + TableUtil.findColTypeWithAssertAndHint(modelSchema, mapKeyColNames[i]));
}
}
}
if (null == outputColNames) {
outputColNames = mapValueColNames;
}
final TypeInformation<?>[] outputColTypes = (mapValueColNames == null) ? TableUtil.findColTypesWithAssertAndHint(modelSchema, new String[] { modelSchema.getFieldNames()[1] }) : TableUtil.findColTypesWithAssertAndHint(modelSchema, mapValueColNames);
final OutputColsHelper predResultColsHelper = new OutputColsHelper(dataSchema, outputColNames, outputColTypes, reservedColNames);
DataSet<Row> result = data.getDataSet().leftOuterJoin(model.getDataSet(), JoinHint.REPARTITION_SORT_MERGE).where(selectedColIndices).equalTo(mapKeyColIndices).with(new JoinFunction<Row, Row, Row>() {
@Override
public Row join(Row first, Row second) {
Row result = new Row(mapValueColIndices.length);
if (second != null) {
for (int i = 0; i < mapValueColIndices.length; ++i) {
result.setField(i, second.getField(mapValueColIndices[i]));
}
}
return predResultColsHelper.getResultRow(first, result);
}
});
setOutput(result, predResultColsHelper.getResultSchema());
return this;
}
use of com.alibaba.alink.common.utils.OutputColsHelper in project Alink by alibaba.
the class HugeMultiIndexerStringPredictBatchOp method linkFrom.
@Override
public HugeMultiIndexerStringPredictBatchOp linkFrom(BatchOperator<?>... inputs) {
Params params = super.getParams();
BatchOperator model = inputs[0];
BatchOperator data = inputs[1];
String[] selectedColNames = params.get(MultiStringIndexerPredictParams.SELECTED_COLS);
String[] outputColNames = params.get(MultiStringIndexerPredictParams.OUTPUT_COLS);
if (outputColNames == null) {
outputColNames = selectedColNames;
}
String[] keepColNames = params.get(StringIndexerPredictParams.RESERVED_COLS);
TypeInformation[] outputColTypes = new TypeInformation[outputColNames.length];
Arrays.fill(outputColTypes, Types.STRING);
OutputColsHelper outputColsHelper = new OutputColsHelper(data.getSchema(), outputColNames, outputColTypes, keepColNames);
final int[] selectedColIdx = TableUtil.findColIndicesWithAssertAndHint(data.getSchema(), selectedColNames);
final HandleInvalid handleInvalidStrategy = HandleInvalid.valueOf(params.get(StringIndexerPredictParams.HANDLE_INVALID).toString());
DataSet<Tuple2<Long, Row>> dataWithId = DataSetUtils.zipWithUniqueId(data.getDataSet());
DataSet<String> modelMeta = getModelMeta(model);
DataSet<Tuple3<Integer, String, Long>> modelData = getModelData(model, modelMeta, selectedColNames);
// tuple: record id, column index, token
DataSet<Tuple3<Long, Integer, Long>> flattened = dataWithId.flatMap(new RichFlatMapFunction<Tuple2<Long, Row>, Tuple3<Long, Integer, Long>>() {
private static final long serialVersionUID = 7795878509849151894L;
@Override
public void flatMap(Tuple2<Long, Row> value, Collector<Tuple3<Long, Integer, Long>> out) throws Exception {
for (int i = 0; i < selectedColIdx.length; i++) {
Object o = value.f1.getField(selectedColIdx[i]);
if (o != null) {
out.collect(Tuple3.of(value.f0, i, (Long) o));
} else {
out.collect(Tuple3.of(value.f0, i, -1L));
}
}
}
}).name("flatten_pred_data").returns(new TupleTypeInfo<>(Types.LONG, Types.INT, Types.LONG));
// record id, column index, token index
DataSet<Tuple3<Long, Integer, String>> indexed = flattened.leftOuterJoin(modelData).where(1, 2).equalTo(0, 2).with(new JoinFunction<Tuple3<Long, Integer, Long>, Tuple3<Integer, String, Long>, Tuple3<Long, Integer, String>>() {
private static final long serialVersionUID = -3177975102816197011L;
@Override
public Tuple3<Long, Integer, String> join(Tuple3<Long, Integer, Long> first, Tuple3<Integer, String, Long> second) throws Exception {
if (second == null) {
return Tuple3.of(first.f0, first.f1, "null");
} else {
return Tuple3.of(first.f0, first.f1, second.f1);
}
}
}).name("map_index_to_token").returns(new TupleTypeInfo<>(Types.LONG, Types.INT, Types.STRING));
// tuple: record id, prediction result
DataSet<Tuple2<Long, Row>> aggregateResult = indexed.groupBy(0).reduceGroup(new RichGroupReduceFunction<Tuple3<Long, Integer, String>, Tuple2<Long, Row>>() {
private static final long serialVersionUID = 2318140138585310686L;
@Override
public void reduce(Iterable<Tuple3<Long, Integer, String>> values, Collector<Tuple2<Long, Row>> out) throws Exception {
Long id = null;
Row r = new Row(selectedColIdx.length);
for (Tuple3<Long, Integer, String> v : values) {
r.setField(v.f1, v.f2);
id = v.f0;
}
out.collect(Tuple2.of(id, r));
}
}).name("aggregate_result").returns(new TupleTypeInfo<>(Types.LONG, new RowTypeInfo(outputColTypes)));
DataSet<Row> output = dataWithId.join(aggregateResult).where(0).equalTo(0).with(new JoinFunction<Tuple2<Long, Row>, Tuple2<Long, Row>, Row>() {
private static final long serialVersionUID = 3724539437313089427L;
@Override
public Row join(Tuple2<Long, Row> first, Tuple2<Long, Row> second) throws Exception {
return outputColsHelper.getResultRow(first.f1, second.f1);
}
}).name("merge_result").returns(new RowTypeInfo(outputColsHelper.getResultSchema().getFieldTypes()));
this.setOutput(output, outputColsHelper.getResultSchema());
return this;
}
Aggregations