use of org.apache.solr.client.solrj.io.stream.HashJoinStream in project lucene-solr by apache.
the class GraphExpressionTest method testGatherNodesFriendsStream.
@Test
public void testGatherNodesFriendsStream() throws Exception {
new UpdateRequest().add(id, "0", "from_s", "bill", "to_s", "jim", "message_t", "Hello jim").add(id, "1", "from_s", "bill", "to_s", "sam", "message_t", "Hello sam").add(id, "2", "from_s", "bill", "to_s", "max", "message_t", "Hello max").add(id, "3", "from_s", "max", "to_s", "kip", "message_t", "Hello kip").add(id, "4", "from_s", "sam", "to_s", "steve", "message_t", "Hello steve").add(id, "5", "from_s", "jim", "to_s", "ann", "message_t", "Hello steve").commit(cluster.getSolrClient(), COLLECTION);
List<Tuple> tuples = null;
GatherNodesStream stream = null;
StreamContext context = new StreamContext();
SolrClientCache cache = new SolrClientCache();
context.setSolrClientCache(cache);
StreamFactory factory = new StreamFactory().withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress()).withFunctionName("gatherNodes", GatherNodesStream.class).withFunctionName("search", CloudSolrStream.class).withFunctionName("count", CountMetric.class).withFunctionName("hashJoin", HashJoinStream.class).withFunctionName("avg", MeanMetric.class).withFunctionName("sum", SumMetric.class).withFunctionName("min", MinMetric.class).withFunctionName("max", MaxMetric.class);
String expr = "gatherNodes(collection1, " + "walk=\"bill->from_s\"," + "gather=\"to_s\")";
stream = (GatherNodesStream) factory.constructStream(expr);
stream.setStreamContext(context);
tuples = getTuples(stream);
Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
assertTrue(tuples.size() == 3);
assertTrue(tuples.get(0).getString("node").equals("jim"));
assertTrue(tuples.get(1).getString("node").equals("max"));
assertTrue(tuples.get(2).getString("node").equals("sam"));
//Test scatter branches, leaves and trackTraversal
expr = "gatherNodes(collection1, " + "walk=\"bill->from_s\"," + "gather=\"to_s\"," + "scatter=\"branches, leaves\", trackTraversal=\"true\")";
stream = (GatherNodesStream) factory.constructStream(expr);
context = new StreamContext();
context.setSolrClientCache(cache);
stream.setStreamContext(context);
tuples = getTuples(stream);
Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
assertTrue(tuples.size() == 4);
assertTrue(tuples.get(0).getString("node").equals("bill"));
assertTrue(tuples.get(0).getLong("level").equals(new Long(0)));
assertTrue(tuples.get(0).getStrings("ancestors").size() == 0);
assertTrue(tuples.get(1).getString("node").equals("jim"));
assertTrue(tuples.get(1).getLong("level").equals(new Long(1)));
List<String> ancestors = tuples.get(1).getStrings("ancestors");
System.out.println("##################### Ancestors:" + ancestors);
assert (ancestors.size() == 1);
assert (ancestors.get(0).equals("bill"));
assertTrue(tuples.get(2).getString("node").equals("max"));
assertTrue(tuples.get(2).getLong("level").equals(new Long(1)));
ancestors = tuples.get(2).getStrings("ancestors");
assert (ancestors.size() == 1);
assert (ancestors.get(0).equals("bill"));
assertTrue(tuples.get(3).getString("node").equals("sam"));
assertTrue(tuples.get(3).getLong("level").equals(new Long(1)));
ancestors = tuples.get(3).getStrings("ancestors");
assert (ancestors.size() == 1);
assert (ancestors.get(0).equals("bill"));
// Test query root
expr = "gatherNodes(collection1, " + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\")," + "walk=\"from_s->from_s\"," + "gather=\"to_s\")";
stream = (GatherNodesStream) factory.constructStream(expr);
context = new StreamContext();
context.setSolrClientCache(cache);
stream.setStreamContext(context);
tuples = getTuples(stream);
Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
assertTrue(tuples.size() == 3);
assertTrue(tuples.get(0).getString("node").equals("jim"));
assertTrue(tuples.get(1).getString("node").equals("max"));
assertTrue(tuples.get(2).getString("node").equals("sam"));
// Test query root scatter branches
expr = "gatherNodes(collection1, " + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\")," + "walk=\"from_s->from_s\"," + "gather=\"to_s\", scatter=\"branches, leaves\")";
stream = (GatherNodesStream) factory.constructStream(expr);
context = new StreamContext();
context.setSolrClientCache(cache);
stream.setStreamContext(context);
tuples = getTuples(stream);
Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
assertTrue(tuples.size() == 4);
assertTrue(tuples.get(0).getString("node").equals("bill"));
assertTrue(tuples.get(0).getLong("level").equals(new Long(0)));
assertTrue(tuples.get(1).getString("node").equals("jim"));
assertTrue(tuples.get(1).getLong("level").equals(new Long(1)));
assertTrue(tuples.get(2).getString("node").equals("max"));
assertTrue(tuples.get(2).getLong("level").equals(new Long(1)));
assertTrue(tuples.get(3).getString("node").equals("sam"));
assertTrue(tuples.get(3).getLong("level").equals(new Long(1)));
expr = "gatherNodes(collection1, " + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\")," + "walk=\"from_s->from_s\"," + "gather=\"to_s\")";
String expr2 = "gatherNodes(collection1, " + expr + "," + "walk=\"node->from_s\"," + "gather=\"to_s\")";
stream = (GatherNodesStream) factory.constructStream(expr2);
context = new StreamContext();
context.setSolrClientCache(cache);
stream.setStreamContext(context);
tuples = getTuples(stream);
Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
assertTrue(tuples.size() == 3);
assertTrue(tuples.get(0).getString("node").equals("ann"));
assertTrue(tuples.get(1).getString("node").equals("kip"));
assertTrue(tuples.get(2).getString("node").equals("steve"));
//Test two traversals in the same expression
String expr3 = "hashJoin(" + expr2 + ", hashed=" + expr2 + ", on=\"node\")";
HashJoinStream hstream = (HashJoinStream) factory.constructStream(expr3);
context = new StreamContext();
context.setSolrClientCache(cache);
hstream.setStreamContext(context);
tuples = getTuples(hstream);
Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
assertTrue(tuples.size() == 3);
assertTrue(tuples.get(0).getString("node").equals("ann"));
assertTrue(tuples.get(1).getString("node").equals("kip"));
assertTrue(tuples.get(2).getString("node").equals("steve"));
//=================================
expr = "gatherNodes(collection1, " + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\")," + "walk=\"from_s->from_s\"," + "gather=\"to_s\")";
expr2 = "gatherNodes(collection1, " + expr + "," + "walk=\"node->from_s\"," + "gather=\"to_s\", scatter=\"branches, leaves\")";
stream = (GatherNodesStream) factory.constructStream(expr2);
context = new StreamContext();
context.setSolrClientCache(cache);
stream.setStreamContext(context);
tuples = getTuples(stream);
Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
assertTrue(tuples.size() == 7);
assertTrue(tuples.get(0).getString("node").equals("ann"));
assertTrue(tuples.get(0).getLong("level").equals(new Long(2)));
assertTrue(tuples.get(1).getString("node").equals("bill"));
assertTrue(tuples.get(1).getLong("level").equals(new Long(0)));
assertTrue(tuples.get(2).getString("node").equals("jim"));
assertTrue(tuples.get(2).getLong("level").equals(new Long(1)));
assertTrue(tuples.get(3).getString("node").equals("kip"));
assertTrue(tuples.get(3).getLong("level").equals(new Long(2)));
assertTrue(tuples.get(4).getString("node").equals("max"));
assertTrue(tuples.get(4).getLong("level").equals(new Long(1)));
assertTrue(tuples.get(5).getString("node").equals("sam"));
assertTrue(tuples.get(5).getLong("level").equals(new Long(1)));
assertTrue(tuples.get(6).getString("node").equals("steve"));
assertTrue(tuples.get(6).getLong("level").equals(new Long(2)));
//Add a cycle from jim to bill
new UpdateRequest().add(id, "6", "from_s", "jim", "to_s", "bill", "message_t", "Hello steve").add(id, "7", "from_s", "sam", "to_s", "bill", "message_t", "Hello steve").commit(cluster.getSolrClient(), COLLECTION);
expr = "gatherNodes(collection1, " + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\")," + "walk=\"from_s->from_s\"," + "gather=\"to_s\", trackTraversal=\"true\")";
expr2 = "gatherNodes(collection1, " + expr + "," + "walk=\"node->from_s\"," + "gather=\"to_s\", scatter=\"branches, leaves\", trackTraversal=\"true\")";
stream = (GatherNodesStream) factory.constructStream(expr2);
context = new StreamContext();
context.setSolrClientCache(cache);
stream.setStreamContext(context);
tuples = getTuples(stream);
Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
assertTrue(tuples.size() == 7);
assertTrue(tuples.get(0).getString("node").equals("ann"));
assertTrue(tuples.get(0).getLong("level").equals(new Long(2)));
//Bill should now have one ancestor
assertTrue(tuples.get(1).getString("node").equals("bill"));
assertTrue(tuples.get(1).getLong("level").equals(new Long(0)));
assertTrue(tuples.get(1).getStrings("ancestors").size() == 2);
List<String> anc = tuples.get(1).getStrings("ancestors");
Collections.sort(anc);
assertTrue(anc.get(0).equals("jim"));
assertTrue(anc.get(1).equals("sam"));
assertTrue(tuples.get(2).getString("node").equals("jim"));
assertTrue(tuples.get(2).getLong("level").equals(new Long(1)));
assertTrue(tuples.get(3).getString("node").equals("kip"));
assertTrue(tuples.get(3).getLong("level").equals(new Long(2)));
assertTrue(tuples.get(4).getString("node").equals("max"));
assertTrue(tuples.get(4).getLong("level").equals(new Long(1)));
assertTrue(tuples.get(5).getString("node").equals("sam"));
assertTrue(tuples.get(5).getLong("level").equals(new Long(1)));
assertTrue(tuples.get(6).getString("node").equals("steve"));
assertTrue(tuples.get(6).getLong("level").equals(new Long(2)));
cache.close();
}
Aggregations