Search in sources :

Example 6 with TupleStream

use of org.apache.solr.client.solrj.io.stream.TupleStream in project lucene-solr by apache.

the class GraphExpressionTest method testScoreNodesStream.

@Test
public void testScoreNodesStream() throws Exception {
    new UpdateRequest().add(id, "0", "basket_s", "basket1", "product_s", "product1", "price_f", "1").add(id, "1", "basket_s", "basket1", "product_s", "product3", "price_f", "1").add(id, "2", "basket_s", "basket1", "product_s", "product5", "price_f", "100").add(id, "3", "basket_s", "basket2", "product_s", "product1", "price_f", "1").add(id, "4", "basket_s", "basket2", "product_s", "product6", "price_f", "1").add(id, "5", "basket_s", "basket2", "product_s", "product7", "price_f", "1").add(id, "6", "basket_s", "basket3", "product_s", "product4", "price_f", "1").add(id, "7", "basket_s", "basket3", "product_s", "product3", "price_f", "1").add(id, "8", "basket_s", "basket3", "product_s", "product1", "price_f", "1").add(id, "9", "basket_s", "basket4", "product_s", "product4", "price_f", "1").add(id, "10", "basket_s", "basket4", "product_s", "product3", "price_f", "1").add(id, "11", "basket_s", "basket4", "product_s", "product1", "price_f", "1").add(id, "12", "basket_s", "basket5", "product_s", "product1", "price_f", "1").add(id, "13", "basket_s", "basket6", "product_s", "product1", "price_f", "1").add(id, "14", "basket_s", "basket7", "product_s", "product1", "price_f", "1").add(id, "15", "basket_s", "basket4", "product_s", "product1", "price_f", "1").commit(cluster.getSolrClient(), COLLECTION);
    List<Tuple> tuples = null;
    TupleStream stream = null;
    StreamContext context = new StreamContext();
    SolrClientCache cache = new SolrClientCache();
    context.setSolrClientCache(cache);
    StreamFactory factory = new StreamFactory().withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress()).withDefaultZkHost(cluster.getZkServer().getZkAddress()).withFunctionName("gatherNodes", GatherNodesStream.class).withFunctionName("scoreNodes", ScoreNodesStream.class).withFunctionName("search", CloudSolrStream.class).withFunctionName("sort", SortStream.class).withFunctionName("count", CountMetric.class).withFunctionName("avg", MeanMetric.class).withFunctionName("sum", SumMetric.class).withFunctionName("min", MinMetric.class).withFunctionName("max", MaxMetric.class);
    String expr = "gatherNodes(collection1, " + "walk=\"product3->product_s\"," + "gather=\"basket_s\")";
    String expr2 = "sort(by=\"nodeScore desc\", " + "scoreNodes(gatherNodes(collection1, " + expr + "," + "walk=\"node->basket_s\"," + "gather=\"product_s\", " + "count(*), " + "avg(price_f), " + "sum(price_f), " + "min(price_f), " + "max(price_f))))";
    stream = factory.constructStream(expr2);
    context = new StreamContext();
    context.setSolrClientCache(cache);
    stream.setStreamContext(context);
    tuples = getTuples(stream);
    Tuple tuple0 = tuples.get(0);
    assert (tuple0.getString("node").equals("product4"));
    assert (tuple0.getLong("docFreq") == 2);
    assert (tuple0.getLong("count(*)") == 2);
    Tuple tuple1 = tuples.get(1);
    assert (tuple1.getString("node").equals("product1"));
    assert (tuple1.getLong("docFreq") == 8);
    assert (tuple1.getLong("count(*)") == 3);
    Tuple tuple2 = tuples.get(2);
    assert (tuple2.getString("node").equals("product5"));
    assert (tuple2.getLong("docFreq") == 1);
    assert (tuple2.getLong("count(*)") == 1);
    //Test using a different termFreq field then the default count(*)
    expr2 = "sort(by=\"nodeScore desc\", " + "scoreNodes(termFreq=\"avg(price_f)\",gatherNodes(collection1, " + expr + "," + "walk=\"node->basket_s\"," + "gather=\"product_s\", " + "count(*), " + "avg(price_f), " + "sum(price_f), " + "min(price_f), " + "max(price_f))))";
    stream = factory.constructStream(expr2);
    context = new StreamContext();
    context.setSolrClientCache(cache);
    stream.setStreamContext(context);
    tuples = getTuples(stream);
    tuple0 = tuples.get(0);
    assert (tuple0.getString("node").equals("product5"));
    assert (tuple0.getLong("docFreq") == 1);
    assert (tuple0.getDouble("avg(price_f)") == 100);
    tuple1 = tuples.get(1);
    assert (tuple1.getString("node").equals("product4"));
    assert (tuple1.getLong("docFreq") == 2);
    assert (tuple1.getDouble("avg(price_f)") == 1);
    tuple2 = tuples.get(2);
    assert (tuple2.getString("node").equals("product1"));
    assert (tuple2.getLong("docFreq") == 8);
    assert (tuple2.getDouble("avg(price_f)") == 1);
    cache.close();
}
Also used : SortStream(org.apache.solr.client.solrj.io.stream.SortStream) ScoreNodesStream(org.apache.solr.client.solrj.io.stream.ScoreNodesStream) TupleStream(org.apache.solr.client.solrj.io.stream.TupleStream) MeanMetric(org.apache.solr.client.solrj.io.stream.metrics.MeanMetric) UpdateRequest(org.apache.solr.client.solrj.request.UpdateRequest) StreamContext(org.apache.solr.client.solrj.io.stream.StreamContext) StreamFactory(org.apache.solr.client.solrj.io.stream.expr.StreamFactory) SolrClientCache(org.apache.solr.client.solrj.io.SolrClientCache) MinMetric(org.apache.solr.client.solrj.io.stream.metrics.MinMetric) Tuple(org.apache.solr.client.solrj.io.Tuple) Test(org.junit.Test)

