use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class Preprocessing method select.
public static BatchOperator<?> select(BatchOperator<?> in, String... selectCols) {
final int[] selectIndices = TableUtil.findColIndicesWithAssertAndHint(in.getColNames(), selectCols);
final TypeInformation<?>[] selectColTypes = TableUtil.findColTypesWithAssertAndHint(in.getSchema(), selectCols);
return new TableSourceBatchOp(DataSetConversionUtil.toTable(in.getMLEnvironmentId(), in.getDataSet().map(new RichMapFunction<Row, Row>() {
private static final long serialVersionUID = 9119490369706910594L;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
LOG.info("{} open.", getRuntimeContext().getTaskName());
}
@Override
public void close() throws Exception {
super.close();
LOG.info("{} close.", getRuntimeContext().getTaskName());
}
@Override
public Row map(Row value) throws Exception {
Row ret = new Row(selectIndices.length);
for (int i = 0; i < selectIndices.length; ++i) {
ret.setField(i, value.getField(selectIndices[i]));
}
return ret;
}
}), selectCols, selectColTypes)).setMLEnvironmentId(in.getMLEnvironmentId());
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class TreeModelInfoBatchOp method combinedTreeModelFeatureImportance.
private static BatchOperator<?> combinedTreeModelFeatureImportance(BatchOperator<?> model, BatchOperator<?> featureImportance) {
DataSet<String> importanceJson = featureImportance.getDataSet().reduceGroup(new GroupReduceFunction<Row, String>() {
private static final long serialVersionUID = -1576541700351312745L;
@Override
public void reduce(Iterable<Row> values, Collector<String> out) throws Exception {
Map<String, Double> importance = new HashMap<>();
for (Row val : values) {
importance.put(String.valueOf(val.getField(0)), ((Number) val.getField(1)).doubleValue());
}
out.collect(JsonConverter.toJson(importance));
}
});
DataSet<Row> combined = model.getDataSet().reduceGroup(new RichGroupReduceFunction<Row, Row>() {
private static final long serialVersionUID = -1576541700351312745L;
private transient String featureImportanceJson;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
featureImportanceJson = getRuntimeContext().getBroadcastVariableWithInitializer("importanceJson", new BroadcastVariableInitializer<String, String>() {
@Override
public String initializeBroadcastVariable(Iterable<String> data) {
return data.iterator().next();
}
});
}
@Override
public void reduce(Iterable<Row> values, Collector<Row> out) throws Exception {
List<Row> modelRows = new ArrayList<>();
for (Row val : values) {
modelRows.add(val);
}
TreeModelDataConverter model = new TreeModelDataConverter().load(modelRows);
model.meta.set(TreeModelInfo.FEATURE_IMPORTANCE, featureImportanceJson);
model.save(model, out);
}
}).withBroadcastSet(importanceJson, "importanceJson");
return new TableSourceBatchOp(DataSetConversionUtil.toTable(model.getMLEnvironmentId(), combined, model.getColNames(), model.getColTypes()));
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class SampleWithSizeBatchOpTest method test.
@Test
public void test() throws Exception {
TableSourceBatchOp tableSourceBatchOp = new TableSourceBatchOp(getBatchTable());
long cnt = tableSourceBatchOp.link(new SampleWithSizeBatchOp(5, true)).count();
assert cnt == 5;
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class StandardScalerTest method testModelInfo.
@Test
public void testModelInfo() {
BatchOperator batchData = new TableSourceBatchOp(GenerateData.getBatchTable());
StandardScalerTrainBatchOp trainOp = new StandardScalerTrainBatchOp().setWithMean(true).setWithStd(true).setSelectedCols("f0").linkFrom(batchData);
StandardScalerModelInfo modelInfo = trainOp.getModelInfoBatchOp().collectModelInfo();
System.out.println(modelInfo.getMeans().length);
System.out.println(modelInfo.getStdDevs().length);
System.out.println(modelInfo.isWithMeans());
System.out.println(modelInfo.isWithStdDevs());
System.out.println(modelInfo.toString());
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class StandardScalerTest method test.
@Test
public void test() throws Exception {
BatchOperator batchData = new TableSourceBatchOp(GenerateData.getBatchTable());
StreamOperator streamData = new TableSourceStreamOp(GenerateData.getStreamTable());
StandardScalerTrainBatchOp op = new StandardScalerTrainBatchOp().setWithMean(true).setWithStd(true).setSelectedCols("f0", "f1").linkFrom(batchData);
new StandardScalerPredictBatchOp().setOutputCols("f0_1", "f1_1").linkFrom(op, batchData).lazyCollect(new Consumer<List<Row>>() {
@Override
public void accept(List<Row> rows) {
rows.sort(compare);
assertEquals(rows.get(0), Row.of(null, null, null, null));
assertRow(rows.get(1), Row.of(-1., -3., -0.9272, -1.1547));
assertRow(rows.get(2), Row.of(1., 2., -0.1325, 0.5774));
assertRow(rows.get(3), Row.of(4., 2., 1.0596, 0.5774));
}
});
new StandardScalerPredictStreamOp(op).setOutputCols("f0_1", "f1_1").linkFrom(streamData).print();
StandardScalerModel model1 = new StandardScaler().setWithMean(true).setWithStd(false).setSelectedCols("f0", "f1").setOutputCols("f0_1", "f1_1").fit(batchData);
model1.transform(batchData).lazyCollect(new Consumer<List<Row>>() {
@Override
public void accept(List<Row> rows) {
rows.sort(compare);
assertEquals(rows.get(0), Row.of(null, null, null, null));
assertRow(rows.get(1), Row.of(-1., -3., -2.3333, -3.3333));
assertRow(rows.get(2), Row.of(1., 2., -0.3333, 1.6666));
assertRow(rows.get(3), Row.of(4., 2., 2.6666, 1.6666));
}
});
model1.transform(streamData).print();
StandardScalerModel model2 = new StandardScaler().setWithMean(false).setWithStd(true).setSelectedCols("f0", "f1").setOutputCols("f0_1", "f1_1").fit(batchData);
model2.transform(batchData).lazyCollect(new Consumer<List<Row>>() {
@Override
public void accept(List<Row> rows) {
rows.sort(compare);
assertEquals(rows.get(0), Row.of(null, null, null, null));
assertRow(rows.get(1), Row.of(-1., -3., -0.3974, -1.0392));
assertRow(rows.get(2), Row.of(1., 2., 0.3974, 0.6928));
assertRow(rows.get(3), Row.of(4., 2., 1.5894, 0.6928));
}
});
model2.transform(streamData).print();
StandardScalerModel model3 = new StandardScaler().setWithMean(false).setWithStd(false).setSelectedCols("f0", "f1").setOutputCols("f0_1", "f1_1").fit(batchData);
model3.transform(batchData).lazyCollect(new Consumer<List<Row>>() {
@Override
public void accept(List<Row> rows) {
rows.sort(compare);
assertEquals(rows.get(0), Row.of(null, null, null, null));
assertRow(rows.get(1), Row.of(-1., -3., -1., -3.));
assertRow(rows.get(2), Row.of(1., 2., 1., 2.));
assertRow(rows.get(3), Row.of(4., 2., 4., 2.));
}
});
model3.transform(streamData).print();
StreamOperator.execute();
}
Aggregations