Search in sources :

Example 1 with MLModel

use of org.opensearch.ml.common.parameter.MLModel in project ml-commons by opensearch-project.

the class MLTrainingTaskRunner method train.

private void train(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
    ActionListener<MLTaskResponse> listener = ActionListener.wrap(r -> actionListener.onResponse(r), e -> {
        mlStats.createCounterStatIfAbsent(failureCountStat(mlTask.getFunctionName(), ActionName.TRAIN)).increment();
        mlStats.getStat(ML_TOTAL_FAILURE_COUNT).increment();
        actionListener.onFailure(e);
    });
    try {
        // run training
        mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
        Model model = MLEngine.train(mlInput);
        mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(indexCreated -> {
            if (!indexCreated) {
                listener.onFailure(new RuntimeException("No response to create ML task index"));
                return;
            }
            // TODO: put the user into model for backend role based access control.
            MLModel mlModel = new MLModel(mlInput.getAlgorithm(), model);
            try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
                ActionListener<IndexResponse> indexResponseListener = ActionListener.wrap(r -> {
                    log.info("Model data indexing done, result:{}, model id: {}", r.getResult(), r.getId());
                    mlStats.getStat(ML_TOTAL_MODEL_COUNT).increment();
                    mlStats.createCounterStatIfAbsent(modelCountStat(mlTask.getFunctionName())).increment();
                    String returnedTaskId = mlTask.isAsync() ? mlTask.getTaskId() : null;
                    MLTrainingOutput output = new MLTrainingOutput(r.getId(), returnedTaskId, MLTaskState.COMPLETED.name());
                    listener.onResponse(MLTaskResponse.builder().output(output).build());
                }, e -> {
                    listener.onFailure(e);
                });
                IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX);
                indexRequest.source(mlModel.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
                indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                client.index(indexRequest, ActionListener.runBefore(indexResponseListener, () -> context.restore()));
            } catch (Exception e) {
                log.error("Failed to save ML model", e);
                listener.onFailure(e);
            }
        }, e -> {
            log.error("Failed to init ML model index", e);
            listener.onFailure(e);
        }));
    } catch (Exception e) {
        // todo need to specify what exception
        log.error("Failed to train " + mlInput.getAlgorithm(), e);
        listener.onFailure(e);
    }
}
Also used : IndexResponse(org.opensearch.action.index.IndexResponse) ToXContent(org.opensearch.common.xcontent.ToXContent) ThreadPool(org.opensearch.threadpool.ThreadPool) StatNames.modelCountStat(org.opensearch.ml.stats.StatNames.modelCountStat) MLTaskState(org.opensearch.ml.common.parameter.MLTaskState) MLInput(org.opensearch.ml.common.parameter.MLInput) MLInputDatasetHandler(org.opensearch.ml.indices.MLInputDatasetHandler) ThreadedActionListener(org.opensearch.action.support.ThreadedActionListener) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) MLTask(org.opensearch.ml.common.parameter.MLTask) ML_EXECUTING_TASK_COUNT(org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT) WriteRequest(org.opensearch.action.support.WriteRequest) ActionListener(org.opensearch.action.ActionListener) MLModel(org.opensearch.ml.common.parameter.MLModel) MLIndicesHandler(org.opensearch.ml.indices.MLIndicesHandler) ML_TOTAL_MODEL_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_MODEL_COUNT) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) Client(org.opensearch.client.Client) MLStats(org.opensearch.ml.stats.MLStats) ActionName(org.opensearch.ml.stats.ActionName) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) UUID(java.util.UUID) Instant(java.time.Instant) TransportService(org.opensearch.transport.TransportService) StatNames.requestCountStat(org.opensearch.ml.stats.StatNames.requestCountStat) MLTrainingTaskAction(org.opensearch.ml.common.transport.training.MLTrainingTaskAction) XContentBuilder(org.opensearch.common.xcontent.XContentBuilder) MLEngine(org.opensearch.ml.engine.MLEngine) StatNames.failureCountStat(org.opensearch.ml.stats.StatNames.failureCountStat) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) ML_MODEL_INDEX(org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX) Model(org.opensearch.ml.common.parameter.Model) ML_TOTAL_FAILURE_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_FAILURE_COUNT) MLTaskType(org.opensearch.ml.common.parameter.MLTaskType) MLCircuitBreakerService(org.opensearch.ml.common.breaker.MLCircuitBreakerService) Log4j2(lombok.extern.log4j.Log4j2) ActionListenerResponseHandler(org.opensearch.action.ActionListenerResponseHandler) ClusterService(org.opensearch.cluster.service.ClusterService) TASK_THREAD_POOL(org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL) ML_TOTAL_REQUEST_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_REQUEST_COUNT) XContentType(org.opensearch.common.xcontent.XContentType) IndexRequest(org.opensearch.action.index.IndexRequest) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLTrainingTaskRequest(org.opensearch.ml.common.transport.training.MLTrainingTaskRequest) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) ThreadedActionListener(org.opensearch.action.support.ThreadedActionListener) ActionListener(org.opensearch.action.ActionListener) MLModel(org.opensearch.ml.common.parameter.MLModel) Model(org.opensearch.ml.common.parameter.Model) MLModel(org.opensearch.ml.common.parameter.MLModel) IndexRequest(org.opensearch.action.index.IndexRequest)

