Search in sources :

Example 26 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class MachineLearningNodeClientTest method predict.

@SuppressWarnings("unchecked")
@Test
public void predict() {
    doAnswer(invocation -> {
        ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
        MLPredictionOutput predictionOutput = MLPredictionOutput.builder().status("Success").predictionResult(output).taskId("taskId").build();
        actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build());
        return null;
    }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());
    ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
    machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener);
    verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), isA(MLPredictionTaskRequest.class), any(ActionListener.class));
    verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
    assertEquals(output, ((MLPredictionOutput) dataFrameArgumentCaptor.getValue()).getPredictionResult());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInput(org.opensearch.ml.common.parameter.MLInput) ActionListener(org.opensearch.action.ActionListener) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) MLOutput(org.opensearch.ml.common.parameter.MLOutput) MLPredictionTaskRequest(org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest) Test(org.junit.Test)

Example 27 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class MachineLearningNodeClientTest method predict_Exception_WithNullAlgorithm.

@Test
public void predict_Exception_WithNullAlgorithm() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("algorithm can't be null");
    MLInput mlInput = MLInput.builder().inputDataset(input).build();
    machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Test(org.junit.Test)

Example 28 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class MachineLearningNodeClientTest method train.

@SuppressWarnings("unchecked")
@Test
public void train() {
    String modelId = "test_model_id";
    String status = "InProgress";
    doAnswer(invocation -> {
        ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
        MLTrainingOutput output = MLTrainingOutput.builder().status(status).modelId(modelId).build();
        actionListener.onResponse(MLTaskResponse.builder().output(output).build());
        return null;
    }).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any());
    ArgumentCaptor<MLOutput> argumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
    machineLearningNodeClient.train(mlInput, false, trainingActionListener);
    verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any(ActionListener.class));
    verify(trainingActionListener).onResponse(argumentCaptor.capture());
    assertEquals(modelId, ((MLTrainingOutput) argumentCaptor.getValue()).getModelId());
    assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) MLInput(org.opensearch.ml.common.parameter.MLInput) ActionListener(org.opensearch.action.ActionListener) MLTrainingTaskRequest(org.opensearch.ml.common.transport.training.MLTrainingTaskRequest) MLOutput(org.opensearch.ml.common.parameter.MLOutput) Test(org.junit.Test)

Example 29 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class MachineLearningClientTest method predict_WithAlgoAndInputDataAndListener.

@Test
public void predict_WithAlgoAndInputDataAndListener() {
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(new DataFrameInputDataset(input)).build();
    ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
    machineLearningClient.predict(null, mlInput, dataFrameActionListener);
    verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
    assertEquals(output, dataFrameArgumentCaptor.getValue());
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLOutput(org.opensearch.ml.common.parameter.MLOutput) Test(org.junit.Test)

Example 30 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class MachineLearningClientTest method predict_WithAlgoAndParametersAndInputDataAndModelId.

@Test
public void predict_WithAlgoAndParametersAndInputDataAndModelId() {
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(mlParameters).inputDataset(new DataFrameInputDataset(input)).build();
    assertEquals(output, machineLearningClient.predict("modelId", mlInput).actionGet());
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) Test(org.junit.Test)

Aggregations

MLInput (org.opensearch.ml.common.parameter.MLInput)46 Test (org.junit.Test)18 MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)13 MLTaskResponse (org.opensearch.ml.common.transport.MLTaskResponse)12 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)11 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)10 Input (org.opensearch.ml.common.parameter.Input)9 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)9 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)7 MLPredictionTaskRequest (org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)7 MLTrainingTaskRequest (org.opensearch.ml.common.transport.training.MLTrainingTaskRequest)7 MLOutput (org.opensearch.ml.common.parameter.MLOutput)6 XContentParser (org.opensearch.common.xcontent.XContentParser)5 Response (org.opensearch.client.Response)4 Model (org.opensearch.ml.common.parameter.Model)4 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)4 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)4 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)4 VisibleForTesting (com.google.common.annotations.VisibleForTesting)3 Instant (java.time.Instant)3