Search in sources :

Example 1 with MLOutput

use of org.opensearch.ml.common.parameter.MLOutput in project ml-commons by opensearch-project.

the class MLTrainAndPredictTaskRunner method trainAndPredict.

private void trainAndPredict(MLTask mlTask, DataFrame inputDataFrame, MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
    ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
    // track ML task count and add ML task into cache
    mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment();
    mlStats.getStat(ML_TOTAL_REQUEST_COUNT).increment();
    mlStats.createCounterStatIfAbsent(requestCountStat(mlTask.getFunctionName(), ActionName.TRAIN_PREDICT)).increment();
    mlTaskManager.add(mlTask);
    MLInput mlInput = request.getMlInput();
    // run train and predict
    try {
        mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
        MLOutput output = MLEngine.trainAndPredict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build());
        handleAsyncMLTaskComplete(mlTask);
        if (output instanceof MLPredictionOutput) {
            ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
        }
        MLTaskResponse response = MLTaskResponse.builder().output(output).build();
        log.info("Train and predict task done for algorithm: {}, task id: {}", mlTask.getFunctionName(), mlTask.getTaskId());
        internalListener.onResponse(response);
    } catch (Exception e) {
        // todo need to specify what exception
        log.error("Failed to train and predict " + mlInput.getAlgorithm(), e);
        handlePredictFailure(mlTask, listener, e, true);
        return;
    }
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInput(org.opensearch.ml.common.parameter.MLInput) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLOutput(org.opensearch.ml.common.parameter.MLOutput) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput)

Example 2 with MLOutput

use of org.opensearch.ml.common.parameter.MLOutput in project ml-commons by opensearch-project.

the class MLPredictTaskRunner method predict.

private void predict(MLTask mlTask, DataFrame inputDataFrame, MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
    ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
    // track ML task count and add ML task into cache
    mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment();
    mlStats.getStat(ML_TOTAL_REQUEST_COUNT).increment();
    mlStats.createCounterStatIfAbsent(requestCountStat(mlTask.getFunctionName(), ActionName.PREDICT)).increment();
    mlTaskManager.add(mlTask);
    // run predict
    if (request.getModelId() != null) {
        // search model by model id.
        try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) {
            MLInput mlInput = request.getMlInput();
            ActionListener<GetResponse> getResponseListener = ActionListener.wrap(r -> {
                if (r == null || !r.isExists()) {
                    internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId."));
                    return;
                }
                Map<String, Object> source = r.getSourceAsMap();
                User requestUser = getUserContext(client);
                User resourceUser = User.parse((String) source.get(USER));
                if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
                    // The backend roles of request user and resource user doesn't have intersection
                    OpenSearchException e = new OpenSearchException("User: " + requestUser.getName() + " does not have permissions to run predict by model: " + request.getModelId());
                    handlePredictFailure(mlTask, internalListener, e, false);
                    return;
                }
                Model model = new Model();
                model.setName((String) source.get(MLModel.MODEL_NAME));
                model.setVersion((Integer) source.get(MLModel.MODEL_VERSION));
                byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT));
                model.setContent(decoded);
                // run predict
                mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
                MLOutput output = MLEngine.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
                if (output instanceof MLPredictionOutput) {
                    ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
                }
                // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
                handleAsyncMLTaskComplete(mlTask);
                MLTaskResponse response = MLTaskResponse.builder().output(output).build();
                internalListener.onResponse(response);
            }, e -> {
                log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e);
                handlePredictFailure(mlTask, internalListener, e, true);
            });
            GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId());
            client.get(getRequest, ActionListener.runBefore(getResponseListener, () -> context.restore()));
        } catch (Exception e) {
            log.error("Failed to get model " + mlTask.getModelId(), e);
            handlePredictFailure(mlTask, internalListener, e, true);
        }
    } else {
        IllegalArgumentException e = new IllegalArgumentException("ModelId is invalid");
        log.error("ModelId is invalid", e);
        handlePredictFailure(mlTask, internalListener, e, false);
    }
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) User(org.opensearch.commons.authuser.User) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) MLOutput(org.opensearch.ml.common.parameter.MLOutput) GetResponse(org.opensearch.action.get.GetResponse) OpenSearchException(org.opensearch.OpenSearchException) ResourceNotFoundException(org.opensearch.ResourceNotFoundException) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) GetRequest(org.opensearch.action.get.GetRequest) MLModel(org.opensearch.ml.common.parameter.MLModel) Model(org.opensearch.ml.common.parameter.Model) OpenSearchException(org.opensearch.OpenSearchException) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) ResourceNotFoundException(org.opensearch.ResourceNotFoundException)

