use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class MinMaxScalerTest method test.
@Test
public void test() throws Exception {
BatchOperator batchData = (BatchOperator) getMultiTypeData(true);
StreamOperator streamData = (StreamOperator) getMultiTypeData(false);
String[] selectedColNames = new String[] { "f_long", "f_int", "f_double" };
MinMaxScaler scaler = new MinMaxScaler().setSelectedCols(selectedColNames).setOutputCols(selectedColNames);
MinMaxScalerModel model = scaler.fit(batchData);
BatchOperator res = model.transform(batchData);
List<Row> rows = res.getDataSet().collect();
rows.sort(new RowComparator(0));
assertEquals(rows.get(0).getField(2), 0.5);
assertEquals(rows.get(0).getField(3), 0.5);
assertEquals(rows.get(0).getField(4), 1.0);
assertEquals(rows.get(1).getField(2), 1.0);
assertEquals(rows.get(1).getField(3), 1.0);
assertEquals(rows.get(1).getField(4), 0.0);
assertEquals(rows.get(2).getField(2), null);
assertEquals(rows.get(2).getField(3), null);
assertEquals(rows.get(2).getField(4), 1.0);
assertEquals(rows.get(3).getField(2), 0.0);
assertEquals(rows.get(3).getField(3), 0.0);
assertEquals(rows.get(3).getField(4), null);
CollectSinkStreamOp collectSinkStreamOp = new CollectSinkStreamOp().linkFrom(model.transform(streamData));
StreamOperator.execute();
rows = collectSinkStreamOp.getAndRemoveValues();
rows.sort(new RowComparator(0));
assertEquals(rows.get(0).getField(2), 0.5);
assertEquals(rows.get(0).getField(3), 0.5);
assertEquals(rows.get(0).getField(4), 1.0);
assertEquals(rows.get(1).getField(2), 1.0);
assertEquals(rows.get(1).getField(3), 1.0);
assertEquals(rows.get(1).getField(4), 0.0);
assertEquals(rows.get(2).getField(2), null);
assertEquals(rows.get(2).getField(3), null);
assertEquals(rows.get(2).getField(4), 1.0);
assertEquals(rows.get(3).getField(2), 0.0);
assertEquals(rows.get(3).getField(3), 0.0);
assertEquals(rows.get(3).getField(4), null);
}
use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class MultiStringIndexerTest method testMultiStringIndexer.
@Test
public void testMultiStringIndexer() throws Exception {
BatchOperator data = new MemSourceBatchOp(Arrays.asList(rows), new String[] { "f0", "f1" });
MultiStringIndexer stringIndexer = new MultiStringIndexer().setSelectedCols("f0", "f1").setOutputCols("f0_index", "f1_index").setHandleInvalid("skip").setStringOrderType("frequency_desc");
data = stringIndexer.fit(data).transform(data);
Assert.assertEquals(data.getColNames().length, 4);
List<Row> result = data.collect();
Assert.assertEquals(result.size(), 4);
result.forEach(row -> {
String token1 = (String) row.getField(0);
Long token2 = (Long) row.getField(1);
Assert.assertEquals(map1.get(token1), row.getField(2));
Assert.assertEquals(map2.get(token2), row.getField(3));
});
StreamOperator streamData = new MemSourceStreamOp(Arrays.asList(rows), new String[] { "f0", "f1" });
CollectSinkStreamOp collectSinkStreamOp = new CollectSinkStreamOp().linkFrom(stringIndexer.fit(data).transform(streamData));
StreamOperator.execute();
result = collectSinkStreamOp.getAndRemoveValues();
Assert.assertEquals(result.size(), 4);
result.forEach(row -> {
String token1 = (String) row.getField(0);
Long token2 = (Long) row.getField(1);
Assert.assertEquals(map1.get(token1), row.getField(2));
Assert.assertEquals(map2.get(token2), row.getField(3));
});
}
use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class StandardScalerTest method test.
@Test
public void test() throws Exception {
BatchOperator batchData = (BatchOperator) getMultiTypeData(true);
StreamOperator streamData = (StreamOperator) getMultiTypeData(false);
String[] selectedColNames = new String[] { "f_long", "f_int", "f_double" };
StandardScaler scaler = new StandardScaler().setSelectedCols(selectedColNames).setWithMean(true).setWithStd(true);
scaler.enableLazyPrintModelInfo();
StandardScalerModel model = scaler.fit(batchData);
BatchOperator res = model.transform(batchData);
List<Row> rows = res.getDataSet().collect();
rows.sort(new RowComparator(0));
assertEquals(rows.get(0).getField(2), 0.0);
assertEquals(rows.get(0).getField(3), 0.0);
assertEquals(rows.get(0).getField(4), 0.0);
assertEquals(rows.get(1).getField(2), 1.0);
assertEquals(rows.get(1).getField(3), 1.0);
assertEquals(rows.get(1).getField(4), null);
assertEquals(rows.get(2).getField(2), null);
assertEquals(rows.get(2).getField(3), null);
assertEquals(rows.get(2).getField(4), null);
assertEquals(rows.get(3).getField(2), -1.0);
assertEquals(rows.get(3).getField(3), -1.0);
assertEquals(rows.get(3).getField(4), null);
CollectSinkStreamOp collectSinkStreamOp = new CollectSinkStreamOp().linkFrom(model.transform(streamData));
StreamOperator.execute();
rows = collectSinkStreamOp.getAndRemoveValues();
rows.sort(new RowComparator(0));
assertEquals(rows.get(0).getField(2), 0.0);
assertEquals(rows.get(0).getField(3), 0.0);
assertEquals(rows.get(0).getField(4), 0.0);
assertEquals(rows.get(1).getField(2), 1.0);
assertEquals(rows.get(1).getField(3), 1.0);
assertEquals(rows.get(1).getField(4), null);
assertEquals(rows.get(2).getField(2), null);
assertEquals(rows.get(2).getField(3), null);
assertEquals(rows.get(2).getField(4), null);
assertEquals(rows.get(3).getField(2), -1.0);
assertEquals(rows.get(3).getField(3), -1.0);
assertEquals(rows.get(3).getField(4), null);
}
use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class StringIndexerTest method testFrequencyAsc.
@Test
public void testFrequencyAsc() throws Exception {
BatchOperator data = new MemSourceBatchOp(Arrays.asList(rows), new String[] { "f0" });
StringIndexer stringIndexer = new StringIndexer().setSelectedCol("f0").setOutputCol("f0_indexed").setStringOrderType("frequency_asc");
List<Row> prediction = stringIndexer.fit(data).transform(data).collect();
checkResult(prediction, new String[] { "tennis", "basketball", "football" });
StreamOperator streamData = new MemSourceStreamOp(Arrays.asList(rows), new String[] { "f0" });
CollectSinkStreamOp collectSinkStreamOp = new CollectSinkStreamOp().linkFrom(stringIndexer.fit(data).transform(streamData));
StreamOperator.execute();
List<Row> result = collectSinkStreamOp.getAndRemoveValues();
checkResult(result, new String[] { "tennis", "basketball", "football" });
}
use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class LinearRegressionTest method regressionPipelineTest.
@Test
public void regressionPipelineTest() throws Exception {
BatchOperator<?> vecdata = new MemSourceBatchOp(Arrays.asList(vecrows), veccolNames);
StreamOperator<?> svecdata = new MemSourceStreamOp(Arrays.asList(vecrows), veccolNames);
String[] xVars = new String[] { "f0", "f1", "f2" };
String yVar = "label";
String vec = "vec";
String svec = "svec";
LinearRegression linear = new LinearRegression().setLabelCol(yVar).setFeatureCols(xVars).setMaxIter(20).setOptimMethod("newton").setPredictionCol("linpred");
LinearRegression vlinear = new LinearRegression().setLabelCol(yVar).setVectorCol(vec).setMaxIter(20).setPredictionCol("vlinpred");
LinearRegression svlinear = new LinearRegression().setLabelCol(yVar).setVectorCol(svec).setMaxIter(20).setPredictionCol("svlinpred");
svlinear.enableLazyPrintModelInfo();
svlinear.enableLazyPrintTrainInfo();
Pipeline pl = new Pipeline().add(linear).add(vlinear).add(svlinear);
PipelineModel model = pl.fit(vecdata);
BatchOperator<?> result = model.transform(vecdata).select(new String[] { "label", "linpred", "vlinpred", "svlinpred" });
List<Row> data = result.collect();
for (Row row : data) {
if ((double) row.getField(0) == 16.8000) {
Assert.assertEquals((double) row.getField(1), 16.814789059973744, 0.01);
Assert.assertEquals((double) row.getField(2), 16.814789059973744, 0.01);
Assert.assertEquals((double) row.getField(3), 16.814788687904162, 0.01);
} else if ((double) row.getField(0) == 6.7000) {
Assert.assertEquals((double) row.getField(1), 6.773942836224718, 0.01);
Assert.assertEquals((double) row.getField(2), 6.773942836224718, 0.01);
Assert.assertEquals((double) row.getField(3), 6.773943529327923, 0.01);
}
}
// below is stream test code
CollectSinkStreamOp sop = model.transform(svecdata).select(new String[] { "label", "linpred", "vlinpred", "svlinpred" }).link(new CollectSinkStreamOp());
StreamOperator.execute();
List<Row> rows = sop.getAndRemoveValues();
for (Row row : rows) {
if ((double) row.getField(0) == 16.8000) {
Assert.assertEquals((double) row.getField(1), 16.814789059973744, 0.01);
Assert.assertEquals((double) row.getField(2), 16.814789059973744, 0.01);
Assert.assertEquals((double) row.getField(3), 16.814788687904162, 0.01);
} else if ((double) row.getField(0) == 6.7000) {
Assert.assertEquals((double) row.getField(1), 6.773942836224718, 0.01);
Assert.assertEquals((double) row.getField(2), 6.773942836224718, 0.01);
Assert.assertEquals((double) row.getField(3), 6.773943529327923, 0.01);
}
}
}
Aggregations