Search in sources :

Example 21 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLCommonsRestTestCase method train.

public void train(RestClient client, FunctionName functionName, String indexName, MLAlgoParams params, SearchSourceBuilder searchSourceBuilder, Consumer<Map<String, Object>> function, boolean async) throws IOException {
    MLInputDataset inputData = SearchQueryInputDataset.builder().indices(ImmutableList.of(indexName)).searchSourceBuilder(searchSourceBuilder).build();
    MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build();
    String endpoint = "/_plugins/_ml/_train/" + functionName.name().toLowerCase(Locale.ROOT);
    if (async) {
        endpoint += "?async=true";
    }
    Response response = TestHelper.makeRequest(client, "POST", endpoint, ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null);
    verifyResponse(function, response);
}
Also used : Response(org.opensearch.client.Response) MLInput(org.opensearch.ml.common.parameter.MLInput) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset)

Aggregations

MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)21 MLInput (org.opensearch.ml.common.parameter.MLInput)13 Test (org.junit.Test)9 Input (org.opensearch.ml.common.parameter.Input)8 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)8 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)7 SearchQueryInputDataset (org.opensearch.ml.common.dataset.SearchQueryInputDataset)5 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)5 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)5 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)5 SearchSourceBuilder (org.opensearch.search.builder.SearchSourceBuilder)5 IntegTestUtils.generateSearchSourceBuilder (org.opensearch.ml.utils.IntegTestUtils.generateSearchSourceBuilder)4 Response (org.opensearch.client.Response)3 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)3 FunctionName (org.opensearch.ml.common.parameter.FunctionName)3 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)3 Map (java.util.Map)2 KMeansParams (org.opensearch.ml.common.parameter.KMeansParams)2 Model (org.opensearch.ml.common.parameter.Model)2 ImmutableMap (com.google.common.collect.ImmutableMap)1