Search in sources :

Example 16 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLEngineTest method trainLinearRegressionModel.

private Model trainLinearRegressionModel() {
    LinearRegressionParams parameters = LinearRegressionParams.builder().objectiveType(LinearRegressionParams.ObjectiveType.SQUARED_LOSS).optimizerType(LinearRegressionParams.OptimizerType.ADAM).learningRate(0.01).epsilon(1e-6).beta1(0.9).beta2(0.99).target("price").build();
    DataFrame trainDataFrame = constructLinearRegressionTrainDataFrame();
    MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build();
    Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).parameters(parameters).inputDataset(inputDataset).build();
    return MLEngine.train(mlInput);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) LinearRegressionParams(org.opensearch.ml.common.parameter.LinearRegressionParams) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame)

Example 17 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLEngineTest method predictKMeans.

@Test
public void predictKMeans() {
    Model model = trainKMeansModel();
    DataFrame predictionDataFrame = constructKMeansDataFrame(10);
    MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
    Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
    MLPredictionOutput output = (MLPredictionOutput) MLEngine.predict(mlInput, model);
    DataFrame predictions = output.getPredictionResult();
    Assert.assertEquals(10, predictions.size());
    predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) Model(org.opensearch.ml.common.parameter.Model) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) Test(org.junit.Test)

Example 18 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class RestMLTrainAndPredictIT method trainAndPredictKmeansWithParmas.

private void trainAndPredictKmeansWithParmas(KMeansParams params, Consumer<Map<Double, Integer>> function) throws IOException {
    SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
    sourceBuilder.query(new MatchAllQueryBuilder());
    sourceBuilder.size(1000);
    sourceBuilder.fetchSource(new String[] { "petal_length_in_cm", "petal_width_in_cm" }, null);
    MLInputDataset inputData = SearchQueryInputDataset.builder().indices(ImmutableList.of(irisIndex)).searchSourceBuilder(sourceBuilder).build();
    trainAndPredictKmeansWithIrisData(params, inputData, function);
}
Also used : MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) SearchSourceBuilder(org.opensearch.search.builder.SearchSourceBuilder) MatchAllQueryBuilder(org.opensearch.index.query.MatchAllQueryBuilder)

Example 19 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class PredictionITTests method initTestingData.

@Before
public void initTestingData() throws ExecutionException, InterruptedException {
    generateMLTestingData();
    SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder();
    MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder);
    taskId = trainModel(inputDataset);
    waitModelAvailable(taskId);
}
Also used : SearchQueryInputDataset(org.opensearch.ml.common.dataset.SearchQueryInputDataset) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) IntegTestUtils.generateSearchSourceBuilder(org.opensearch.ml.utils.IntegTestUtils.generateSearchSourceBuilder) SearchSourceBuilder(org.opensearch.search.builder.SearchSourceBuilder) Before(org.junit.Before)

Example 20 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLCommonsRestTestCase method predict.

public void predict(RestClient client, FunctionName functionName, String modelId, String indexName, MLAlgoParams params, SearchSourceBuilder searchSourceBuilder, Consumer<Map<String, Object>> function) throws IOException {
    MLInputDataset inputData = SearchQueryInputDataset.builder().indices(ImmutableList.of(indexName)).searchSourceBuilder(searchSourceBuilder).build();
    MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build();
    String endpoint = "/_plugins/_ml/_predict/" + functionName.name().toLowerCase(Locale.ROOT) + "/" + modelId;
    Response response = TestHelper.makeRequest(client, "POST", endpoint, ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null);
    verifyResponse(function, response);
}
Also used : Response(org.opensearch.client.Response) MLInput(org.opensearch.ml.common.parameter.MLInput) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset)

Aggregations

MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)21 MLInput (org.opensearch.ml.common.parameter.MLInput)13 Test (org.junit.Test)9 Input (org.opensearch.ml.common.parameter.Input)8 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)8 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)7 SearchQueryInputDataset (org.opensearch.ml.common.dataset.SearchQueryInputDataset)5 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)5 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)5 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)5 SearchSourceBuilder (org.opensearch.search.builder.SearchSourceBuilder)5 IntegTestUtils.generateSearchSourceBuilder (org.opensearch.ml.utils.IntegTestUtils.generateSearchSourceBuilder)4 Response (org.opensearch.client.Response)3 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)3 FunctionName (org.opensearch.ml.common.parameter.FunctionName)3 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)3 Map (java.util.Map)2 KMeansParams (org.opensearch.ml.common.parameter.KMeansParams)2 Model (org.opensearch.ml.common.parameter.Model)2 ImmutableMap (com.google.common.collect.ImmutableMap)1