Search in sources :

Example 6 with MLTask

use of org.opensearch.ml.common.parameter.MLTask in project ml-commons by opensearch-project.

the class MLTrainingTaskRunner method createMLTaskAndTrain.

public void createMLTaskAndTrain(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
    MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
    Instant now = Instant.now();
    MLTask mlTask = MLTask.builder().taskType(MLTaskType.TRAINING).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(request.isAsync()).build();
    if (request.isAsync()) {
        mlTaskManager.createMLTask(mlTask, ActionListener.wrap(r -> {
            String taskId = r.getId();
            mlTask.setTaskId(taskId);
            listener.onResponse(new MLTaskResponse(new MLTrainingOutput(null, taskId, mlTask.getState().name())));
            ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(res -> {
                String modelId = ((MLTrainingOutput) res.getOutput()).getModelId();
                log.info("ML model trained successfully, task id: {}, model id: {}", taskId, modelId);
                mlTask.setModelId(modelId);
                handleAsyncMLTaskComplete(mlTask);
            }, ex -> {
                log.error("Failed to train ML model for task " + taskId);
                handleAsyncMLTaskFailure(mlTask, ex);
            });
            startTrainingTask(mlTask, request.getMlInput(), internalListener);
        }, e -> {
            log.error("Failed to create ML task", e);
            listener.onFailure(e);
        }));
    } else {
        mlTask.setTaskId(UUID.randomUUID().toString());
        startTrainingTask(mlTask, request.getMlInput(), listener);
    }
}
Also used : IndexResponse(org.opensearch.action.index.IndexResponse) ToXContent(org.opensearch.common.xcontent.ToXContent) ThreadPool(org.opensearch.threadpool.ThreadPool) StatNames.modelCountStat(org.opensearch.ml.stats.StatNames.modelCountStat) MLTaskState(org.opensearch.ml.common.parameter.MLTaskState) MLInput(org.opensearch.ml.common.parameter.MLInput) MLInputDatasetHandler(org.opensearch.ml.indices.MLInputDatasetHandler) ThreadedActionListener(org.opensearch.action.support.ThreadedActionListener) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) MLTask(org.opensearch.ml.common.parameter.MLTask) ML_EXECUTING_TASK_COUNT(org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT) WriteRequest(org.opensearch.action.support.WriteRequest) ActionListener(org.opensearch.action.ActionListener) MLModel(org.opensearch.ml.common.parameter.MLModel) MLIndicesHandler(org.opensearch.ml.indices.MLIndicesHandler) ML_TOTAL_MODEL_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_MODEL_COUNT) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) Client(org.opensearch.client.Client) MLStats(org.opensearch.ml.stats.MLStats) ActionName(org.opensearch.ml.stats.ActionName) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) UUID(java.util.UUID) Instant(java.time.Instant) TransportService(org.opensearch.transport.TransportService) StatNames.requestCountStat(org.opensearch.ml.stats.StatNames.requestCountStat) MLTrainingTaskAction(org.opensearch.ml.common.transport.training.MLTrainingTaskAction) XContentBuilder(org.opensearch.common.xcontent.XContentBuilder) MLEngine(org.opensearch.ml.engine.MLEngine) StatNames.failureCountStat(org.opensearch.ml.stats.StatNames.failureCountStat) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) ML_MODEL_INDEX(org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX) Model(org.opensearch.ml.common.parameter.Model) ML_TOTAL_FAILURE_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_FAILURE_COUNT) MLTaskType(org.opensearch.ml.common.parameter.MLTaskType) MLCircuitBreakerService(org.opensearch.ml.common.breaker.MLCircuitBreakerService) Log4j2(lombok.extern.log4j.Log4j2) ActionListenerResponseHandler(org.opensearch.action.ActionListenerResponseHandler) ClusterService(org.opensearch.cluster.service.ClusterService) TASK_THREAD_POOL(org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL) ML_TOTAL_REQUEST_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_REQUEST_COUNT) XContentType(org.opensearch.common.xcontent.XContentType) IndexRequest(org.opensearch.action.index.IndexRequest) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLTrainingTaskRequest(org.opensearch.ml.common.transport.training.MLTrainingTaskRequest) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) ThreadedActionListener(org.opensearch.action.support.ThreadedActionListener) ActionListener(org.opensearch.action.ActionListener) Instant(java.time.Instant) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) MLTask(org.opensearch.ml.common.parameter.MLTask)

Example 7 with MLTask

use of org.opensearch.ml.common.parameter.MLTask in project ml-commons by opensearch-project.

the class MLPredictTaskRunner method startPredictionTask.

/**
 * Start prediction task
 * @param request MLPredictionTaskRequest
 * @param listener Action listener
 */
