use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.
the class MachineLearningNodeClientTest method predict.
@SuppressWarnings("unchecked")
@Test
public void predict() {
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
MLPredictionOutput predictionOutput = MLPredictionOutput.builder().status("Success").predictionResult(output).taskId("taskId").build();
actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build());
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());
ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener);
verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), isA(MLPredictionTaskRequest.class), any(ActionListener.class));
verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
assertEquals(output, ((MLPredictionOutput) dataFrameArgumentCaptor.getValue()).getPredictionResult());
}
use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.
the class MachineLearningNodeClientTest method train.
@SuppressWarnings("unchecked")
@Test
public void train() {
String modelId = "test_model_id";
String status = "InProgress";
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
MLTrainingOutput output = MLTrainingOutput.builder().status(status).modelId(modelId).build();
actionListener.onResponse(MLTaskResponse.builder().output(output).build());
return null;
}).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any());
ArgumentCaptor<MLOutput> argumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
machineLearningNodeClient.train(mlInput, false, trainingActionListener);
verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any(ActionListener.class));
verify(trainingActionListener).onResponse(argumentCaptor.capture());
assertEquals(modelId, ((MLTrainingOutput) argumentCaptor.getValue()).getModelId());
assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus());
}
use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.
the class MLPredictionTaskResponseTest method fromActionResponse_WithNonMLPredictionTaskResponse.
@Test
public void fromActionResponse_WithNonMLPredictionTaskResponse() {
MLPredictionOutput output = MLPredictionOutput.builder().taskId("taskId").status("Success").predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {
{
put("key1", 2.0D);
}
}))).build();
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
response.writeTo(out);
}
};
MLTaskResponse result = MLTaskResponse.fromActionResponse(actionResponse);
assertNotSame(response, result);
MLPredictionOutput mlPredictionOutput = (MLPredictionOutput) response.getOutput();
MLPredictionOutput resultMlPredictionOutput = (MLPredictionOutput) result.getOutput();
assertEquals(mlPredictionOutput.getTaskId(), resultMlPredictionOutput.getTaskId());
assertEquals(mlPredictionOutput.getStatus(), resultMlPredictionOutput.getStatus());
assertEquals(mlPredictionOutput.getPredictionResult().size(), resultMlPredictionOutput.getPredictionResult().size());
}
use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.
the class MLPredictionTaskResponseTest method fromActionResponse_WithMLPredictionTaskResponse.
@Test
public void fromActionResponse_WithMLPredictionTaskResponse() {
MLPredictionOutput output = MLPredictionOutput.builder().taskId("taskId").status("Success").predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {
{
put("key1", 2.0D);
}
}))).build();
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
assertSame(response, MLTaskResponse.fromActionResponse(response));
}
use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.
the class MLPredictionTaskResponseTest method writeTo_Success.
@Test
public void writeTo_Success() throws IOException {
MLPredictionOutput output = MLPredictionOutput.builder().taskId("taskId").status("Success").predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {
{
put("key1", 2.0D);
}
}))).build();
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
response.writeTo(bytesStreamOutput);
response = new MLTaskResponse(bytesStreamOutput.bytes().streamInput());
MLPredictionOutput mlPredictionOutput = (MLPredictionOutput) response.getOutput();
assertEquals("taskId", mlPredictionOutput.getTaskId());
assertEquals("Success", mlPredictionOutput.getStatus());
assertEquals(1, mlPredictionOutput.getPredictionResult().size());
}
Aggregations