Example 3 with MLOutput

use of org.opensearch.ml.common.parameter.MLOutput in project ml-commons by opensearch-project.

the class MachineLearningClientTest method predict_WithAlgoAndInputDataAndParametersAndListener.

@Test
public void predict_WithAlgoAndInputDataAndParametersAndListener() {
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(mlParameters).inputDataset(new DataFrameInputDataset(input)).build();
    ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
    machineLearningClient.predict(null, mlInput, dataFrameActionListener);
    verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
    assertEquals(output, dataFrameArgumentCaptor.getValue());
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLOutput(org.opensearch.ml.common.parameter.MLOutput) Test(org.junit.Test)

Example 4 with MLOutput

use of org.opensearch.ml.common.parameter.MLOutput in project ml-commons by opensearch-project.

the class MachineLearningNodeClientTest method predict.

@SuppressWarnings("unchecked")
@Test
public void predict() {
    doAnswer(invocation -> {
        ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
        MLPredictionOutput predictionOutput = MLPredictionOutput.builder().status("Success").predictionResult(output).taskId("taskId").build();
        actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build());
        return null;
    }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());
    ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
    machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener);
    verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), isA(MLPredictionTaskRequest.class), any(ActionListener.class));
    verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
    assertEquals(output, ((MLPredictionOutput) dataFrameArgumentCaptor.getValue()).getPredictionResult());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInput(org.opensearch.ml.common.parameter.MLInput) ActionListener(org.opensearch.action.ActionListener) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) MLOutput(org.opensearch.ml.common.parameter.MLOutput) MLPredictionTaskRequest(org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest) Test(org.junit.Test)

Example 5 with MLOutput

use of org.opensearch.ml.common.parameter.MLOutput in project ml-commons by opensearch-project.

the class MachineLearningNodeClientTest method train.

@SuppressWarnings("unchecked")
@Test
public void train() {
    String modelId = "test_model_id";
    String status = "InProgress";
    doAnswer(invocation -> {
        ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
        MLTrainingOutput output = MLTrainingOutput.builder().status(status).modelId(modelId).build();
        actionListener.onResponse(MLTaskResponse.builder().output(output).build());
        return null;
    }).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any());
    ArgumentCaptor<MLOutput> argumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
    machineLearningNodeClient.train(mlInput, false, trainingActionListener);
    verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any(ActionListener.class));
    verify(trainingActionListener).onResponse(argumentCaptor.capture());
    assertEquals(modelId, ((MLTrainingOutput) argumentCaptor.getValue()).getModelId());
    assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) MLInput(org.opensearch.ml.common.parameter.MLInput) ActionListener(org.opensearch.action.ActionListener) MLTrainingTaskRequest(org.opensearch.ml.common.transport.training.MLTrainingTaskRequest) MLOutput(org.opensearch.ml.common.parameter.MLOutput) Test(org.junit.Test)

Aggregations

MLInput (org.opensearch.ml.common.parameter.MLInput)6 MLOutput (org.opensearch.ml.common.parameter.MLOutput)6 Test (org.junit.Test)4 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)4 MLTaskResponse (org.opensearch.ml.common.transport.MLTaskResponse)4 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)3 ActionListener (org.opensearch.action.ActionListener)2 OpenSearchException (org.opensearch.OpenSearchException)1 ResourceNotFoundException (org.opensearch.ResourceNotFoundException)1 GetRequest (org.opensearch.action.get.GetRequest)1 GetResponse (org.opensearch.action.get.GetResponse)1 ThreadContext (org.opensearch.common.util.concurrent.ThreadContext)1 User (org.opensearch.commons.authuser.User)1 MLModel (org.opensearch.ml.common.parameter.MLModel)1 MLTrainingOutput (org.opensearch.ml.common.parameter.MLTrainingOutput)1 Model (org.opensearch.ml.common.parameter.Model)1 MLPredictionTaskRequest (org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)1 MLTrainingTaskRequest (org.opensearch.ml.common.transport.training.MLTrainingTaskRequest)1