use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MachineLearningNodeClientTest method predict.
@SuppressWarnings("unchecked")
@Test
public void predict() {
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
MLPredictionOutput predictionOutput = MLPredictionOutput.builder().status("Success").predictionResult(output).taskId("taskId").build();
actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build());
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());
ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener);
verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), isA(MLPredictionTaskRequest.class), any(ActionListener.class));
verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
assertEquals(output, ((MLPredictionOutput) dataFrameArgumentCaptor.getValue()).getPredictionResult());
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MachineLearningNodeClientTest method predict_Exception_WithNullAlgorithm.
@Test
public void predict_Exception_WithNullAlgorithm() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("algorithm can't be null");
MLInput mlInput = MLInput.builder().inputDataset(input).build();
machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener);
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MachineLearningNodeClientTest method train.
@SuppressWarnings("unchecked")
@Test
public void train() {
String modelId = "test_model_id";
String status = "InProgress";
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
MLTrainingOutput output = MLTrainingOutput.builder().status(status).modelId(modelId).build();
actionListener.onResponse(MLTaskResponse.builder().output(output).build());
return null;
}).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any());
ArgumentCaptor<MLOutput> argumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
machineLearningNodeClient.train(mlInput, false, trainingActionListener);
verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any(ActionListener.class));
verify(trainingActionListener).onResponse(argumentCaptor.capture());
assertEquals(modelId, ((MLTrainingOutput) argumentCaptor.getValue()).getModelId());
assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus());
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MachineLearningClientTest method predict_WithAlgoAndInputDataAndListener.
@Test
public void predict_WithAlgoAndInputDataAndListener() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(new DataFrameInputDataset(input)).build();
ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
machineLearningClient.predict(null, mlInput, dataFrameActionListener);
verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
assertEquals(output, dataFrameArgumentCaptor.getValue());
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MachineLearningClientTest method predict_WithAlgoAndParametersAndInputDataAndModelId.
@Test
public void predict_WithAlgoAndParametersAndInputDataAndModelId() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(mlParameters).inputDataset(new DataFrameInputDataset(input)).build();
assertEquals(output, machineLearningClient.predict("modelId", mlInput).actionGet());
}
Aggregations