use of org.opensearch.ml.common.dataset.SearchQueryInputDataset in project ml-commons by opensearch-project.
the class PredictionITTests method testPredictionWithSearchInput.
public void testPredictionWithSearchInput() throws IOException {
SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder();
MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder);
predictAndVerifyResult(taskId, inputDataset);
}
use of org.opensearch.ml.common.dataset.SearchQueryInputDataset in project ml-commons by opensearch-project.
the class MLInputTests method testParseKmeansInputQuery.
public void testParseKmeansInputQuery() throws IOException {
String query = "{\"input_query\":{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"k1\":1}}]}},\"size\":10},\"input_index\":[\"test_data\"]}";
XContentParser parser = parser(query);
MLInput mlInput = MLInput.parse(parser, FunctionName.KMEANS.name());
String expectedQuery = "{\"size\":10,\"query\":{\"bool\":{\"filter\":[{\"term\":{\"k1\":{\"value\":1,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}}}";
SearchQueryInputDataset inputDataset = (SearchQueryInputDataset) mlInput.getInputDataset();
assertEquals(expectedQuery, inputDataset.getSearchSourceBuilder().toString());
}
use of org.opensearch.ml.common.dataset.SearchQueryInputDataset in project ml-commons by opensearch-project.
the class TrainingITTests method testTrainingWithSearchInput.
@Ignore("This test case is flaky, something is off with waitModelAvailable(taskId) method." + " This issue will be tracked in an issue and will be fixed later")
public void testTrainingWithSearchInput() throws ExecutionException, InterruptedException {
SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder();
MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder);
String taskId = trainModel(inputDataset);
waitModelAvailable(taskId);
}
use of org.opensearch.ml.common.dataset.SearchQueryInputDataset in project ml-commons by opensearch-project.
the class TrainingITTests method testTrainingWithEmptyDataset.
// Train a model with empty dataset.
public void testTrainingWithEmptyDataset() {
SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder();
searchSourceBuilder.query(QueryBuilders.matchQuery("noSuchName", ""));
MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput, false);
expectThrows(IllegalArgumentException.class, () -> client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest).actionGet());
}
use of org.opensearch.ml.common.dataset.SearchQueryInputDataset in project ml-commons by opensearch-project.
the class MLInputDatasetHandlerTests method testSearchQueryInputDatasetWithHits.
@SuppressWarnings("unchecked")
public void testSearchQueryInputDatasetWithHits() {
searchResponse = mock(SearchResponse.class);
BytesReference bytesArray = new BytesArray("{\"taskId\":\"111\"}");
SearchHit hit = new SearchHit(1);
hit.sourceRef(bytesArray);
SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f);
when(searchResponse.getHits()).thenReturn(hits);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocation.getArguments()[1];
listener.onResponse(searchResponse);
return null;
}).when(client).search(any(), any());
SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder().indices(Collections.singletonList("index1")).searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())).build();
mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener);
ArgumentCaptor<DataFrame> captor = ArgumentCaptor.forClass(DataFrame.class);
verify(listener, times(1)).onResponse(captor.capture());
Assert.assertEquals(captor.getAllValues().size(), 1);
}
Aggregations