Example 7 with TupleStream

use of org.apache.solr.client.solrj.io.stream.TupleStream in project lucene-solr by apache.

the class GraphMLResponseWriter method write.

public void write(Writer writer, SolrQueryRequest req, SolrQueryResponse res) throws IOException {
    Exception e1 = res.getException();
    if (e1 != null) {
        e1.printStackTrace(new PrintWriter(writer));
        return;
    }
    TupleStream stream = (TupleStream) req.getContext().get("stream");
    if (stream instanceof GraphHandler.DummyErrorStream) {
        GraphHandler.DummyErrorStream d = (GraphHandler.DummyErrorStream) stream;
        Exception e = d.getException();
        e.printStackTrace(new PrintWriter(writer));
        return;
    }
    Traversal traversal = (Traversal) req.getContext().get("traversal");
    PrintWriter printWriter = new PrintWriter(writer);
    try {
        stream.open();
        Tuple tuple = null;
        int edgeCount = 0;
        printWriter.println("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
        printWriter.println("<graphml xmlns=\"http://graphml.graphdrawing.org/xmlns\" ");
        printWriter.println("xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\" ");
        printWriter.print("xsi:schemaLocation=\"http://graphml.graphdrawing.org/xmlns ");
        printWriter.println("http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd\">");
        printWriter.println("<graph id=\"G\" edgedefault=\"directed\">");
        while (true) {
            //Output the graph
            tuple = stream.read();
            if (tuple.EOF) {
                break;
            }
            String id = tuple.getString("node");
            if (traversal.isMultiCollection()) {
                id = tuple.getString("collection") + "." + id;
            }
            writer.write("<node id=\"" + replace(id) + "\"");
            List<String> outfields = new ArrayList();
            Iterator<String> keys = tuple.fields.keySet().iterator();
            while (keys.hasNext()) {
                String key = keys.next();
                if (key.equals("node") || key.equals("ancestors") || key.equals("collection")) {
                    continue;
                } else {
                    outfields.add(key);
                }
            }
            if (outfields.size() > 0) {
                printWriter.println(">");
                for (String nodeAttribute : outfields) {
                    Object o = tuple.get(nodeAttribute);
                    if (o != null) {
                        printWriter.println("<data key=\"" + nodeAttribute + "\">" + o.toString() + "</data>");
                    }
                }
                printWriter.println("</node>");
            } else {
                printWriter.println("/>");
            }
            List<String> ancestors = tuple.getStrings("ancestors");
            if (ancestors != null) {
                for (String ancestor : ancestors) {
                    ++edgeCount;
                    writer.write("<edge id=\"" + edgeCount + "\" ");
                    writer.write(" source=\"" + replace(ancestor) + "\" ");
                    printWriter.println(" target=\"" + replace(id) + "\"/>");
                }
            }
        }
        writer.write("</graph></graphml>");
    } finally {
        stream.close();
    }
}
Also used : GraphHandler(org.apache.solr.handler.GraphHandler) ArrayList(java.util.ArrayList) Traversal(org.apache.solr.client.solrj.io.graph.Traversal) IOException(java.io.IOException) TupleStream(org.apache.solr.client.solrj.io.stream.TupleStream) Tuple(org.apache.solr.client.solrj.io.Tuple) PrintWriter(java.io.PrintWriter)

Aggregations

TupleStream (org.apache.solr.client.solrj.io.stream.TupleStream)7 Tuple (org.apache.solr.client.solrj.io.Tuple)5 StreamContext (org.apache.solr.client.solrj.io.stream.StreamContext)3 StreamFactory (org.apache.solr.client.solrj.io.stream.expr.StreamFactory)3 Test (org.junit.Test)3 IOException (java.io.IOException)2 ArrayList (java.util.ArrayList)2 Map (java.util.Map)2 SolrClientCache (org.apache.solr.client.solrj.io.SolrClientCache)2 Traversal (org.apache.solr.client.solrj.io.graph.Traversal)2 SortStream (org.apache.solr.client.solrj.io.stream.SortStream)2 MeanMetric (org.apache.solr.client.solrj.io.stream.metrics.MeanMetric)2 MinMetric (org.apache.solr.client.solrj.io.stream.metrics.MinMetric)2 UpdateRequest (org.apache.solr.client.solrj.request.UpdateRequest)2 ModifiableSolrParams (org.apache.solr.common.params.ModifiableSolrParams)2 ByteArrayInputStream (java.io.ByteArrayInputStream)1 PrintWriter (java.io.PrintWriter)1 StringWriter (java.io.StringWriter)1 SQLException (java.sql.SQLException)1 HashMap (java.util.HashMap)1