public void startPredictionTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
    MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
    Instant now = Instant.now();
    MLTask mlTask = MLTask.builder().taskId(UUID.randomUUID().toString()).modelId(request.getModelId()).taskType(MLTaskType.PREDICTION).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(false).build();
    MLInput mlInput = request.getMlInput();
    if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
        ActionListener<DataFrame> dataFrameActionListener = ActionListener.wrap(dataFrame -> {
            predict(mlTask, dataFrame, request, listener);
        }, e -> {
            log.error("Failed to generate DataFrame from search query", e);
            handleAsyncMLTaskFailure(mlTask, e);
            listener.onFailure(e);
        });
        mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false));
    } else {
        DataFrame inputDataFrame = mlInputDatasetHandler.parseDataFrameInput(mlInput.getInputDataset());
        threadPool.executor(TASK_THREAD_POOL).execute(() -> {
            predict(mlTask, inputDataFrame, request, listener);
        });
    }
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Instant(java.time.Instant) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) MLTask(org.opensearch.ml.common.parameter.MLTask)

Example 8 with MLTask

use of org.opensearch.ml.common.parameter.MLTask in project ml-commons by opensearch-project.

the class GetTaskTransportAction method doExecute.

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskGetResponse> actionListener) {
    MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.fromActionRequest(request);
    String taskId = mlTaskGetRequest.getTaskId();
    GetRequest getRequest = new GetRequest(ML_TASK_INDEX).id(taskId);
    try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
        client.get(getRequest, ActionListener.wrap(r -> {
            log.info("Completed Get Task Request, id:{}", taskId);
            if (r != null && r.isExists()) {
                try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
                    ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
                    MLTask mlTask = MLTask.parse(parser);
                    actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build());
                } catch (Exception e) {
                    log.error("Failed to parse ml task" + r.getId(), e);
                    actionListener.onFailure(e);
                }
            } else {
                actionListener.onFailure(new MLResourceNotFoundException("Fail to find task"));
            }
        }, e -> {
            if (e instanceof IndexNotFoundException) {
                actionListener.onFailure(new MLResourceNotFoundException("Fail to find task"));
            } else {
                log.error("Failed to get ML task " + taskId, e);
                actionListener.onFailure(e);
            }
        }));
    } catch (Exception e) {
        log.error("Failed to get ML task " + taskId, e);
        actionListener.onFailure(e);
    }
}
Also used : Client(org.opensearch.client.Client) HandledTransportAction(org.opensearch.action.support.HandledTransportAction) MLTaskGetAction(org.opensearch.ml.common.transport.task.MLTaskGetAction) MLTaskGetResponse(org.opensearch.ml.common.transport.task.MLTaskGetResponse) IndexNotFoundException(org.opensearch.index.IndexNotFoundException) GetRequest(org.opensearch.action.get.GetRequest) XContentParserUtils.ensureExpectedToken(org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken) MLNodeUtils.createXContentParserFromRegistry(org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry) ActionRequest(org.opensearch.action.ActionRequest) Task(org.opensearch.tasks.Task) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) MLTask(org.opensearch.ml.common.parameter.MLTask) TransportService(org.opensearch.transport.TransportService) MLResourceNotFoundException(org.opensearch.ml.common.exception.MLResourceNotFoundException) XContentParser(org.opensearch.common.xcontent.XContentParser) ActionFilters(org.opensearch.action.support.ActionFilters) MLTaskGetRequest(org.opensearch.ml.common.transport.task.MLTaskGetRequest) NamedXContentRegistry(org.opensearch.common.xcontent.NamedXContentRegistry) Log4j2(lombok.extern.log4j.Log4j2) Inject(org.opensearch.common.inject.Inject) ML_TASK_INDEX(org.opensearch.ml.indices.MLIndicesHandler.ML_TASK_INDEX) ActionListener(org.opensearch.action.ActionListener) MLResourceNotFoundException(org.opensearch.ml.common.exception.MLResourceNotFoundException) GetRequest(org.opensearch.action.get.GetRequest) MLTaskGetRequest(org.opensearch.ml.common.transport.task.MLTaskGetRequest) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) IndexNotFoundException(org.opensearch.index.IndexNotFoundException) MLTaskGetRequest(org.opensearch.ml.common.transport.task.MLTaskGetRequest) XContentParser(org.opensearch.common.xcontent.XContentParser) IndexNotFoundException(org.opensearch.index.IndexNotFoundException) MLResourceNotFoundException(org.opensearch.ml.common.exception.MLResourceNotFoundException) MLTask(org.opensearch.ml.common.parameter.MLTask)

Example 9 with MLTask

use of org.opensearch.ml.common.parameter.MLTask in project ml-commons by opensearch-project.

the class MLTaskManager method createMLTask.

