use of org.apache.solr.client.solrj.io.ClassificationEvaluation in project lucene-solr by apache.
the class StreamExpressionTest method testBasicTextLogitStream.
@Test
public void testBasicTextLogitStream() throws Exception {
Assume.assumeTrue(!useAlias);
CollectionAdminRequest.createCollection("destinationCollection", "ml", 2, 1).process(cluster.getSolrClient());
AbstractDistribZkTestBase.waitForRecoveriesToFinish("destinationCollection", cluster.getSolrClient().getZkStateReader(), false, true, TIMEOUT);
UpdateRequest updateRequest = new UpdateRequest();
for (int i = 0; i < 5000; i += 2) {
updateRequest.add(id, String.valueOf(i), "tv_text", "a b c c d", "out_i", "1");
updateRequest.add(id, String.valueOf(i + 1), "tv_text", "a b e e f", "out_i", "0");
}
updateRequest.commit(cluster.getSolrClient(), COLLECTIONORALIAS);
StreamExpression expression;
TupleStream stream;
List<Tuple> tuples;
StreamContext streamContext = new StreamContext();
SolrClientCache solrClientCache = new SolrClientCache();
streamContext.setSolrClientCache(solrClientCache);
StreamFactory factory = new StreamFactory().withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress()).withCollectionZkHost("destinationCollection", cluster.getZkServer().getZkAddress()).withFunctionName("features", FeaturesSelectionStream.class).withFunctionName("train", TextLogitStream.class).withFunctionName("search", CloudSolrStream.class).withFunctionName("update", UpdateStream.class);
try {
expression = StreamExpressionParser.parse("features(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4)");
stream = new FeaturesSelectionStream(expression, factory);
stream.setStreamContext(streamContext);
tuples = getTuples(stream);
assert (tuples.size() == 4);
HashSet<String> terms = new HashSet<>();
for (Tuple tuple : tuples) {
terms.add((String) tuple.get("term_s"));
}
assertTrue(terms.contains("d"));
assertTrue(terms.contains("c"));
assertTrue(terms.contains("e"));
assertTrue(terms.contains("f"));
String textLogitExpression = "train(" + "collection1, " + "features(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4)," + "q=\"*:*\", " + "name=\"model\", " + "field=\"tv_text\", " + "outcome=\"out_i\", " + "maxIterations=100)";
stream = factory.constructStream(textLogitExpression);
stream.setStreamContext(streamContext);
tuples = getTuples(stream);
Tuple lastTuple = tuples.get(tuples.size() - 1);
List<Double> lastWeights = lastTuple.getDoubles("weights_ds");
Double[] lastWeightsArray = lastWeights.toArray(new Double[lastWeights.size()]);
// first feature is bias value
Double[] testRecord = { 1.0, 1.17, 0.691, 0.0, 0.0 };
double d = sum(multiply(testRecord, lastWeightsArray));
double prob = sigmoid(d);
assertEquals(prob, 1.0, 0.1);
// first feature is bias value
Double[] testRecord2 = { 1.0, 0.0, 0.0, 1.17, 0.691 };
d = sum(multiply(testRecord2, lastWeightsArray));
prob = sigmoid(d);
assertEquals(prob, 0, 0.1);
stream = factory.constructStream("update(destinationCollection, batchSize=5, " + textLogitExpression + ")");
getTuples(stream);
cluster.getSolrClient().commit("destinationCollection");
stream = factory.constructStream("search(destinationCollection, " + "q=*:*, " + "fl=\"iteration_i,* \", " + "rows=100, " + "sort=\"iteration_i desc\")");
stream.setStreamContext(streamContext);
tuples = getTuples(stream);
assertEquals(100, tuples.size());
Tuple lastModel = tuples.get(0);
ClassificationEvaluation evaluation = ClassificationEvaluation.create(lastModel.fields);
assertTrue(evaluation.getF1() >= 1.0);
assertEquals(Math.log(5000.0 / (2500 + 1)), lastModel.getDoubles("idfs_ds").get(0), 0.0001);
// make sure the tuples is retrieved in correct order
Tuple firstTuple = tuples.get(99);
assertEquals(1L, (long) firstTuple.getLong("iteration_i"));
} finally {
CollectionAdminRequest.deleteCollection("destinationCollection").process(cluster.getSolrClient());
solrClientCache.close();
}
}
use of org.apache.solr.client.solrj.io.ClassificationEvaluation in project lucene-solr by apache.
the class TextLogitStream method read.
public Tuple read() throws IOException {
try {
if (++iteration > maxIterations) {
Map map = new HashMap();
map.put("EOF", true);
return new Tuple(map);
} else {
if (this.idfs == null) {
loadTerms();
if (weights != null && terms.size() + 1 != weights.size()) {
throw new IOException(String.format(Locale.ROOT, "invalid expression %s - the number of weights must be %d, found %d", terms.size() + 1, weights.size()));
}
}
List<List<Double>> allWeights = new ArrayList();
this.evaluation = new ClassificationEvaluation();
this.error = 0;
for (Future<Tuple> logitCall : callShards(getShardUrls())) {
Tuple tuple = logitCall.get();
List<Double> shardWeights = (List<Double>) tuple.get("weights");
allWeights.add(shardWeights);
this.error += tuple.getDouble("error");
Map shardEvaluation = (Map) tuple.get("evaluation");
this.evaluation.addEvaluation(shardEvaluation);
}
this.weights = averageWeights(allWeights);
Map map = new HashMap();
map.put(ID, name + "_" + iteration);
map.put("name_s", name);
map.put("field_s", field);
map.put("terms_ss", terms);
map.put("iteration_i", iteration);
if (weights != null) {
map.put("weights_ds", weights);
}
map.put("error_d", error);
evaluation.putToMap(map);
map.put("alpha_d", this.learningRate);
map.put("idfs_ds", this.idfs);
if (iteration != 1) {
if (lastError <= error) {
this.learningRate *= 0.5;
} else {
this.learningRate *= 1.05;
}
}
lastError = error;
return new Tuple(map);
}
} catch (Exception e) {
throw new IOException(e);
}
}
Aggregations