use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class PredictionITTests method testPredictionWithSearchInput.
public void testPredictionWithSearchInput() throws IOException {
SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder();
MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder);
predictAndVerifyResult(taskId, inputDataset);
}
use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class PredictionITTests method testPredictionWithEmptyDataset.
public void testPredictionWithEmptyDataset() throws IOException {
MLInputDataset emptySearchInputDataset = generateEmptyDataset();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(emptySearchInputDataset).build();
MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput);
ActionFuture<MLTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest);
expectThrows(IllegalArgumentException.class, () -> predictionFuture.actionGet());
}
use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class MLCommonsRestTestCase method trainAndPredict.
public void trainAndPredict(RestClient client, FunctionName functionName, String indexName, MLAlgoParams params, SearchSourceBuilder searchSourceBuilder, Consumer<Map<String, Object>> function) throws IOException {
MLInputDataset inputData = SearchQueryInputDataset.builder().indices(ImmutableList.of(indexName)).searchSourceBuilder(searchSourceBuilder).build();
MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build();
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/_train_predict/" + functionName.name().toLowerCase(Locale.ROOT), ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null);
HttpEntity entity = response.getEntity();
assertNotNull(response);
String entityString = TestHelper.httpEntityToString(entity);
Map map = gson.fromJson(entityString, Map.class);
Map<String, Object> predictionResult = (Map<String, Object>) map.get("prediction_result");
if (function != null) {
function.accept(predictionResult);
}
}
use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class TrainingITTests method testTrainingWithSearchInput.
@Ignore("This test case is flaky, something is off with waitModelAvailable(taskId) method." + " This issue will be tracked in an issue and will be fixed later")
public void testTrainingWithSearchInput() throws ExecutionException, InterruptedException {
SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder();
MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder);
String taskId = trainModel(inputDataset);
waitModelAvailable(taskId);
}
use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class TrainingITTests method testTrainingWithEmptyDataset.
// Train a model with empty dataset.
public void testTrainingWithEmptyDataset() {
SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder();
searchSourceBuilder.query(QueryBuilders.matchQuery("noSuchName", ""));
MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput, false);
expectThrows(IllegalArgumentException.class, () -> client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest).actionGet());
}
Aggregations