Search in sources :

Example 1 with ClassificationEvaluation

use of org.apache.solr.client.solrj.io.ClassificationEvaluation in project lucene-solr by apache.

the class StreamExpressionTest method testBasicTextLogitStream.

@Test
public void testBasicTextLogitStream() throws Exception {
    Assume.assumeTrue(!useAlias);
    CollectionAdminRequest.createCollection("destinationCollection", "ml", 2, 1).process(cluster.getSolrClient());
    AbstractDistribZkTestBase.waitForRecoveriesToFinish("destinationCollection", cluster.getSolrClient().getZkStateReader(), false, true, TIMEOUT);
    UpdateRequest updateRequest = new UpdateRequest();
    for (int i = 0; i < 5000; i += 2) {
        updateRequest.add(id, String.valueOf(i), "tv_text", "a b c c d", "out_i", "1");
        updateRequest.add(id, String.valueOf(i + 1), "tv_text", "a b e e f", "out_i", "0");
    }
    updateRequest.commit(cluster.getSolrClient(), COLLECTIONORALIAS);
    StreamExpression expression;
    TupleStream stream;
    List<Tuple> tuples;
    StreamContext streamContext = new StreamContext();
    SolrClientCache solrClientCache = new SolrClientCache();
    streamContext.setSolrClientCache(solrClientCache);
    StreamFactory factory = new StreamFactory().withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress()).withCollectionZkHost("destinationCollection", cluster.getZkServer().getZkAddress()).withFunctionName("features", FeaturesSelectionStream.class).withFunctionName("train", TextLogitStream.class).withFunctionName("search", CloudSolrStream.class).withFunctionName("update", UpdateStream.class);
    try {
        expression = StreamExpressionParser.parse("features(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4)");
        stream = new FeaturesSelectionStream(expression, factory);
        stream.setStreamContext(streamContext);
        tuples = getTuples(stream);
        assert (tuples.size() == 4);
        HashSet<String> terms = new HashSet<>();
        for (Tuple tuple : tuples) {
            terms.add((String) tuple.get("term_s"));
        }
        assertTrue(terms.contains("d"));
        assertTrue(terms.contains("c"));
        assertTrue(terms.contains("e"));
        assertTrue(terms.contains("f"));
        String textLogitExpression = "train(" + "collection1, " + "features(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4)," + "q=\"*:*\", " + "name=\"model\", " + "field=\"tv_text\", " + "outcome=\"out_i\", " + "maxIterations=100)";
        stream = factory.constructStream(textLogitExpression);
        stream.setStreamContext(streamContext);
        tuples = getTuples(stream);
        Tuple lastTuple = tuples.get(tuples.size() - 1);
        List<Double> lastWeights = lastTuple.getDoubles("weights_ds");
        Double[] lastWeightsArray = lastWeights.toArray(new Double[lastWeights.size()]);
        // first feature is bias value
        Double[] testRecord = { 1.0, 1.17, 0.691, 0.0, 0.0 };
        double d = sum(multiply(testRecord, lastWeightsArray));
        double prob = sigmoid(d);
        assertEquals(prob, 1.0, 0.1);
        // first feature is bias value
        Double[] testRecord2 = { 1.0, 0.0, 0.0, 1.17, 0.691 };
        d = sum(multiply(testRecord2, lastWeightsArray));
        prob = sigmoid(d);
        assertEquals(prob, 0, 0.1);
        stream = factory.constructStream("update(destinationCollection, batchSize=5, " + textLogitExpression + ")");
        getTuples(stream);
        cluster.getSolrClient().commit("destinationCollection");
        stream = factory.constructStream("search(destinationCollection, " + "q=*:*, " + "fl=\"iteration_i,* \", " + "rows=100, " + "sort=\"iteration_i desc\")");
        stream.setStreamContext(streamContext);
        tuples = getTuples(stream);
        assertEquals(100, tuples.size());
        Tuple lastModel = tuples.get(0);
        ClassificationEvaluation evaluation = ClassificationEvaluation.create(lastModel.fields);
        assertTrue(evaluation.getF1() >= 1.0);
        assertEquals(Math.log(5000.0 / (2500 + 1)), lastModel.getDoubles("idfs_ds").get(0), 0.0001);
        // make sure the tuples is retrieved in correct order
        Tuple firstTuple = tuples.get(99);
        assertEquals(1L, (long) firstTuple.getLong("iteration_i"));
    } finally {
        CollectionAdminRequest.deleteCollection("destinationCollection").process(cluster.getSolrClient());
        solrClientCache.close();
    }
}
Also used : UpdateRequest(org.apache.solr.client.solrj.request.UpdateRequest) ClassificationEvaluation(org.apache.solr.client.solrj.io.ClassificationEvaluation) StreamExpression(org.apache.solr.client.solrj.io.stream.expr.StreamExpression) StreamFactory(org.apache.solr.client.solrj.io.stream.expr.StreamFactory) SolrClientCache(org.apache.solr.client.solrj.io.SolrClientCache) Tuple(org.apache.solr.client.solrj.io.Tuple) HashSet(java.util.HashSet) Test(org.junit.Test)