/**
 * Create ML task. Will init ML task index first if absent.
 * @param mlTask ML task
 * @param listener action listener
 */
public void createMLTask(MLTask mlTask, ActionListener<IndexResponse> listener) {
    mlIndicesHandler.initMLTaskIndex(ActionListener.wrap(indexCreated -> {
        if (!indexCreated) {
            listener.onFailure(new RuntimeException("No response to create ML task index"));
            return;
        }
        IndexRequest request = new IndexRequest(ML_TASK_INDEX);
        try (XContentBuilder builder = XContentFactory.jsonBuilder();
            ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
            request.source(mlTask.toXContent(builder, ToXContent.EMPTY_PARAMS)).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            client.index(request, ActionListener.runBefore(listener, () -> context.restore()));
        } catch (Exception e) {
            log.error("Failed to create AD task for " + mlTask.getFunctionName() + ", " + mlTask.getTaskType(), e);
            listener.onFailure(e);
        }
    }, e -> {
        log.error("Failed to create ML index", e);
        listener.onFailure(e);
    }));
}
Also used : Client(org.opensearch.client.Client) Semaphore(java.util.concurrent.Semaphore) IndexResponse(org.opensearch.action.index.IndexResponse) UpdateResponse(org.opensearch.action.update.UpdateResponse) ToXContent(org.opensearch.common.xcontent.ToXContent) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) MLTaskState(org.opensearch.ml.common.parameter.MLTaskState) LAST_UPDATE_TIME_FIELD(org.opensearch.ml.common.parameter.MLTask.LAST_UPDATE_TIME_FIELD) HashMap(java.util.HashMap) Instant(java.time.Instant) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) MLTask(org.opensearch.ml.common.parameter.MLTask) RestStatus(org.opensearch.rest.RestStatus) TimeUnit(java.util.concurrent.TimeUnit) XContentBuilder(org.opensearch.common.xcontent.XContentBuilder) WriteRequest(org.opensearch.action.support.WriteRequest) Map(java.util.Map) Log4j2(lombok.extern.log4j.Log4j2) XContentFactory(org.opensearch.common.xcontent.XContentFactory) UpdateRequest(org.opensearch.action.update.UpdateRequest) ML_TASK_INDEX(org.opensearch.ml.indices.MLIndicesHandler.ML_TASK_INDEX) ActionListener(org.opensearch.action.ActionListener) IndexRequest(org.opensearch.action.index.IndexRequest) MLIndicesHandler(org.opensearch.ml.indices.MLIndicesHandler) IndexRequest(org.opensearch.action.index.IndexRequest) XContentBuilder(org.opensearch.common.xcontent.XContentBuilder)

Example 10 with MLTask

use of org.opensearch.ml.common.parameter.MLTask in project ml-commons by opensearch-project.

the class MLTaskManager method updateTaskStateAndError.

public synchronized void updateTaskStateAndError(String taskId, MLTaskState state, String error, boolean isAsyncTask) {
    if (!contains(taskId)) {
        throw new IllegalArgumentException("Task not found");
    }
    MLTask task = get(taskId);
    task.setState(state);
    task.setError(error);
    if (isAsyncTask) {
        Map<String, Object> updatedFields = new HashMap<>();
        if (state != null) {
            updatedFields.put(MLTask.STATE_FIELD, state.name());
        }
        if (error != null) {
            updatedFields.put(MLTask.ERROR_FIELD, error);
        }
        updateMLTask(taskId, updatedFields, 0);
    }
}
Also used : ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) MLTask(org.opensearch.ml.common.parameter.MLTask)

Aggregations

MLTask (org.opensearch.ml.common.parameter.MLTask)11 Instant (java.time.Instant)5 Log4j2 (lombok.extern.log4j.Log4j2)4 ActionListener (org.opensearch.action.ActionListener)4 Client (org.opensearch.client.Client)4 ThreadContext (org.opensearch.common.util.concurrent.ThreadContext)4 XContentBuilder (org.opensearch.common.xcontent.XContentBuilder)4 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)4 MLInputDataType (org.opensearch.ml.common.dataset.MLInputDataType)4 MLInput (org.opensearch.ml.common.parameter.MLInput)4 IndexRequest (org.opensearch.action.index.IndexRequest)3 IndexResponse (org.opensearch.action.index.IndexResponse)3 WriteRequest (org.opensearch.action.support.WriteRequest)3 ToXContent (org.opensearch.common.xcontent.ToXContent)3 MLTaskState (org.opensearch.ml.common.parameter.MLTaskState)3 MLIndicesHandler (org.opensearch.ml.indices.MLIndicesHandler)3 HashMap (java.util.HashMap)2 UUID (java.util.UUID)2 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)2 ActionListenerResponseHandler (org.opensearch.action.ActionListenerResponseHandler)2