use of com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp in project Alink by alibaba.
the class Preprocessing method generateStringIndexerModel.
public static BatchOperator<?> generateStringIndexerModel(BatchOperator<?> input, Params params) {
String[] categoricalColNames = null;
if (params.contains(HasCategoricalCols.CATEGORICAL_COLS)) {
categoricalColNames = params.get(HasCategoricalCols.CATEGORICAL_COLS);
}
BatchOperator<?> stringIndexerModel;
if (categoricalColNames == null || categoricalColNames.length == 0) {
MultiStringIndexerModelDataConverter emptyModel = new MultiStringIndexerModelDataConverter();
stringIndexerModel = new DataSetWrapperBatchOp(MLEnvironmentFactory.get(input.getMLEnvironmentId()).getExecutionEnvironment().fromElements(1).mapPartition(new MapPartitionFunction<Integer, Row>() {
private static final long serialVersionUID = -7481931851291494026L;
@Override
public void mapPartition(Iterable<Integer> values, Collector<Row> out) throws Exception {
// pass
}
}), emptyModel.getModelSchema().getFieldNames(), emptyModel.getModelSchema().getFieldTypes()).setMLEnvironmentId(input.getMLEnvironmentId());
} else {
stringIndexerModel = new MultiStringIndexerTrainBatchOp().setMLEnvironmentId(input.getMLEnvironmentId()).setSelectedCols(categoricalColNames).setStringOrderType(HasStringOrderTypeDefaultAsRandom.StringOrderType.ALPHABET_ASC).linkFrom(input);
}
return stringIndexerModel;
}
use of com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp in project Alink by alibaba.
the class Preprocessing method generateQuantileDiscretizerModel.
public static BatchOperator<?> generateQuantileDiscretizerModel(BatchOperator<?> input, Params params) {
if (params.contains(HasVectorCol.VECTOR_COL)) {
return sample(input, params).linkTo(new VectorTrain(new Params().set(ZERO_AS_MISSING, params.get(ZERO_AS_MISSING))).setMLEnvironmentId(input.getMLEnvironmentId()).setVectorCol(params.get(HasVectorCol.VECTOR_COL)).setNumBuckets(params.get(HasMaxBins.MAX_BINS)));
}
String[] continuousColNames = ArrayUtils.removeElements(params.get(HasFeatureCols.FEATURE_COLS), params.get(HasCategoricalCols.CATEGORICAL_COLS));
BatchOperator<?> quantileDiscretizerModel;
if (continuousColNames != null && continuousColNames.length > 0) {
quantileDiscretizerModel = sample(input, params).linkTo(new QuantileDiscretizerTrainBatchOp(new Params().set(ZERO_AS_MISSING, params.get(ZERO_AS_MISSING))).setMLEnvironmentId(input.getMLEnvironmentId()).setSelectedCols(continuousColNames).setNumBuckets(params.get(HasMaxBins.MAX_BINS)));
} else {
QuantileDiscretizerModelDataConverter emptyModel = new QuantileDiscretizerModelDataConverter();
quantileDiscretizerModel = new DataSetWrapperBatchOp(MLEnvironmentFactory.get(input.getMLEnvironmentId()).getExecutionEnvironment().fromElements(1).mapPartition(new MapPartitionFunction<Integer, Row>() {
private static final long serialVersionUID = 2328781103352773618L;
@Override
public void mapPartition(Iterable<Integer> values, Collector<Row> out) throws Exception {
// pass
}
}), emptyModel.getModelSchema().getFieldNames(), emptyModel.getModelSchema().getFieldTypes()).setMLEnvironmentId(input.getMLEnvironmentId());
}
return quantileDiscretizerModel;
}
use of com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp in project Alink by alibaba.
the class ItemCfTrainBatchOp method linkFrom.
@Override
public ItemCfTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
final String userCol = getUserCol();
final String itemCol = getItemCol();
final String rateCol = getRateCol();
final TypeInformation<?> userType = TableUtil.findColTypeWithAssertAndHint(in.getSchema(), userCol);
final String itemType = FlinkTypeConverter.getTypeString(TableUtil.findColTypeWithAssertAndHint(in.getSchema(), itemCol));
if (null == rateCol) {
Preconditions.checkArgument(getSimilarityType().equals(SimilarityType.JACCARD), "When rateCol is not given, only Jaccard calc is supported!");
}
String[] selectedCols = (null == rateCol ? new String[] { userCol, itemCol } : new String[] { userCol, itemCol, rateCol });
in = in.select(selectedCols);
OneHotTrainBatchOp oneHot = new OneHotTrainBatchOp().setSelectedCols(userCol, itemCol).linkFrom(in);
String USER_ENCODE = "userEncode";
BatchOperator<?> userEncode = new OneHotPredictBatchOp().setSelectedCols(userCol, itemCol).setOutputCols(USER_ENCODE, itemCol).setEncode(HasEncodeWithoutWoe.Encode.INDEX).linkFrom(oneHot, in);
DataSet<Row> itemVector = userEncode.select(null == rateCol ? new String[] { USER_ENCODE, itemCol } : new String[] { USER_ENCODE, itemCol, rateCol }).getDataSet().groupBy(1).reduceGroup(new ItemVectorGenerator(rateCol, userCol)).withBroadcastSet(oneHot.getSideOutput(0).getDataSet(), USER_NUM).name("GenerateItemVector");
DataSet<Row> userVector = userEncode.select(null == rateCol ? new String[] { userCol, itemCol } : new String[] { userCol, itemCol, rateCol }).getDataSet().groupBy(0).reduceGroup(new UserItemVectorGenerator(rateCol, itemCol)).withBroadcastSet(oneHot.getSideOutput(0).getDataSet(), USER_NUM).name("GetUserItems");
BatchOperator<?> items = new DataSetWrapperBatchOp(itemVector, COL_NAMES, new TypeInformation[] { Types.LONG, VectorTypes.SPARSE_VECTOR });
BatchOperator<?> train = new VectorNearestNeighborTrainBatchOp().setIdCol(COL_NAMES[0]).setSelectedCol(COL_NAMES[1]).setMetric(HasFastMetric.Metric.valueOf(this.getSimilarityType().name())).linkFrom(items);
BatchOperator<?> op = new VectorNearestNeighborPredictBatchOp().setSelectedCol(COL_NAMES[1]).setReservedCols(COL_NAMES[0]).setTopN(this.getMaxNeighborNumber() + 1).setRadius(1.0 - this.getSimilarityThreshold()).linkFrom(train, items);
DataSet<Row> itemSimilarities = op.select(new String[] { COL_NAMES[0], COL_NAMES[1] }).getDataSet().mapPartition(new ItemSimilarityVectorGenerator(itemCol)).withBroadcastSet(oneHot.getSideOutput(0).getDataSet(), USER_NUM).name("CalcItemSimilarity");
DataSet<Row> itemMapTable = oneHot.getDataSet().filter(new FilterFunction<Row>() {
private static final long serialVersionUID = 7406134775433418651L;
@Override
public boolean filter(Row value) {
return !value.getField(0).equals(0L);
}
});
Params params = getParams();
DataSet<Row> outs = userVector.union(itemSimilarities).mapPartition(new RichMapPartitionFunction<Row, Row>() {
private static final long serialVersionUID = 3779020277896699637L;
@Override
public void mapPartition(Iterable<Row> values, Collector<Row> out) {
Params meta = null;
if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
List<Row> modelRows = getRuntimeContext().getBroadcastVariable("ITEM_MAP");
MultiStringIndexerModelData modelData = new OneHotModelDataConverter().load(modelRows).modelData;
String[] items = new String[(int) modelData.getNumberOfTokensOfColumn(itemCol)];
for (int i = 0; i < items.length; i++) {
items[i] = modelData.getToken(itemCol, (long) i);
}
meta = params.set(ItemCfRecommTrainParams.RATE_COL, rateCol).set(ItemCfRecommModelDataConverter.ITEMS, items).set(ItemCfRecommModelDataConverter.ITEM_TYPE, itemType).set(ItemCfRecommModelDataConverter.USER_TYPE, FlinkTypeConverter.getTypeString(userType));
}
new ItemCfRecommModelDataConverter(userCol, userType, itemCol).save(Tuple2.of(meta, values), out);
}
}).withBroadcastSet(itemMapTable, "ITEM_MAP").name("build_model");
this.setOutput(outs, new ItemCfRecommModelDataConverter(userCol, userType, itemCol).getModelSchema());
return this;
}
use of com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp in project Alink by alibaba.
the class LeaveKObjectOutBatchOp method linkFrom.
@Override
public LeaveKObjectOutBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
final Double testFraction = getFraction();
final Integer testK = getK();
DataSet<Tuple2<Boolean, Row>> splits = in.getDataSet().groupBy(TableUtil.findColIndexWithAssertAndHint(in.getSchema(), this.getGroupCol())).reduceGroup(new Split(testFraction, testK));
DataSet<Row> train = splits.flatMap(new FlatMapFunction<Tuple2<Boolean, Row>, Row>() {
private static final long serialVersionUID = 8194568313949873147L;
@Override
public void flatMap(Tuple2<Boolean, Row> value, Collector<Row> out) {
if (value.f0) {
out.collect(value.f1);
}
}
});
DataSet<Row> test = splits.flatMap(new FlatMapFunction<Tuple2<Boolean, Row>, Row>() {
private static final long serialVersionUID = 7652429568716812411L;
@Override
public void flatMap(Tuple2<Boolean, Row> value, Collector<Row> out) {
if (!value.f0) {
out.collect(value.f1);
}
}
});
BatchOperator<?> testOp = new DataSetWrapperBatchOp(test, in.getColNames(), in.getColTypes()).setMLEnvironmentId(getMLEnvironmentId());
Zipped2KObjectBatchOp op = new Zipped2KObjectBatchOp(getParams()).linkFrom(testOp);
this.setOutput(op.getDataSet(), op.getSchema());
this.setSideOutputTables(new Table[] { DataSetConversionUtil.toTable(getMLEnvironmentId(), train, in.getSchema()) });
return this;
}
use of com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp in project Alink by alibaba.
the class LeaveTopKObjectOutBatchOp method linkFrom.
@Override
public LeaveTopKObjectOutBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
final Double testFraction = getFraction();
final Integer testK = getK();
int rateIdx = TableUtil.findColIndexWithAssertAndHint(in.getSchema(), this.getRateCol());
DataSet<Tuple2<Boolean, Row>> splits = in.getDataSet().groupBy(TableUtil.findColIndexWithAssertAndHint(in.getSchema(), this.getGroupCol())).reduceGroup(new Split(testFraction, testK, this.getRateThreshold(), rateIdx));
DataSet<Row> train = splits.flatMap(new FlatMapFunction<Tuple2<Boolean, Row>, Row>() {
private static final long serialVersionUID = -2766287446658278413L;
@Override
public void flatMap(Tuple2<Boolean, Row> value, Collector<Row> out) {
if (value.f0) {
out.collect(value.f1);
}
}
});
DataSet<Row> test = splits.flatMap(new FlatMapFunction<Tuple2<Boolean, Row>, Row>() {
private static final long serialVersionUID = 3051286291048876503L;
@Override
public void flatMap(Tuple2<Boolean, Row> value, Collector<Row> out) {
if (!value.f0) {
out.collect(value.f1);
}
}
});
BatchOperator<?> testOp = new DataSetWrapperBatchOp(test, in.getColNames(), in.getColTypes()).setMLEnvironmentId(getMLEnvironmentId());
Zipped2KObjectBatchOp op = new Zipped2KObjectBatchOp(getParams().set(Zipped2KObjectParams.INFO_COLS, new String[] { getRateCol() })).linkFrom(testOp);
this.setOutput(op.getDataSet(), op.getSchema());
this.setSideOutputTables(new Table[] { DataSetConversionUtil.toTable(getMLEnvironmentId(), train, in.getSchema()) });
return this;
}
Aggregations