Example 2 with ClassificationEvaluation

use of org.apache.solr.client.solrj.io.ClassificationEvaluation in project lucene-solr by apache.

the class TextLogitStream method read.

public Tuple read() throws IOException {
    try {
        if (++iteration > maxIterations) {
            Map map = new HashMap();
            map.put("EOF", true);
            return new Tuple(map);
        } else {
            if (this.idfs == null) {
                loadTerms();
                if (weights != null && terms.size() + 1 != weights.size()) {
                    throw new IOException(String.format(Locale.ROOT, "invalid expression %s - the number of weights must be %d, found %d", terms.size() + 1, weights.size()));
                }
            }
            List<List<Double>> allWeights = new ArrayList();
            this.evaluation = new ClassificationEvaluation();
            this.error = 0;
            for (Future<Tuple> logitCall : callShards(getShardUrls())) {
                Tuple tuple = logitCall.get();
                List<Double> shardWeights = (List<Double>) tuple.get("weights");
                allWeights.add(shardWeights);
                this.error += tuple.getDouble("error");
                Map shardEvaluation = (Map) tuple.get("evaluation");
                this.evaluation.addEvaluation(shardEvaluation);
            }
            this.weights = averageWeights(allWeights);
            Map map = new HashMap();
            map.put(ID, name + "_" + iteration);
            map.put("name_s", name);
            map.put("field_s", field);
            map.put("terms_ss", terms);
            map.put("iteration_i", iteration);
            if (weights != null) {
                map.put("weights_ds", weights);
            }
            map.put("error_d", error);
            evaluation.putToMap(map);
            map.put("alpha_d", this.learningRate);
            map.put("idfs_ds", this.idfs);
            if (iteration != 1) {
                if (lastError <= error) {
                    this.learningRate *= 0.5;
                } else {
                    this.learningRate *= 1.05;
                }
            }
            lastError = error;
            return new Tuple(map);
        }
    } catch (Exception e) {
        throw new IOException(e);
    }
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) ClassificationEvaluation(org.apache.solr.client.solrj.io.ClassificationEvaluation) ArrayList(java.util.ArrayList) NamedList(org.apache.solr.common.util.NamedList) List(java.util.List) IOException(java.io.IOException) HashMap(java.util.HashMap) Map(java.util.Map) Tuple(org.apache.solr.client.solrj.io.Tuple) IOException(java.io.IOException)

Aggregations

ClassificationEvaluation (org.apache.solr.client.solrj.io.ClassificationEvaluation)2 Tuple (org.apache.solr.client.solrj.io.Tuple)2 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 HashSet (java.util.HashSet)1 List (java.util.List)1 Map (java.util.Map)1 SolrClientCache (org.apache.solr.client.solrj.io.SolrClientCache)1 StreamExpression (org.apache.solr.client.solrj.io.stream.expr.StreamExpression)1 StreamFactory (org.apache.solr.client.solrj.io.stream.expr.StreamFactory)1 UpdateRequest (org.apache.solr.client.solrj.request.UpdateRequest)1 NamedList (org.apache.solr.common.util.NamedList)1 Test (org.junit.Test)1