Example 2 with MLModel

use of org.opensearch.ml.common.parameter.MLModel in project ml-commons by opensearch-project.

the class MachineLearningNodeClient method getModel.

@Override
public void getModel(String modelId, ActionListener<MLModel> listener) {
    MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build();
    client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, ActionListener.wrap(response -> {
        listener.onResponse(MLModelGetResponse.fromActionResponse(response).getMlModel());
    }, listener::onFailure));
}
Also used : MLOutput(org.opensearch.ml.common.parameter.MLOutput) FieldDefaults(lombok.experimental.FieldDefaults) MLModelDeleteRequest(org.opensearch.ml.common.transport.model.MLModelDeleteRequest) RequiredArgsConstructor(lombok.RequiredArgsConstructor) MLInput(org.opensearch.ml.common.parameter.MLInput) MLTrainAndPredictionTaskAction(org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction) Input(org.opensearch.ml.common.parameter.Input) Function(java.util.function.Function) MLPredictionTaskRequest(org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest) Output(org.opensearch.ml.common.parameter.Output) AccessLevel(lombok.AccessLevel) MLPredictionTaskAction(org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction) DeleteResponse(org.opensearch.action.delete.DeleteResponse) SearchRequest(org.opensearch.action.search.SearchRequest) ActionListener(org.opensearch.action.ActionListener) ActionResponse(org.opensearch.action.ActionResponse) SearchResponse(org.opensearch.action.search.SearchResponse) MLModel(org.opensearch.ml.common.parameter.MLModel) MLModelGetResponse(org.opensearch.ml.common.transport.model.MLModelGetResponse) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLExecuteTaskAction(org.opensearch.ml.common.transport.execute.MLExecuteTaskAction) MLModelDeleteAction(org.opensearch.ml.common.transport.model.MLModelDeleteAction) NodeClient(org.opensearch.client.node.NodeClient) MLModelGetRequest(org.opensearch.ml.common.transport.model.MLModelGetRequest) MLModelSearchAction(org.opensearch.ml.common.transport.model.MLModelSearchAction) MLExecuteTaskRequest(org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest) MLTrainingTaskAction(org.opensearch.ml.common.transport.training.MLTrainingTaskAction) MLModelGetAction(org.opensearch.ml.common.transport.model.MLModelGetAction) MLExecuteTaskResponse(org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse) MLTrainingTaskRequest(org.opensearch.ml.common.transport.training.MLTrainingTaskRequest) MLModelGetRequest(org.opensearch.ml.common.transport.model.MLModelGetRequest)

Example 3 with MLModel

use of org.opensearch.ml.common.parameter.MLModel in project ml-commons by opensearch-project.

