Search in sources :

Example 1 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class PredictionITTests method testPredictionWithSearchInput.

public void testPredictionWithSearchInput() throws IOException {
    SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder();
    MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder);
    predictAndVerifyResult(taskId, inputDataset);
}
Also used : SearchQueryInputDataset(org.opensearch.ml.common.dataset.SearchQueryInputDataset) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) IntegTestUtils.generateSearchSourceBuilder(org.opensearch.ml.utils.IntegTestUtils.generateSearchSourceBuilder) SearchSourceBuilder(org.opensearch.search.builder.SearchSourceBuilder)

Example 2 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class PredictionITTests method testPredictionWithEmptyDataset.

public void testPredictionWithEmptyDataset() throws IOException {
    MLInputDataset emptySearchInputDataset = generateEmptyDataset();
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(emptySearchInputDataset).build();
    MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput);
    ActionFuture<MLTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest);
    expectThrows(IllegalArgumentException.class, () -> predictionFuture.actionGet());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInput(org.opensearch.ml.common.parameter.MLInput) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) MLPredictionTaskRequest(org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)

Example 3 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLCommonsRestTestCase method trainAndPredict.

public void trainAndPredict(RestClient client, FunctionName functionName, String indexName, MLAlgoParams params, SearchSourceBuilder searchSourceBuilder, Consumer<Map<String, Object>> function) throws IOException {
    MLInputDataset inputData = SearchQueryInputDataset.builder().indices(ImmutableList.of(indexName)).searchSourceBuilder(searchSourceBuilder).build();
    MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build();
    Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/_train_predict/" + functionName.name().toLowerCase(Locale.ROOT), ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null);
    HttpEntity entity = response.getEntity();
    assertNotNull(response);
    String entityString = TestHelper.httpEntityToString(entity);
    Map map = gson.fromJson(entityString, Map.class);
    Map<String, Object> predictionResult = (Map<String, Object>) map.get("prediction_result");
    if (function != null) {
        function.accept(predictionResult);
    }
}
Also used : Response(org.opensearch.client.Response) MLInput(org.opensearch.ml.common.parameter.MLInput) HttpEntity(org.apache.http.HttpEntity) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap)

Example 4 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class TrainingITTests method testTrainingWithSearchInput.

@Ignore("This test case is flaky, something is off with waitModelAvailable(taskId) method." + " This issue will be tracked in an issue and will be fixed later")
public void testTrainingWithSearchInput() throws ExecutionException, InterruptedException {
    SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder();
    MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder);
    String taskId = trainModel(inputDataset);
    waitModelAvailable(taskId);
}
Also used : SearchQueryInputDataset(org.opensearch.ml.common.dataset.SearchQueryInputDataset) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) IntegTestUtils.generateSearchSourceBuilder(org.opensearch.ml.utils.IntegTestUtils.generateSearchSourceBuilder) SearchSourceBuilder(org.opensearch.search.builder.SearchSourceBuilder) Ignore(org.junit.Ignore)

Example 5 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class TrainingITTests method testTrainingWithEmptyDataset.

// Train a model with empty dataset.
public void testTrainingWithEmptyDataset() {
    SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder();
    searchSourceBuilder.query(QueryBuilders.matchQuery("noSuchName", ""));
    MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder);
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
    MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput, false);
    expectThrows(IllegalArgumentException.class, () -> client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest).actionGet());
}
Also used : SearchQueryInputDataset(org.opensearch.ml.common.dataset.SearchQueryInputDataset) MLInput(org.opensearch.ml.common.parameter.MLInput) MLTrainingTaskRequest(org.opensearch.ml.common.transport.training.MLTrainingTaskRequest) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) IntegTestUtils.generateSearchSourceBuilder(org.opensearch.ml.utils.IntegTestUtils.generateSearchSourceBuilder) SearchSourceBuilder(org.opensearch.search.builder.SearchSourceBuilder)

Aggregations

MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)21 MLInput (org.opensearch.ml.common.parameter.MLInput)13 Test (org.junit.Test)9 Input (org.opensearch.ml.common.parameter.Input)8 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)8 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)7 SearchQueryInputDataset (org.opensearch.ml.common.dataset.SearchQueryInputDataset)5 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)5 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)5 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)5 SearchSourceBuilder (org.opensearch.search.builder.SearchSourceBuilder)5 IntegTestUtils.generateSearchSourceBuilder (org.opensearch.ml.utils.IntegTestUtils.generateSearchSourceBuilder)4 Response (org.opensearch.client.Response)3 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)3 FunctionName (org.opensearch.ml.common.parameter.FunctionName)3 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)3 Map (java.util.Map)2 KMeansParams (org.opensearch.ml.common.parameter.KMeansParams)2 Model (org.opensearch.ml.common.parameter.Model)2 ImmutableMap (com.google.common.collect.ImmutableMap)1