use of org.opensearch.ml.common.parameter.MLOutput in project ml-commons by opensearch-project.
the class MLTrainAndPredictTaskRunner method trainAndPredict.
private void trainAndPredict(MLTask mlTask, DataFrame inputDataFrame, MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
// track ML task count and add ML task into cache
mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(ML_TOTAL_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(requestCountStat(mlTask.getFunctionName(), ActionName.TRAIN_PREDICT)).increment();
mlTaskManager.add(mlTask);
MLInput mlInput = request.getMlInput();
// run train and predict
try {
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
MLOutput output = MLEngine.trainAndPredict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build());
handleAsyncMLTaskComplete(mlTask);
if (output instanceof MLPredictionOutput) {
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
}
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
log.info("Train and predict task done for algorithm: {}, task id: {}", mlTask.getFunctionName(), mlTask.getTaskId());
internalListener.onResponse(response);
} catch (Exception e) {
// todo need to specify what exception
log.error("Failed to train and predict " + mlInput.getAlgorithm(), e);
handlePredictFailure(mlTask, listener, e, true);
return;
}
}
use of org.opensearch.ml.common.parameter.MLOutput in project ml-commons by opensearch-project.
the class MLPredictTaskRunner method predict.
private void predict(MLTask mlTask, DataFrame inputDataFrame, MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
// track ML task count and add ML task into cache
mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(ML_TOTAL_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(requestCountStat(mlTask.getFunctionName(), ActionName.PREDICT)).increment();
mlTaskManager.add(mlTask);
// run predict
if (request.getModelId() != null) {
// search model by model id.
try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) {
MLInput mlInput = request.getMlInput();
ActionListener<GetResponse> getResponseListener = ActionListener.wrap(r -> {
if (r == null || !r.isExists()) {
internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId."));
return;
}
Map<String, Object> source = r.getSourceAsMap();
User requestUser = getUserContext(client);
User resourceUser = User.parse((String) source.get(USER));
if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
// The backend roles of request user and resource user doesn't have intersection
OpenSearchException e = new OpenSearchException("User: " + requestUser.getName() + " does not have permissions to run predict by model: " + request.getModelId());
handlePredictFailure(mlTask, internalListener, e, false);
return;
}
Model model = new Model();
model.setName((String) source.get(MLModel.MODEL_NAME));
model.setVersion((Integer) source.get(MLModel.MODEL_VERSION));
byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT));
model.setContent(decoded);
// run predict
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
MLOutput output = MLEngine.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
if (output instanceof MLPredictionOutput) {
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
}
// Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
handleAsyncMLTaskComplete(mlTask);
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
internalListener.onResponse(response);
}, e -> {
log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e);
handlePredictFailure(mlTask, internalListener, e, true);
});
GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId());
client.get(getRequest, ActionListener.runBefore(getResponseListener, () -> context.restore()));
} catch (Exception e) {
log.error("Failed to get model " + mlTask.getModelId(), e);
handlePredictFailure(mlTask, internalListener, e, true);
}
} else {
IllegalArgumentException e = new IllegalArgumentException("ModelId is invalid");
log.error("ModelId is invalid", e);
handlePredictFailure(mlTask, internalListener, e, false);
}
}
use of org.opensearch.ml.common.parameter.MLOutput in project ml-commons by opensearch-project.
the class MachineLearningClientTest method predict_WithAlgoAndInputDataAndParametersAndListener.
@Test
public void predict_WithAlgoAndInputDataAndParametersAndListener() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(mlParameters).inputDataset(new DataFrameInputDataset(input)).build();
ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
machineLearningClient.predict(null, mlInput, dataFrameActionListener);
verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
assertEquals(output, dataFrameArgumentCaptor.getValue());
}
use of org.opensearch.ml.common.parameter.MLOutput in project ml-commons by opensearch-project.
the class MachineLearningNodeClientTest method predict.
@SuppressWarnings("unchecked")
@Test
public void predict() {
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
MLPredictionOutput predictionOutput = MLPredictionOutput.builder().status("Success").predictionResult(output).taskId("taskId").build();
actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build());
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());
ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener);
verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), isA(MLPredictionTaskRequest.class), any(ActionListener.class));
verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
assertEquals(output, ((MLPredictionOutput) dataFrameArgumentCaptor.getValue()).getPredictionResult());
}
use of org.opensearch.ml.common.parameter.MLOutput in project ml-commons by opensearch-project.
the class MachineLearningNodeClientTest method train.
@SuppressWarnings("unchecked")
@Test
public void train() {
String modelId = "test_model_id";
String status = "InProgress";
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
MLTrainingOutput output = MLTrainingOutput.builder().status(status).modelId(modelId).build();
actionListener.onResponse(MLTaskResponse.builder().output(output).build());
return null;
}).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any());
ArgumentCaptor<MLOutput> argumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
machineLearningNodeClient.train(mlInput, false, trainingActionListener);
verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any(ActionListener.class));
verify(trainingActionListener).onResponse(argumentCaptor.capture());
assertEquals(modelId, ((MLTrainingOutput) argumentCaptor.getValue()).getModelId());
assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus());
}
Aggregations