use of org.opensearch.ml.common.dataset.SearchQueryInputDataset in project ml-commons by opensearch-project.
the class MLInputDatasetHandler method parseSearchQueryInput.
/**
* Create DataFrame based on given search query
* @param mlInputDataset MLInputDataset
* @param listener ActionListener
*/
public void parseSearchQueryInput(MLInputDataset mlInputDataset, ActionListener<DataFrame> listener) {
if (!mlInputDataset.getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
throw new IllegalArgumentException("Input dataset is not SEARCH_QUERY type.");
}
SearchQueryInputDataset inputDataset = (SearchQueryInputDataset) mlInputDataset;
SearchRequest searchRequest = new SearchRequest();
searchRequest.source(inputDataset.getSearchSourceBuilder());
List<String> indicesList = inputDataset.getIndices();
String[] indices = new String[indicesList.size()];
indices = indicesList.toArray(indices);
searchRequest.indices(indices);
client.search(searchRequest, ActionListener.wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
listener.onFailure(new IllegalArgumentException("No document found"));
return;
}
SearchHits hits = r.getHits();
List<Map<String, Object>> input = new ArrayList<>();
SearchHit[] searchHits = hits.getHits();
for (SearchHit hit : searchHits) {
input.add(hit.getSourceAsMap());
}
DataFrame dataFrame = DataFrameBuilder.load(input);
listener.onResponse(dataFrame);
return;
}, e -> {
log.error("Failed to search" + e);
listener.onFailure(e);
}));
return;
}
use of org.opensearch.ml.common.dataset.SearchQueryInputDataset in project ml-commons by opensearch-project.
the class RestActionUtils method buildSearchQueryInput.
/**
* Create SearchQueryInputDataset from a RestRequest
*
* @param request RestRequest
* @param client node client
* @return SearchQueryInputDataset with indices and search source
* @throws IOException throw IOException when fail to parse search request
*/
public static SearchQueryInputDataset buildSearchQueryInput(RestRequest request, NodeClient client) throws IOException {
SearchRequest searchRequest = new SearchRequest();
IntConsumer setSize = size -> searchRequest.source().size(size);
request.withContentOrSourceParamParserOrNull(parser -> parseSearchRequest(searchRequest, request, parser, client.getNamedWriteableRegistry(), setSize));
return new SearchQueryInputDataset(Arrays.asList(searchRequest.indices()), searchRequest.source());
}
use of org.opensearch.ml.common.dataset.SearchQueryInputDataset in project ml-commons by opensearch-project.
the class MLInputDatasetHandlerTests method testSearchQueryInputDatasetWithoutHits.
@SuppressWarnings("unchecked")
public void testSearchQueryInputDatasetWithoutHits() {
searchResponse = mock(SearchResponse.class);
SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f);
when(searchResponse.getHits()).thenReturn(hits);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocation.getArguments()[1];
listener.onResponse(searchResponse);
return null;
}).when(client).search(any(), any());
SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder().indices(Collections.singletonList("index1")).searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())).build();
mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener);
verify(listener, times(1)).onFailure(any());
}
use of org.opensearch.ml.common.dataset.SearchQueryInputDataset in project ml-commons by opensearch-project.
the class MLInputDatasetHandlerTests method testDataFrameInputDatasetWrongType.
public void testDataFrameInputDatasetWrongType() {
expectedEx.expect(IllegalArgumentException.class);
expectedEx.expectMessage("Input dataset is not DATA_FRAME type.");
SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder().indices(Collections.singletonList("index1")).searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())).build();
mlInputDatasetHandler.parseDataFrameInput(searchQueryInputDataset);
}
use of org.opensearch.ml.common.dataset.SearchQueryInputDataset in project ml-commons by opensearch-project.
the class PredictionITTests method initTestingData.
@Before
public void initTestingData() throws ExecutionException, InterruptedException {
generateMLTestingData();
SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder();
MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder);
taskId = trainModel(inputDataset);
waitModelAvailable(taskId);
}
Aggregations