Search in sources :

Example 11 with MLTaskResponse

use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.

the class MachineLearningNodeClientTest method predict.

@SuppressWarnings("unchecked")
@Test
public void predict() {
    doAnswer(invocation -> {
        ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
        MLPredictionOutput predictionOutput = MLPredictionOutput.builder().status("Success").predictionResult(output).taskId("taskId").build();
        actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build());
        return null;
    }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());
    ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
    machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener);
    verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), isA(MLPredictionTaskRequest.class), any(ActionListener.class));
    verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
    assertEquals(output, ((MLPredictionOutput) dataFrameArgumentCaptor.getValue()).getPredictionResult());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInput(org.opensearch.ml.common.parameter.MLInput) ActionListener(org.opensearch.action.ActionListener) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) MLOutput(org.opensearch.ml.common.parameter.MLOutput) MLPredictionTaskRequest(org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest) Test(org.junit.Test)

Example 12 with MLTaskResponse

use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.

the class MachineLearningNodeClientTest method train.

@SuppressWarnings("unchecked")
@Test
public void train() {
    String modelId = "test_model_id";
    String status = "InProgress";
    doAnswer(invocation -> {
        ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
        MLTrainingOutput output = MLTrainingOutput.builder().status(status).modelId(modelId).build();
        actionListener.onResponse(MLTaskResponse.builder().output(output).build());
        return null;
    }).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any());
    ArgumentCaptor<MLOutput> argumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
    machineLearningNodeClient.train(mlInput, false, trainingActionListener);
    verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any(ActionListener.class));
    verify(trainingActionListener).onResponse(argumentCaptor.capture());
    assertEquals(modelId, ((MLTrainingOutput) argumentCaptor.getValue()).getModelId());
    assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) MLInput(org.opensearch.ml.common.parameter.MLInput) ActionListener(org.opensearch.action.ActionListener) MLTrainingTaskRequest(org.opensearch.ml.common.transport.training.MLTrainingTaskRequest) MLOutput(org.opensearch.ml.common.parameter.MLOutput) Test(org.junit.Test)

Example 13 with MLTaskResponse

use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.

the class MLPredictionTaskResponseTest method fromActionResponse_WithNonMLPredictionTaskResponse.

@Test
public void fromActionResponse_WithNonMLPredictionTaskResponse() {
    MLPredictionOutput output = MLPredictionOutput.builder().taskId("taskId").status("Success").predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {

        {
            put("key1", 2.0D);
        }
    }))).build();
    MLTaskResponse response = MLTaskResponse.builder().output(output).build();
    ActionResponse actionResponse = new ActionResponse() {

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            response.writeTo(out);
        }
    };
    MLTaskResponse result = MLTaskResponse.fromActionResponse(actionResponse);
    assertNotSame(response, result);
    MLPredictionOutput mlPredictionOutput = (MLPredictionOutput) response.getOutput();
    MLPredictionOutput resultMlPredictionOutput = (MLPredictionOutput) result.getOutput();
    assertEquals(mlPredictionOutput.getTaskId(), resultMlPredictionOutput.getTaskId());
    assertEquals(mlPredictionOutput.getStatus(), resultMlPredictionOutput.getStatus());
    assertEquals(mlPredictionOutput.getPredictionResult().size(), resultMlPredictionOutput.getPredictionResult().size());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) HashMap(java.util.HashMap) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) StreamOutput(org.opensearch.common.io.stream.StreamOutput) BytesStreamOutput(org.opensearch.common.io.stream.BytesStreamOutput) ActionResponse(org.opensearch.action.ActionResponse) Test(org.junit.Test)

Example 14 with MLTaskResponse

use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.

the class MLPredictionTaskResponseTest method fromActionResponse_WithMLPredictionTaskResponse.

@Test
public void fromActionResponse_WithMLPredictionTaskResponse() {
    MLPredictionOutput output = MLPredictionOutput.builder().taskId("taskId").status("Success").predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {

        {
            put("key1", 2.0D);
        }
    }))).build();
    MLTaskResponse response = MLTaskResponse.builder().output(output).build();
    assertSame(response, MLTaskResponse.fromActionResponse(response));
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) HashMap(java.util.HashMap) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) Test(org.junit.Test)

Example 15 with MLTaskResponse

use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.

the class MLPredictionTaskResponseTest method writeTo_Success.

@Test
public void writeTo_Success() throws IOException {
    MLPredictionOutput output = MLPredictionOutput.builder().taskId("taskId").status("Success").predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {

        {
            put("key1", 2.0D);
        }
    }))).build();
    MLTaskResponse response = MLTaskResponse.builder().output(output).build();
    BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
    response.writeTo(bytesStreamOutput);
    response = new MLTaskResponse(bytesStreamOutput.bytes().streamInput());
    MLPredictionOutput mlPredictionOutput = (MLPredictionOutput) response.getOutput();
    assertEquals("taskId", mlPredictionOutput.getTaskId());
    assertEquals("Success", mlPredictionOutput.getStatus());
    assertEquals(1, mlPredictionOutput.getPredictionResult().size());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) HashMap(java.util.HashMap) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) BytesStreamOutput(org.opensearch.common.io.stream.BytesStreamOutput) Test(org.junit.Test)

Aggregations

MLTaskResponse (org.opensearch.ml.common.transport.MLTaskResponse)22 MLInput (org.opensearch.ml.common.parameter.MLInput)13 Test (org.junit.Test)9 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)8 MLTrainingOutput (org.opensearch.ml.common.parameter.MLTrainingOutput)7 MLPredictionTaskRequest (org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)6 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)5 MLTrainingTaskRequest (org.opensearch.ml.common.transport.training.MLTrainingTaskRequest)5 HashMap (java.util.HashMap)4 ActionListener (org.opensearch.action.ActionListener)4 BytesStreamOutput (org.opensearch.common.io.stream.BytesStreamOutput)4 XContentBuilder (org.opensearch.common.xcontent.XContentBuilder)4 MLOutput (org.opensearch.ml.common.parameter.MLOutput)4 ThreadContext (org.opensearch.common.util.concurrent.ThreadContext)3 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)3 MLModel (org.opensearch.ml.common.parameter.MLModel)3 Model (org.opensearch.ml.common.parameter.Model)3 Instant (java.time.Instant)2 UUID (java.util.UUID)2 Log4j2 (lombok.extern.log4j.Log4j2)2