use of org.opensearch.index.rankeval.PrecisionAtK in project OpenSearch by opensearch-project.
the class RequestConvertersTests method testRankEval.
public void testRankEval() throws Exception {
RankEvalSpec spec = new RankEvalSpec(Collections.singletonList(new RatedRequest("queryId", Collections.emptyList(), new SearchSourceBuilder())), new PrecisionAtK());
String[] indices = randomIndicesNames(0, 5);
RankEvalRequest rankEvalRequest = new RankEvalRequest(spec, indices);
Map<String, String> expectedParams = new HashMap<>();
setRandomIndicesOptions(rankEvalRequest::indicesOptions, rankEvalRequest::indicesOptions, expectedParams);
if (randomBoolean()) {
rankEvalRequest.searchType(randomFrom(SearchType.CURRENTLY_SUPPORTED));
}
expectedParams.put("search_type", rankEvalRequest.searchType().name().toLowerCase(Locale.ROOT));
Request request = RequestConverters.rankEval(rankEvalRequest);
StringJoiner endpoint = new StringJoiner("/", "/", "");
String index = String.join(",", indices);
if (Strings.hasLength(index)) {
endpoint.add(index);
}
endpoint.add(RestRankEvalAction.ENDPOINT);
assertEquals(endpoint.toString(), request.getEndpoint());
assertEquals(5, request.getParameters().size());
assertEquals(expectedParams, request.getParameters());
assertToXContentBody(spec, request.getEntity());
}
use of org.opensearch.index.rankeval.PrecisionAtK in project OpenSearch by opensearch-project.
the class SearchDocumentationIT method testRankEval.
public void testRankEval() throws Exception {
indexSearchTestData();
RestHighLevelClient client = highLevelClient();
{
// tag::rank-eval-request-basic
// <1>
EvaluationMetric metric = new PrecisionAtK();
List<RatedDocument> ratedDocs = new ArrayList<>();
// <2>
ratedDocs.add(new RatedDocument("posts", "1", 1));
SearchSourceBuilder searchQuery = new SearchSourceBuilder();
// <3>
searchQuery.query(QueryBuilders.matchQuery("user", "foobar"));
// <4>
RatedRequest ratedRequest = new RatedRequest("foobar_query", ratedDocs, searchQuery);
List<RatedRequest> ratedRequests = Arrays.asList(ratedRequest);
RankEvalSpec specification = // <5>
new RankEvalSpec(ratedRequests, metric);
// <6>
RankEvalRequest request = new RankEvalRequest(specification, new String[] { "posts" });
// end::rank-eval-request-basic
// tag::rank-eval-execute
RankEvalResponse response = client.rankEval(request, RequestOptions.DEFAULT);
// end::rank-eval-execute
// tag::rank-eval-response
// <1>
double evaluationResult = response.getMetricScore();
assertEquals(1.0 / 3.0, evaluationResult, 0.0);
Map<String, EvalQueryQuality> partialResults = response.getPartialResults();
EvalQueryQuality evalQuality = // <2>
partialResults.get("foobar_query");
assertEquals("foobar_query", evalQuality.getId());
// <3>
double qualityLevel = evalQuality.metricScore();
assertEquals(1.0 / 3.0, qualityLevel, 0.0);
List<RatedSearchHit> hitsAndRatings = evalQuality.getHitsAndRatings();
RatedSearchHit ratedSearchHit = hitsAndRatings.get(2);
// <4>
assertEquals("3", ratedSearchHit.getSearchHit().getId());
// <5>
assertFalse(ratedSearchHit.getRating().isPresent());
MetricDetail metricDetails = evalQuality.getMetricDetails();
String metricName = metricDetails.getMetricName();
// <6>
assertEquals(PrecisionAtK.NAME, metricName);
PrecisionAtK.Detail detail = (PrecisionAtK.Detail) metricDetails;
// <7>
assertEquals(1, detail.getRelevantRetrieved());
assertEquals(3, detail.getRetrieved());
// end::rank-eval-response
// tag::rank-eval-execute-listener
ActionListener<RankEvalResponse> listener = new ActionListener<RankEvalResponse>() {
@Override
public void onResponse(RankEvalResponse response) {
// <1>
}
@Override
public void onFailure(Exception e) {
// <2>
}
};
// end::rank-eval-execute-listener
// Replace the empty listener by a blocking listener in test
final CountDownLatch latch = new CountDownLatch(1);
listener = new LatchedActionListener<>(listener, latch);
// tag::rank-eval-execute-async
// <1>
client.rankEvalAsync(request, RequestOptions.DEFAULT, listener);
// end::rank-eval-execute-async
assertTrue(latch.await(30L, TimeUnit.SECONDS));
}
}
use of org.opensearch.index.rankeval.PrecisionAtK in project OpenSearch by opensearch-project.
the class RankEvalIT method testMetrics.
/**
* Test case checks that the default metrics are registered and usable
*/
public void testMetrics() throws IOException {
List<RatedRequest> specifications = createTestEvaluationSpec();
List<Supplier<EvaluationMetric>> metrics = Arrays.asList(PrecisionAtK::new, RecallAtK::new, MeanReciprocalRank::new, DiscountedCumulativeGain::new, () -> new ExpectedReciprocalRank(1));
double[] expectedScores = new double[] { 0.4285714285714286, 1.0, 0.75, 1.6408962261063627, 0.4407738095238095 };
int i = 0;
for (Supplier<EvaluationMetric> metricSupplier : metrics) {
RankEvalSpec spec = new RankEvalSpec(specifications, metricSupplier.get());
RankEvalRequest rankEvalRequest = new RankEvalRequest(spec, new String[] { "index", "index2" });
RankEvalResponse response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync);
assertEquals(expectedScores[i], response.getMetricScore(), Double.MIN_VALUE);
i++;
}
}
use of org.opensearch.index.rankeval.PrecisionAtK in project OpenSearch by opensearch-project.
the class RankEvalIT method testRankEvalRequest.
/**
* Test cases retrieves all six documents indexed above and checks the Prec@10
* calculation where all unlabeled documents are treated as not relevant.
*/
public void testRankEvalRequest() throws IOException {
List<RatedRequest> specifications = createTestEvaluationSpec();
PrecisionAtK metric = new PrecisionAtK(1, false, 10);
RankEvalSpec spec = new RankEvalSpec(specifications, metric);
RankEvalRequest rankEvalRequest = new RankEvalRequest(spec, new String[] { "index", "index2" });
RankEvalResponse response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync);
// the expected Prec@ for the first query is 5/7 and the expected Prec@ for the second is 1/7, divided by 2 to get the average
double expectedPrecision = (1.0 / 7.0 + 5.0 / 7.0) / 2.0;
assertEquals(expectedPrecision, response.getMetricScore(), Double.MIN_VALUE);
Map<String, EvalQueryQuality> partialResults = response.getPartialResults();
assertEquals(2, partialResults.size());
EvalQueryQuality amsterdamQueryQuality = partialResults.get("amsterdam_query");
assertEquals(2, filterUnratedDocuments(amsterdamQueryQuality.getHitsAndRatings()).size());
List<RatedSearchHit> hitsAndRatings = amsterdamQueryQuality.getHitsAndRatings();
assertEquals(7, hitsAndRatings.size());
for (RatedSearchHit hit : hitsAndRatings) {
String id = hit.getSearchHit().getId();
if (id.equals("berlin") || id.equals("amsterdam5")) {
assertFalse(hit.getRating().isPresent());
} else {
assertEquals(1, hit.getRating().getAsInt());
}
}
EvalQueryQuality berlinQueryQuality = partialResults.get("berlin_query");
assertEquals(6, filterUnratedDocuments(berlinQueryQuality.getHitsAndRatings()).size());
hitsAndRatings = berlinQueryQuality.getHitsAndRatings();
assertEquals(7, hitsAndRatings.size());
for (RatedSearchHit hit : hitsAndRatings) {
String id = hit.getSearchHit().getId();
if (id.equals("berlin")) {
assertEquals(1, hit.getRating().getAsInt());
} else {
assertFalse(hit.getRating().isPresent());
}
}
// now try this when test2 is closed
client().performRequest(new Request("POST", "index2/_close"));
rankEvalRequest.indicesOptions(IndicesOptions.fromParameters(null, "true", null, "false", SearchRequest.DEFAULT_INDICES_OPTIONS));
response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync);
}
Aggregations