use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class StringNearestNeighborTest method testStringApproxStream.
@Test
public void testStringApproxStream() throws Exception {
BatchOperator dict = new MemSourceBatchOp(Arrays.asList(StringNearestNeighborBatchOpTest.dictRows), new String[] { "id", "str" });
StreamOperator query = new MemSourceStreamOp(Arrays.asList(StringNearestNeighborBatchOpTest.queryRows), new String[] { "id", "str" });
StringApproxNearestNeighborModel neareastNeighbor = new StringApproxNearestNeighbor().setIdCol("id").setSelectedCol("str").setTopN(3).setOutputCol("output").fit(dict);
CollectSinkStreamOp collectSinkStreamOp = new CollectSinkStreamOp().linkFrom(neareastNeighbor.transform(query));
StreamOperator.execute();
List<Row> res = collectSinkStreamOp.getAndRemoveValues();
Map<Object, Double[]> score = new HashMap<>();
score.put(1, new Double[] { 0.953125, 0.9375, 0.921875 });
score.put(2, new Double[] { 0.953125, 0.9375, 0.921875 });
score.put(3, new Double[] { 0.9375, 0.921875, 0.90625 });
score.put(4, new Double[] { 0.96875, 0.90625, 0.890625 });
score.put(5, new Double[] { 0.9375, 0.921875, 0.90625 });
score.put(6, new Double[] { 0.96875, 0.90625, 0.890625 });
for (Row row : res) {
Double[] actual = extractScore((String) row.getField(2));
Double[] expect = score.get(row.getField(0));
for (int i = 0; i < actual.length; i++) {
Assert.assertEquals(actual[i], expect[i], 0.01);
}
}
}
use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class StringNearestNeighborTest method testTextStream.
@Test
public void testTextStream() throws Exception {
BatchOperator dict = new MemSourceBatchOp(Arrays.asList(TextApproxNearestNeighborBatchOpTest.dictRows), new String[] { "id", "str" });
StreamOperator query = new MemSourceStreamOp(Arrays.asList(TextApproxNearestNeighborBatchOpTest.queryRows), new String[] { "id", "str" });
TextNearestNeighborModel neareastNeighbor = new TextNearestNeighbor().setIdCol("id").setSelectedCol("str").setTopN(3).setOutputCol("output").fit(dict);
CollectSinkStreamOp collectSinkStreamOp = new CollectSinkStreamOp().linkFrom(neareastNeighbor.transform(query));
StreamOperator.execute();
List<Row> res = collectSinkStreamOp.getAndRemoveValues();
Map<Object, Double[]> score = new HashMap<>();
score.put(1, new Double[] { 0.75, 0.667, 0.333 });
score.put(2, new Double[] { 0.667, 0.667, 0.5 });
score.put(3, new Double[] { 0.333, 0.333, 0.25 });
score.put(4, new Double[] { 0.75, 0.333, 0.333 });
score.put(5, new Double[] { 0.333, 0.25, 0.25 });
score.put(6, new Double[] { 0.333, 0.333, 0.333 });
for (Row row : res) {
Double[] actual = StringNearestNeighborBatchOpTest.extractScore((String) row.getField(2));
Double[] expect = score.get(row.getField(0));
for (int i = 0; i < actual.length; i++) {
Assert.assertEquals(actual[i], expect[i], 0.01);
}
}
}
use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class StringSimilarityPairwiseTest method testLEVENSHTEINStream.
@Test
public void testLEVENSHTEINStream() throws Exception {
List<Row> df = Arrays.asList(Row.of(0, "abcde", "aabce"), Row.of(1, "aacedw", "aabbed"), Row.of(2, "cdefa", "bbcefa"), Row.of(3, "bdefh", "ddeac"), Row.of(4, "acedm", "aeefbc"));
StreamOperator<?> inOp1 = new MemSourceStreamOp(df, "id int, text1 string, text2 string");
StringSimilarityPairwise stringSimilarityPairwise = new StringSimilarityPairwise().setSelectedCols("text1", "text2").setMetric("LEVENSHTEIN").setOutputCol("output");
CollectSinkStreamOp collectSinkStreamOp = new CollectSinkStreamOp().linkFrom(stringSimilarityPairwise.transform(inOp1));
StreamOperator.execute();
List<Row> list = collectSinkStreamOp.getAndRemoveValues();
list.sort(new RowComparator(0));
List<Row> output = Arrays.asList(Row.of(0, "abcde", "aabce", 2.0), Row.of(1, "aacedw", "aabbed", 3.0), Row.of(2, "cdefa", "bbcefa", 3.0), Row.of(3, "bdefh", "ddeac", 3.0), Row.of(4, "acedm", "aeefbc", 4.0));
assertListRowEqual(output, list, 0);
}
use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class StringSimilarityPairwiseTest method testStream.
@Test
public void testStream() throws Exception {
Row[] array = new Row[] { Row.of(1L, "北京", "北京"), Row.of(2L, "北京欢迎", "中国人民"), Row.of(3L, "Beijing", "Beijing"), Row.of(4L, "Beijing", "Chinese"), Row.of(5L, "Good Morning!", "Good Evening!") };
String selectedColName0 = "col0";
String selectedColName1 = "col1";
StreamOperator<?> words = new MemSourceStreamOp(Arrays.asList(array), new String[] { "ID", selectedColName0, selectedColName1 });
StringSimilarityPairwise evalOp = new StringSimilarityPairwise().setSelectedCols(new String[] { selectedColName0, selectedColName1 }).setMetric("COSINE").setOutputCol("COSINE").setWindowSize(4);
StreamOperator<?> res = evalOp.transform(words);
CollectSinkStreamOp collectSinkStreamOp = new CollectSinkStreamOp().linkFrom(res);
StreamOperator.execute();
List<Row> list = collectSinkStreamOp.getAndRemoveValues();
list.sort(new RowComparator(0));
Row[] output = new Row[] { Row.of(1L, "北京", "北京", 1.0), Row.of(2L, "北京欢迎", "中国人民", 0.0), Row.of(3L, "Beijing", "Beijing", 1.0), Row.of(4L, "Beijing", "Chinese", 0.0), Row.of(5L, "Good Morning!", "Good Evening!", 0.4) };
assertListRowEqual(Arrays.asList(output), list, 0);
}
use of com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp in project Alink by alibaba.
the class OneHotTest method pipelineTest.
@Test
public void pipelineTest() throws Exception {
OneHotEncoder oneHot = new OneHotEncoder().setSelectedCols(binaryNames).setOutputCols("results").setDropLast(false).enableLazyPrintModelInfo();
VectorAssembler va = new VectorAssembler().setSelectedCols(new String[] { "cnt", "results" }).enableLazyPrintTransformStat("xxxxxx").setOutputCol("outN");
Pipeline pl = new Pipeline().add(oneHot).add(va);
PipelineModel model = pl.fit((BatchOperator<?>) getData(true));
Row[] parray = new Row[] { Row.of("0", "doc0", "天", 4L), Row.of("1", "doc2", null, 3L) };
List<Row> expectedRow = Arrays.asList(Row.of("0", new SparseVector(19, new int[] { 0, 3, 10, 16 }, new double[] { 4.0, 1.0, 1.0, 1.0 })), Row.of("1", new SparseVector(19, new int[] { 0, 1, 12, 15 }, new double[] { 3.0, 1.0, 1.0, 1.0 })));
// batch predict
MemSourceBatchOp predData = new MemSourceBatchOp(Arrays.asList(parray), schema);
List<Row> rows = model.transform(predData).select("id, outN").collect();
assertListRowEqual(expectedRow, rows, 0);
// stream predict
MemSourceStreamOp predSData = new MemSourceStreamOp(Arrays.asList(parray), schema);
CollectSinkStreamOp sink = model.transform(predSData).select("id, outN").link(new CollectSinkStreamOp());
StreamOperator.execute();
assertListRowEqual(expectedRow, sink.getAndRemoveValues(), 0);
}
Aggregations