use of com.alibaba.alink.common.utils.OutputColsHelper in project Alink by alibaba.
the class Mapper method initializeSliced.
protected final void initializeSliced() {
if (null != ioSchema) {
OutputColsHelper outputColsHelper = new OutputColsHelper(getDataSchema(), ioSchema.f1, ioSchema.f2, ioSchema.f3);
this.transformer = new MemoryTransformer(TableUtil.findColIndicesWithAssertAndHint(getDataSchema(), outputColsHelper.getReservedColumns()), TableUtil.findColIndicesWithAssertAndHint(getOutputSchema(), outputColsHelper.getReservedColumns()));
this.selection = new SlicedSelectedSampleThreadLocal(TableUtil.findColIndicesWithAssertAndHint(getDataSchema(), ioSchema.f0));
this.result = new SlicedSlicedResultThreadLocal(TableUtil.findColIndicesWithAssertAndHint(getOutputSchema(), ioSchema.f1));
}
}
use of com.alibaba.alink.common.utils.OutputColsHelper in project Alink by alibaba.
the class KeywordsExtractionBatchOp method linkFrom.
@Override
public KeywordsExtractionBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
final String docId = "doc_alink_id";
String selectedColName = this.getSelectedCol();
TableUtil.assertSelectedColExist(in.getColNames(), selectedColName);
String outputColName = this.getOutputCol();
if (null == outputColName) {
outputColName = selectedColName;
}
OutputColsHelper outputColsHelper = new OutputColsHelper(in.getSchema(), outputColName, Types.STRING);
final Integer topN = this.getTopN();
Method method = this.getMethod();
BatchOperator inWithId = new TableSourceBatchOp(AppendIdBatchOp.appendId(in.getDataSet(), in.getSchema(), getMLEnvironmentId())).setMLEnvironmentId(getMLEnvironmentId());
DataSet<Row> weights;
StopWordsRemoverBatchOp filterOp = new StopWordsRemoverBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setSelectedCol(selectedColName).setOutputCol("selectedColName");
BatchOperator filtered = filterOp.linkFrom(inWithId);
switch(method) {
case TF_IDF:
{
DocWordCountBatchOp wordCount = new DocWordCountBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setDocIdCol(AppendIdBatchOp.appendIdColName).setContentCol("selectedColName");
TfidfBatchOp tfIdf = new TfidfBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setDocIdCol(AppendIdBatchOp.appendIdColName).setWordCol("word").setCountCol("cnt");
BatchOperator op = filtered.link(wordCount).link(tfIdf);
weights = op.select(AppendIdBatchOp.appendIdColName + ", " + "word, tfidf").getDataSet();
break;
}
case TEXT_RANK:
{
DataSet<Row> data = filtered.select(AppendIdBatchOp.appendIdColName + ", selectedColName").getDataSet();
// Initialize the TextRank class, which runs the text rank algorithm.
final Params params = getParams();
weights = data.flatMap(new FlatMapFunction<Row, Row>() {
private static final long serialVersionUID = -4083643981693873537L;
@Override
public void flatMap(Row row, Collector<Row> collector) throws Exception {
// For each row, apply the text rank algorithm to get the key words.
Row[] out = TextRank.getKeyWords(row, params.get(KeywordsExtractionParams.DAMPING_FACTOR), params.get(KeywordsExtractionParams.WINDOW_SIZE), params.get(KeywordsExtractionParams.MAX_ITER), params.get(KeywordsExtractionParams.EPSILON));
for (int i = 0; i < out.length; i++) {
collector.collect(out[i]);
}
}
});
break;
}
default:
{
throw new RuntimeException("Not support this type!");
}
}
DataSet<Row> res = weights.groupBy(new KeySelector<Row, String>() {
private static final long serialVersionUID = 801794449492798203L;
@Override
public String getKey(Row row) {
Object obj = row.getField(0);
if (obj == null) {
return "NULL";
}
return row.getField(0).toString();
}
}).reduceGroup(new GroupReduceFunction<Row, Row>() {
private static final long serialVersionUID = -4051509261188494119L;
@Override
public void reduce(Iterable<Row> rows, Collector<Row> collector) {
List<Row> list = new ArrayList<>();
for (Row row : rows) {
list.add(row);
}
Collections.sort(list, new Comparator<Row>() {
@Override
public int compare(Row row1, Row row2) {
Double v1 = (double) row1.getField(2);
Double v2 = (double) row2.getField(2);
return v2.compareTo(v1);
}
});
int len = Math.min(list.size(), topN);
Row out = new Row(2);
StringBuilder builder = new StringBuilder();
for (int i = 0; i < len; i++) {
builder.append(list.get(i).getField(1).toString());
if (i != len - 1) {
builder.append(" ");
}
}
out.setField(0, list.get(0).getField(0));
out.setField(1, builder.toString());
collector.collect(out);
}
});
// Set the output into table.
Table tmpTable = DataSetConversionUtil.toTable(getMLEnvironmentId(), res, new String[] { docId, outputColName }, new TypeInformation[] { Types.LONG, Types.STRING });
StringBuilder selectClause = new StringBuilder("a." + outputColName);
String[] keepColNames = outputColsHelper.getReservedColumns();
for (int i = 0; i < keepColNames.length; i++) {
selectClause.append("," + keepColNames[i]);
}
JoinBatchOp join = new JoinBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setType("join").setSelectClause(selectClause.toString()).setJoinPredicate(docId + "=" + AppendIdBatchOp.appendIdColName);
this.setOutputTable(join.linkFrom(new TableSourceBatchOp(tmpTable).setMLEnvironmentId(getMLEnvironmentId()), inWithId).getOutputTable());
return this;
}
use of com.alibaba.alink.common.utils.OutputColsHelper in project Alink by alibaba.
the class HugeMultiStringIndexerPredictBatchOp method linkFrom.
@Override
public HugeMultiStringIndexerPredictBatchOp linkFrom(BatchOperator<?>... inputs) {
Params params = super.getParams();
BatchOperator model = inputs[0];
BatchOperator data = inputs[1];
String[] selectedColNames = params.get(MultiStringIndexerPredictParams.SELECTED_COLS);
String[] outputColNames = params.get(MultiStringIndexerPredictParams.OUTPUT_COLS);
if (outputColNames == null) {
outputColNames = selectedColNames;
}
String[] keepColNames = params.get(StringIndexerPredictParams.RESERVED_COLS);
TypeInformation[] outputColTypes = new TypeInformation[outputColNames.length];
Arrays.fill(outputColTypes, Types.LONG);
OutputColsHelper outputColsHelper = new OutputColsHelper(data.getSchema(), outputColNames, outputColTypes, keepColNames);
final int[] selectedColIdx = TableUtil.findColIndicesWithAssertAndHint(data.getSchema(), selectedColNames);
final HandleInvalid handleInvalidStrategy = HandleInvalid.valueOf(params.get(StringIndexerPredictParams.HANDLE_INVALID).toString());
DataSet<Tuple2<Long, Row>> dataWithId = DataSetUtils.zipWithUniqueId(data.getDataSet());
DataSet<String> modelMeta = getModelMeta(model);
DataSet<Tuple3<Integer, String, Long>> modelData = getModelData(model, modelMeta, selectedColNames);
// tuple: column index, default token id
DataSet<Tuple2<Integer, Long>> defaultIndex = modelData.<Tuple2<Integer, Long>>project(0, 2).mapPartition(new MapPartitionFunction<Tuple2<Integer, Long>, Tuple2<Integer, Long>>() {
@Override
public void mapPartition(Iterable<Tuple2<Integer, Long>> iterable, Collector<Tuple2<Integer, Long>> collector) throws Exception {
HashMap<Integer, Long> map = new HashMap<>();
for (Tuple2<Integer, Long> value : iterable) {
map.put(value.f0, Math.max(map.getOrDefault(value.f0, 0L), value.f1));
}
map.forEach((key, value) -> collector.collect(Tuple2.of(key, value)));
}
}).groupBy(0).reduce(new ReduceFunction<Tuple2<Integer, Long>>() {
private static final long serialVersionUID = 5053931294560858595L;
@Override
public Tuple2<Integer, Long> reduce(Tuple2<Integer, Long> value1, Tuple2<Integer, Long> value2) throws Exception {
return Tuple2.of(value1.f0, Math.max(value1.f1, value2.f1));
}
}).map(new MapFunction<Tuple2<Integer, Long>, Tuple2<Integer, Long>>() {
private static final long serialVersionUID = 2371384596429653822L;
@Override
public Tuple2<Integer, Long> map(Tuple2<Integer, Long> value) throws Exception {
return Tuple2.of(value.f0, value.f1 + 1L);
}
}).name("get_default_index").returns(new TupleTypeInfo<>(Types.INT, Types.LONG));
// tuple: record id, column index, token
DataSet<Tuple3<Long, Integer, String>> flattened = dataWithId.flatMap(new RichFlatMapFunction<Tuple2<Long, Row>, Tuple3<Long, Integer, String>>() {
private static final long serialVersionUID = -8382461068855755626L;
@Override
public void flatMap(Tuple2<Long, Row> value, Collector<Tuple3<Long, Integer, String>> out) throws Exception {
for (int i = 0; i < selectedColIdx.length; i++) {
Object o = value.f1.getField(selectedColIdx[i]);
if (o != null) {
out.collect(Tuple3.of(value.f0, i, String.valueOf(o)));
}
}
}
}).name("flatten_pred_data").returns(new TupleTypeInfo<>(Types.LONG, Types.INT, Types.STRING));
// tuple: record id, column index, token id
DataSet<Tuple3<Long, Integer, Long>> indexedNulTokens = dataWithId.flatMap(new FlatMapFunction<Tuple2<Long, Row>, Tuple3<Long, Integer, Long>>() {
private static final long serialVersionUID = 4078100010408649546L;
@Override
public void flatMap(Tuple2<Long, Row> value, Collector<Tuple3<Long, Integer, Long>> out) throws Exception {
for (int i = 0; i < selectedColIdx.length; i++) {
Object o = value.f1.getField(selectedColIdx[i]);
if (o == null) {
// because null value is ignored during training, so it will always
// be treated as "unseen" token.
out.collect(Tuple3.of(value.f0, i, -1L));
}
}
}
}).name("map_null_token_to_index").returns(new TupleTypeInfo<>(Types.LONG, Types.INT, Types.LONG));
// record id, column index, token index
DataSet<Tuple3<Long, Integer, Long>> indexed = flattened.leftOuterJoin(modelData).where(1, 2).equalTo(0, 1).with(new JoinFunction<Tuple3<Long, Integer, String>, Tuple3<Integer, String, Long>, Tuple3<Long, Integer, Long>>() {
private static final long serialVersionUID = -2049684621727655644L;
@Override
public Tuple3<Long, Integer, Long> join(Tuple3<Long, Integer, String> first, Tuple3<Integer, String, Long> second) throws Exception {
if (second == null) {
return Tuple3.of(first.f0, first.f1, -1L);
} else {
return Tuple3.of(first.f0, first.f1, second.f2);
}
}
}).name("map_token_to_index").returns(new TupleTypeInfo<>(Types.LONG, Types.INT, Types.LONG));
// tuple: record id, prediction result
DataSet<Tuple2<Long, Row>> aggregateResult = indexed.union(indexedNulTokens).groupBy(0).reduceGroup(new RichGroupReduceFunction<Tuple3<Long, Integer, Long>, Tuple2<Long, Row>>() {
private static final long serialVersionUID = -1581264399340055162L;
// transient long[] defaultIndex;
transient Map<Integer, Long> defaultIndex;
@Override
public void open(Configuration parameters) throws Exception {
if (handleInvalidStrategy.equals(SKIP) || handleInvalidStrategy.equals(ERROR)) {
return;
}
List<Tuple2<Integer, Long>> bc = getRuntimeContext().getBroadcastVariable("defaultIndex");
defaultIndex = new HashMap<>();
for (int i = 0; i < bc.size(); i++) {
defaultIndex.put(bc.get(i).f0, bc.get(i).f1);
}
}
@Override
public void reduce(Iterable<Tuple3<Long, Integer, Long>> values, Collector<Tuple2<Long, Row>> out) throws Exception {
Long id = null;
Row r = new Row(selectedColIdx.length);
for (Tuple3<Long, Integer, Long> v : values) {
Long index = v.f2;
if (index == -1L) {
switch(handleInvalidStrategy) {
case KEEP:
index = defaultIndex.get(v.f1);
index = index == null ? 0L : index;
break;
case SKIP:
index = null;
break;
case ERROR:
throw new RuntimeException("Unknown token.");
}
}
int col = v.f1;
r.setField(col, index);
id = v.f0;
}
out.collect(Tuple2.of(id, r));
}
}).withBroadcastSet(defaultIndex, "defaultIndex").name("aggregate_result").returns(new TupleTypeInfo<>(Types.LONG, new RowTypeInfo(outputColTypes)));
DataSet<Row> output = dataWithId.join(aggregateResult).where(0).equalTo(0).with(new JoinFunction<Tuple2<Long, Row>, Tuple2<Long, Row>, Row>() {
private static final long serialVersionUID = 3724539437313089427L;
@Override
public Row join(Tuple2<Long, Row> first, Tuple2<Long, Row> second) throws Exception {
return outputColsHelper.getResultRow(first.f1, second.f1);
}
}).name("merge_result").returns(new RowTypeInfo(outputColsHelper.getResultSchema().getFieldTypes()));
this.setOutput(output, outputColsHelper.getResultSchema());
return this;
}
use of com.alibaba.alink.common.utils.OutputColsHelper in project Alink by alibaba.
the class HugeStringIndexerPredictBatchOp method linkFrom.
@Override
public HugeStringIndexerPredictBatchOp linkFrom(BatchOperator<?>... inputs) {
Params params = super.getParams();
BatchOperator model = inputs[0];
BatchOperator data = inputs[1];
String[] selectedColNames = params.get(HugeMultiStringIndexerPredictParams.SELECTED_COLS);
String[] outputColNames = params.get(HugeMultiStringIndexerPredictParams.OUTPUT_COLS);
if (outputColNames == null) {
outputColNames = selectedColNames;
}
String[] keepColNames = params.get(HugeMultiStringIndexerPredictParams.RESERVED_COLS);
TypeInformation[] outputColTypes = new TypeInformation[outputColNames.length];
Arrays.fill(outputColTypes, Types.LONG);
OutputColsHelper outputColsHelper = new OutputColsHelper(data.getSchema(), outputColNames, outputColTypes, keepColNames);
final int[] selectedColIdx = TableUtil.findColIndicesWithAssertAndHint(data.getSchema(), selectedColNames);
final HandleInvalid handleInvalidStrategy = HandleInvalid.valueOf(params.get(HugeMultiStringIndexerPredictParams.HANDLE_INVALID).toString());
DataSet<Tuple2<Long, Row>> dataWithId = DataSetUtils.zipWithUniqueId(data.getDataSet());
DataSet<Tuple2<String, Long>> modelData = getModelData(model);
// tuple: column index, default token id
DataSet<Long> defaultIndex = modelData.mapPartition(new MapPartitionFunction<Tuple2<String, Long>, Long>() {
@Override
public void mapPartition(Iterable<Tuple2<String, Long>> iterable, Collector<Long> collector) throws Exception {
Long max = 0L;
for (Tuple2<String, Long> value : iterable) {
max = Math.max(max, value.f1);
}
collector.collect(max);
}
}).reduceGroup(new GroupReduceFunction<Long, Long>() {
@Override
public void reduce(Iterable<Long> iterable, Collector<Long> collector) throws Exception {
Long max = 0L;
for (Long value : iterable) {
max = Math.max(max, value);
}
collector.collect(max + 1L);
}
}).name("get_default_index").returns(Types.LONG);
// tuple: record id, column index, token
DataSet<Tuple3<Long, Integer, String>> flattened = dataWithId.flatMap(new RichFlatMapFunction<Tuple2<Long, Row>, Tuple3<Long, Integer, String>>() {
private static final long serialVersionUID = 1917562146323592635L;
@Override
public void flatMap(Tuple2<Long, Row> value, Collector<Tuple3<Long, Integer, String>> out) throws Exception {
for (int i = 0; i < selectedColIdx.length; i++) {
Object o = value.f1.getField(selectedColIdx[i]);
if (o != null) {
out.collect(Tuple3.of(value.f0, i, String.valueOf(o)));
}
}
}
}).name("flatten_pred_data").returns(new TupleTypeInfo<>(Types.LONG, Types.INT, Types.STRING));
// tuple: record id, column index, token id
DataSet<Tuple3<Long, Integer, Long>> indexedNulTokens = dataWithId.flatMap(new FlatMapFunction<Tuple2<Long, Row>, Tuple3<Long, Integer, Long>>() {
private static final long serialVersionUID = 2893742534366079246L;
@Override
public void flatMap(Tuple2<Long, Row> value, Collector<Tuple3<Long, Integer, Long>> out) throws Exception {
for (int i = 0; i < selectedColIdx.length; i++) {
Object o = value.f1.getField(selectedColIdx[i]);
if (o == null) {
// because null value is ignored during training, so it will always
// be treated as "unseen" token.
out.collect(Tuple3.of(value.f0, i, -1L));
}
}
}
}).name("map_null_token_to_index").returns(new TupleTypeInfo<>(Types.LONG, Types.INT, Types.LONG));
// record id, column index, token index
DataSet<Tuple3<Long, Integer, Long>> indexed = flattened.leftOuterJoin(modelData).where(2).equalTo(0).with(new JoinFunction<Tuple3<Long, Integer, String>, Tuple2<String, Long>, Tuple3<Long, Integer, Long>>() {
private static final long serialVersionUID = 2270459281179536013L;
@Override
public Tuple3<Long, Integer, Long> join(Tuple3<Long, Integer, String> first, Tuple2<String, Long> second) throws Exception {
if (second == null) {
return Tuple3.of(first.f0, first.f1, -1L);
} else {
return Tuple3.of(first.f0, first.f1, second.f1);
}
}
}).name("map_token_to_index").returns(new TupleTypeInfo<>(Types.LONG, Types.INT, Types.LONG));
// tuple: record id, prediction result
DataSet<Tuple2<Long, Row>> aggregateResult = indexed.union(indexedNulTokens).groupBy(0).reduceGroup(new RichGroupReduceFunction<Tuple3<Long, Integer, Long>, Tuple2<Long, Row>>() {
private static final long serialVersionUID = -1581264399340055162L;
// transient long[] defaultIndex;
transient Long defaultIndex;
@Override
public void open(Configuration parameters) throws Exception {
if (handleInvalidStrategy.equals(SKIP) || handleInvalidStrategy.equals(ERROR)) {
return;
}
List<Long> bc = getRuntimeContext().getBroadcastVariable("defaultIndex");
defaultIndex = bc.get(0);
}
@Override
public void reduce(Iterable<Tuple3<Long, Integer, Long>> values, Collector<Tuple2<Long, Row>> out) throws Exception {
Long id = null;
Row r = new Row(selectedColIdx.length);
for (Tuple3<Long, Integer, Long> v : values) {
Long index = v.f2;
if (index == -1L) {
switch(handleInvalidStrategy) {
case KEEP:
index = defaultIndex;
index = index == null ? 0L : index;
break;
case SKIP:
index = null;
break;
case ERROR:
throw new RuntimeException("Unknown token.");
}
}
int col = v.f1;
r.setField(col, index);
id = v.f0;
}
out.collect(Tuple2.of(id, r));
}
}).withBroadcastSet(defaultIndex, "defaultIndex").name("aggregate_result").returns(new TupleTypeInfo<>(Types.LONG, new RowTypeInfo(outputColTypes)));
DataSet<Row> output = dataWithId.join(aggregateResult).where(0).equalTo(0).with(new JoinFunction<Tuple2<Long, Row>, Tuple2<Long, Row>, Row>() {
private static final long serialVersionUID = 3724539437313089427L;
@Override
public Row join(Tuple2<Long, Row> first, Tuple2<Long, Row> second) throws Exception {
return outputColsHelper.getResultRow(first.f1, second.f1);
}
}).name("merge_result").returns(new RowTypeInfo(outputColsHelper.getResultSchema().getFieldTypes()));
this.setOutput(output, outputColsHelper.getResultSchema());
return this;
}
Aggregations