Search in sources :

Example 1 with HashJoinStream

use of org.apache.solr.client.solrj.io.stream.HashJoinStream in project lucene-solr by apache.

the class GraphExpressionTest method testGatherNodesFriendsStream.

@Test
public void testGatherNodesFriendsStream() throws Exception {
    new UpdateRequest().add(id, "0", "from_s", "bill", "to_s", "jim", "message_t", "Hello jim").add(id, "1", "from_s", "bill", "to_s", "sam", "message_t", "Hello sam").add(id, "2", "from_s", "bill", "to_s", "max", "message_t", "Hello max").add(id, "3", "from_s", "max", "to_s", "kip", "message_t", "Hello kip").add(id, "4", "from_s", "sam", "to_s", "steve", "message_t", "Hello steve").add(id, "5", "from_s", "jim", "to_s", "ann", "message_t", "Hello steve").commit(cluster.getSolrClient(), COLLECTION);
    List<Tuple> tuples = null;
    GatherNodesStream stream = null;
    StreamContext context = new StreamContext();
    SolrClientCache cache = new SolrClientCache();
    context.setSolrClientCache(cache);
    StreamFactory factory = new StreamFactory().withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress()).withFunctionName("gatherNodes", GatherNodesStream.class).withFunctionName("search", CloudSolrStream.class).withFunctionName("count", CountMetric.class).withFunctionName("hashJoin", HashJoinStream.class).withFunctionName("avg", MeanMetric.class).withFunctionName("sum", SumMetric.class).withFunctionName("min", MinMetric.class).withFunctionName("max", MaxMetric.class);
    String expr = "gatherNodes(collection1, " + "walk=\"bill->from_s\"," + "gather=\"to_s\")";
    stream = (GatherNodesStream) factory.constructStream(expr);
    stream.setStreamContext(context);
    tuples = getTuples(stream);
    Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
    assertTrue(tuples.size() == 3);
    assertTrue(tuples.get(0).getString("node").equals("jim"));
    assertTrue(tuples.get(1).getString("node").equals("max"));
    assertTrue(tuples.get(2).getString("node").equals("sam"));
    //Test scatter branches, leaves and trackTraversal
    expr = "gatherNodes(collection1, " + "walk=\"bill->from_s\"," + "gather=\"to_s\"," + "scatter=\"branches, leaves\", trackTraversal=\"true\")";
    stream = (GatherNodesStream) factory.constructStream(expr);
    context = new StreamContext();
    context.setSolrClientCache(cache);
    stream.setStreamContext(context);
    tuples = getTuples(stream);
    Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
    assertTrue(tuples.size() == 4);
    assertTrue(tuples.get(0).getString("node").equals("bill"));
    assertTrue(tuples.get(0).getLong("level").equals(new Long(0)));
    assertTrue(tuples.get(0).getStrings("ancestors").size() == 0);
    assertTrue(tuples.get(1).getString("node").equals("jim"));
    assertTrue(tuples.get(1).getLong("level").equals(new Long(1)));
    List<String> ancestors = tuples.get(1).getStrings("ancestors");
    System.out.println("##################### Ancestors:" + ancestors);
    assert (ancestors.size() == 1);
    assert (ancestors.get(0).equals("bill"));
    assertTrue(tuples.get(2).getString("node").equals("max"));
    assertTrue(tuples.get(2).getLong("level").equals(new Long(1)));
    ancestors = tuples.get(2).getStrings("ancestors");
    assert (ancestors.size() == 1);
    assert (ancestors.get(0).equals("bill"));
    assertTrue(tuples.get(3).getString("node").equals("sam"));
    assertTrue(tuples.get(3).getLong("level").equals(new Long(1)));
    ancestors = tuples.get(3).getStrings("ancestors");
    assert (ancestors.size() == 1);
    assert (ancestors.get(0).equals("bill"));
    // Test query root
    expr = "gatherNodes(collection1, " + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\")," + "walk=\"from_s->from_s\"," + "gather=\"to_s\")";
    stream = (GatherNodesStream) factory.constructStream(expr);
    context = new StreamContext();
    context.setSolrClientCache(cache);
    stream.setStreamContext(context);
    tuples = getTuples(stream);
    Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
    assertTrue(tuples.size() == 3);
    assertTrue(tuples.get(0).getString("node").equals("jim"));
    assertTrue(tuples.get(1).getString("node").equals("max"));
    assertTrue(tuples.get(2).getString("node").equals("sam"));
    // Test query root scatter branches
    expr = "gatherNodes(collection1, " + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\")," + "walk=\"from_s->from_s\"," + "gather=\"to_s\", scatter=\"branches, leaves\")";
    stream = (GatherNodesStream) factory.constructStream(expr);
    context = new StreamContext();
    context.setSolrClientCache(cache);
    stream.setStreamContext(context);
    tuples = getTuples(stream);
    Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
    assertTrue(tuples.size() == 4);
    assertTrue(tuples.get(0).getString("node").equals("bill"));
    assertTrue(tuples.get(0).getLong("level").equals(new Long(0)));
    assertTrue(tuples.get(1).getString("node").equals("jim"));
    assertTrue(tuples.get(1).getLong("level").equals(new Long(1)));
    assertTrue(tuples.get(2).getString("node").equals("max"));
    assertTrue(tuples.get(2).getLong("level").equals(new Long(1)));
    assertTrue(tuples.get(3).getString("node").equals("sam"));
    assertTrue(tuples.get(3).getLong("level").equals(new Long(1)));
    expr = "gatherNodes(collection1, " + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\")," + "walk=\"from_s->from_s\"," + "gather=\"to_s\")";
    String expr2 = "gatherNodes(collection1, " + expr + "," + "walk=\"node->from_s\"," + "gather=\"to_s\")";
    stream = (GatherNodesStream) factory.constructStream(expr2);
    context = new StreamContext();
    context.setSolrClientCache(cache);
    stream.setStreamContext(context);
    tuples = getTuples(stream);
    Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
    assertTrue(tuples.size() == 3);
    assertTrue(tuples.get(0).getString("node").equals("ann"));
    assertTrue(tuples.get(1).getString("node").equals("kip"));
    assertTrue(tuples.get(2).getString("node").equals("steve"));
    //Test two traversals in the same expression
    String expr3 = "hashJoin(" + expr2 + ", hashed=" + expr2 + ", on=\"node\")";
    HashJoinStream hstream = (HashJoinStream) factory.constructStream(expr3);
    context = new StreamContext();
    context.setSolrClientCache(cache);
    hstream.setStreamContext(context);
    tuples = getTuples(hstream);
    Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
    assertTrue(tuples.size() == 3);
    assertTrue(tuples.get(0).getString("node").equals("ann"));
    assertTrue(tuples.get(1).getString("node").equals("kip"));
    assertTrue(tuples.get(2).getString("node").equals("steve"));
    //=================================
    expr = "gatherNodes(collection1, " + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\")," + "walk=\"from_s->from_s\"," + "gather=\"to_s\")";
    expr2 = "gatherNodes(collection1, " + expr + "," + "walk=\"node->from_s\"," + "gather=\"to_s\", scatter=\"branches, leaves\")";
    stream = (GatherNodesStream) factory.constructStream(expr2);
    context = new StreamContext();
    context.setSolrClientCache(cache);
    stream.setStreamContext(context);
    tuples = getTuples(stream);
    Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
    assertTrue(tuples.size() == 7);
    assertTrue(tuples.get(0).getString("node").equals("ann"));
    assertTrue(tuples.get(0).getLong("level").equals(new Long(2)));
    assertTrue(tuples.get(1).getString("node").equals("bill"));
    assertTrue(tuples.get(1).getLong("level").equals(new Long(0)));
    assertTrue(tuples.get(2).getString("node").equals("jim"));
    assertTrue(tuples.get(2).getLong("level").equals(new Long(1)));
    assertTrue(tuples.get(3).getString("node").equals("kip"));
    assertTrue(tuples.get(3).getLong("level").equals(new Long(2)));
    assertTrue(tuples.get(4).getString("node").equals("max"));
    assertTrue(tuples.get(4).getLong("level").equals(new Long(1)));
    assertTrue(tuples.get(5).getString("node").equals("sam"));
    assertTrue(tuples.get(5).getLong("level").equals(new Long(1)));
    assertTrue(tuples.get(6).getString("node").equals("steve"));
    assertTrue(tuples.get(6).getLong("level").equals(new Long(2)));
    //Add a cycle from jim to bill
    new UpdateRequest().add(id, "6", "from_s", "jim", "to_s", "bill", "message_t", "Hello steve").add(id, "7", "from_s", "sam", "to_s", "bill", "message_t", "Hello steve").commit(cluster.getSolrClient(), COLLECTION);
    expr = "gatherNodes(collection1, " + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\")," + "walk=\"from_s->from_s\"," + "gather=\"to_s\", trackTraversal=\"true\")";
    expr2 = "gatherNodes(collection1, " + expr + "," + "walk=\"node->from_s\"," + "gather=\"to_s\", scatter=\"branches, leaves\", trackTraversal=\"true\")";
    stream = (GatherNodesStream) factory.constructStream(expr2);
    context = new StreamContext();
    context.setSolrClientCache(cache);
    stream.setStreamContext(context);
    tuples = getTuples(stream);
    Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
    assertTrue(tuples.size() == 7);
    assertTrue(tuples.get(0).getString("node").equals("ann"));
    assertTrue(tuples.get(0).getLong("level").equals(new Long(2)));
    //Bill should now have one ancestor
    assertTrue(tuples.get(1).getString("node").equals("bill"));
    assertTrue(tuples.get(1).getLong("level").equals(new Long(0)));
    assertTrue(tuples.get(1).getStrings("ancestors").size() == 2);
    List<String> anc = tuples.get(1).getStrings("ancestors");
    Collections.sort(anc);
    assertTrue(anc.get(0).equals("jim"));
    assertTrue(anc.get(1).equals("sam"));
    assertTrue(tuples.get(2).getString("node").equals("jim"));
    assertTrue(tuples.get(2).getLong("level").equals(new Long(1)));
    assertTrue(tuples.get(3).getString("node").equals("kip"));
    assertTrue(tuples.get(3).getLong("level").equals(new Long(2)));
    assertTrue(tuples.get(4).getString("node").equals("max"));
    assertTrue(tuples.get(4).getLong("level").equals(new Long(1)));
    assertTrue(tuples.get(5).getString("node").equals("sam"));
    assertTrue(tuples.get(5).getLong("level").equals(new Long(1)));
    assertTrue(tuples.get(6).getString("node").equals("steve"));
    assertTrue(tuples.get(6).getLong("level").equals(new Long(2)));
    cache.close();
}
Also used : UpdateRequest(org.apache.solr.client.solrj.request.UpdateRequest) StreamContext(org.apache.solr.client.solrj.io.stream.StreamContext) CountMetric(org.apache.solr.client.solrj.io.stream.metrics.CountMetric) MinMetric(org.apache.solr.client.solrj.io.stream.metrics.MinMetric) MeanMetric(org.apache.solr.client.solrj.io.stream.metrics.MeanMetric) StreamFactory(org.apache.solr.client.solrj.io.stream.expr.StreamFactory) SolrClientCache(org.apache.solr.client.solrj.io.SolrClientCache) FieldComparator(org.apache.solr.client.solrj.io.comp.FieldComparator) Tuple(org.apache.solr.client.solrj.io.Tuple) HashJoinStream(org.apache.solr.client.solrj.io.stream.HashJoinStream) Test(org.junit.Test)

Aggregations

SolrClientCache (org.apache.solr.client.solrj.io.SolrClientCache)1 Tuple (org.apache.solr.client.solrj.io.Tuple)1 FieldComparator (org.apache.solr.client.solrj.io.comp.FieldComparator)1 HashJoinStream (org.apache.solr.client.solrj.io.stream.HashJoinStream)1 StreamContext (org.apache.solr.client.solrj.io.stream.StreamContext)1 StreamFactory (org.apache.solr.client.solrj.io.stream.expr.StreamFactory)1 CountMetric (org.apache.solr.client.solrj.io.stream.metrics.CountMetric)1 MeanMetric (org.apache.solr.client.solrj.io.stream.metrics.MeanMetric)1 MinMetric (org.apache.solr.client.solrj.io.stream.metrics.MinMetric)1 UpdateRequest (org.apache.solr.client.solrj.request.UpdateRequest)1 Test (org.junit.Test)1