the class GetModelTransportAction method doExecute.

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLModelGetResponse> actionListener) {
    MLModelGetRequest mlModelGetRequest = MLModelGetRequest.fromActionRequest(request);
    String modelId = mlModelGetRequest.getModelId();
    GetRequest getRequest = new GetRequest(ML_MODEL_INDEX).id(modelId);
    try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
        client.get(getRequest, ActionListener.wrap(r -> {
            log.info("Completed Get Model Request, id:{}", modelId);
            if (r != null && r.isExists()) {
                try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
                    ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
                    MLModel mlModel = MLModel.parse(parser);
                    actionListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build());
                } catch (Exception e) {
                    log.error("Failed to parse ml model" + r.getId(), e);
                    actionListener.onFailure(e);
                }
            } else {
                actionListener.onFailure(new MLResourceNotFoundException("Fail to find model"));
            }
        }, e -> {
            if (e instanceof IndexNotFoundException) {
                actionListener.onFailure(new MLResourceNotFoundException("Fail to find model"));
            } else {
                log.error("Failed to get ML model " + modelId, e);
                actionListener.onFailure(e);
            }
        }));
    } catch (Exception e) {
        log.error("Failed to get ML model " + modelId, e);
        actionListener.onFailure(e);
    }
}
Also used : FieldDefaults(lombok.experimental.FieldDefaults) Client(org.opensearch.client.Client) HandledTransportAction(org.opensearch.action.support.HandledTransportAction) MLModelGetRequest(org.opensearch.ml.common.transport.model.MLModelGetRequest) IndexNotFoundException(org.opensearch.index.IndexNotFoundException) GetRequest(org.opensearch.action.get.GetRequest) XContentParserUtils.ensureExpectedToken(org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken) MLNodeUtils.createXContentParserFromRegistry(org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry) ActionRequest(org.opensearch.action.ActionRequest) Task(org.opensearch.tasks.Task) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) TransportService(org.opensearch.transport.TransportService) MLResourceNotFoundException(org.opensearch.ml.common.exception.MLResourceNotFoundException) XContentParser(org.opensearch.common.xcontent.XContentParser) ActionFilters(org.opensearch.action.support.ActionFilters) AccessLevel(lombok.AccessLevel) MLModelGetAction(org.opensearch.ml.common.transport.model.MLModelGetAction) ML_MODEL_INDEX(org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX) NamedXContentRegistry(org.opensearch.common.xcontent.NamedXContentRegistry) Log4j2(lombok.extern.log4j.Log4j2) Inject(org.opensearch.common.inject.Inject) ActionListener(org.opensearch.action.ActionListener) MLModel(org.opensearch.ml.common.parameter.MLModel) MLModelGetResponse(org.opensearch.ml.common.transport.model.MLModelGetResponse) MLResourceNotFoundException(org.opensearch.ml.common.exception.MLResourceNotFoundException) MLModelGetRequest(org.opensearch.ml.common.transport.model.MLModelGetRequest) GetRequest(org.opensearch.action.get.GetRequest) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) IndexNotFoundException(org.opensearch.index.IndexNotFoundException) MLModel(org.opensearch.ml.common.parameter.MLModel) MLModelGetRequest(org.opensearch.ml.common.transport.model.MLModelGetRequest) XContentParser(org.opensearch.common.xcontent.XContentParser) IndexNotFoundException(org.opensearch.index.IndexNotFoundException) MLResourceNotFoundException(org.opensearch.ml.common.exception.MLResourceNotFoundException)

Aggregations

ActionListener (org.opensearch.action.ActionListener)3 MLModel (org.opensearch.ml.common.parameter.MLModel)3 AccessLevel (lombok.AccessLevel)2 FieldDefaults (lombok.experimental.FieldDefaults)2 Log4j2 (lombok.extern.log4j.Log4j2)2 Client (org.opensearch.client.Client)2 ThreadContext (org.opensearch.common.util.concurrent.ThreadContext)2 MLInput (org.opensearch.ml.common.parameter.MLInput)2 MLTaskResponse (org.opensearch.ml.common.transport.MLTaskResponse)2 MLModelGetAction (org.opensearch.ml.common.transport.model.MLModelGetAction)2 MLModelGetRequest (org.opensearch.ml.common.transport.model.MLModelGetRequest)2 MLModelGetResponse (org.opensearch.ml.common.transport.model.MLModelGetResponse)2 MLTrainingTaskAction (org.opensearch.ml.common.transport.training.MLTrainingTaskAction)2 MLTrainingTaskRequest (org.opensearch.ml.common.transport.training.MLTrainingTaskRequest)2 ML_MODEL_INDEX (org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX)2 TransportService (org.opensearch.transport.TransportService)2 Instant (java.time.Instant)1 UUID (java.util.UUID)1 Function (java.util.function.Function)1 RequiredArgsConstructor (lombok.RequiredArgsConstructor)1