use of com.alibaba.alink.pipeline.dataproc.StandardScalerModel in project Alink by alibaba.
the class StandardScalerTest method test.
@Test
public void test() throws Exception {
BatchOperator batchData = new TableSourceBatchOp(GenerateData.getBatchTable());
StreamOperator streamData = new TableSourceStreamOp(GenerateData.getStreamTable());
StandardScalerTrainBatchOp op = new StandardScalerTrainBatchOp().setWithMean(true).setWithStd(true).setSelectedCols("f0", "f1").linkFrom(batchData);
new StandardScalerPredictBatchOp().setOutputCols("f0_1", "f1_1").linkFrom(op, batchData).lazyCollect(new Consumer<List<Row>>() {
@Override
public void accept(List<Row> rows) {
rows.sort(compare);
assertEquals(rows.get(0), Row.of(null, null, null, null));
assertRow(rows.get(1), Row.of(-1., -3., -0.9272, -1.1547));
assertRow(rows.get(2), Row.of(1., 2., -0.1325, 0.5774));
assertRow(rows.get(3), Row.of(4., 2., 1.0596, 0.5774));
}
});
new StandardScalerPredictStreamOp(op).setOutputCols("f0_1", "f1_1").linkFrom(streamData).print();
StandardScalerModel model1 = new StandardScaler().setWithMean(true).setWithStd(false).setSelectedCols("f0", "f1").setOutputCols("f0_1", "f1_1").fit(batchData);
model1.transform(batchData).lazyCollect(new Consumer<List<Row>>() {
@Override
public void accept(List<Row> rows) {
rows.sort(compare);
assertEquals(rows.get(0), Row.of(null, null, null, null));
assertRow(rows.get(1), Row.of(-1., -3., -2.3333, -3.3333));
assertRow(rows.get(2), Row.of(1., 2., -0.3333, 1.6666));
assertRow(rows.get(3), Row.of(4., 2., 2.6666, 1.6666));
}
});
model1.transform(streamData).print();
StandardScalerModel model2 = new StandardScaler().setWithMean(false).setWithStd(true).setSelectedCols("f0", "f1").setOutputCols("f0_1", "f1_1").fit(batchData);
model2.transform(batchData).lazyCollect(new Consumer<List<Row>>() {
@Override
public void accept(List<Row> rows) {
rows.sort(compare);
assertEquals(rows.get(0), Row.of(null, null, null, null));
assertRow(rows.get(1), Row.of(-1., -3., -0.3974, -1.0392));
assertRow(rows.get(2), Row.of(1., 2., 0.3974, 0.6928));
assertRow(rows.get(3), Row.of(4., 2., 1.5894, 0.6928));
}
});
model2.transform(streamData).print();
StandardScalerModel model3 = new StandardScaler().setWithMean(false).setWithStd(false).setSelectedCols("f0", "f1").setOutputCols("f0_1", "f1_1").fit(batchData);
model3.transform(batchData).lazyCollect(new Consumer<List<Row>>() {
@Override
public void accept(List<Row> rows) {
rows.sort(compare);
assertEquals(rows.get(0), Row.of(null, null, null, null));
assertRow(rows.get(1), Row.of(-1., -3., -1., -3.));
assertRow(rows.get(2), Row.of(1., 2., 1., 2.));
assertRow(rows.get(3), Row.of(4., 2., 4., 2.));
}
});
model3.transform(streamData).print();
StreamOperator.execute();
}
Aggregations