use of org.opensearch.ml.common.parameter.MLTask in project ml-commons by opensearch-project.
the class MLTrainingTaskRunner method createMLTaskAndTrain.
public void createMLTaskAndTrain(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask.builder().taskType(MLTaskType.TRAINING).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(request.isAsync()).build();
if (request.isAsync()) {
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(r -> {
String taskId = r.getId();
mlTask.setTaskId(taskId);
listener.onResponse(new MLTaskResponse(new MLTrainingOutput(null, taskId, mlTask.getState().name())));
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(res -> {
String modelId = ((MLTrainingOutput) res.getOutput()).getModelId();
log.info("ML model trained successfully, task id: {}, model id: {}", taskId, modelId);
mlTask.setModelId(modelId);
handleAsyncMLTaskComplete(mlTask);
}, ex -> {
log.error("Failed to train ML model for task " + taskId);
handleAsyncMLTaskFailure(mlTask, ex);
});
startTrainingTask(mlTask, request.getMlInput(), internalListener);
}, e -> {
log.error("Failed to create ML task", e);
listener.onFailure(e);
}));
} else {
mlTask.setTaskId(UUID.randomUUID().toString());
startTrainingTask(mlTask, request.getMlInput(), listener);
}
}
use of org.opensearch.ml.common.parameter.MLTask in project ml-commons by opensearch-project.
the class MLPredictTaskRunner method startPredictionTask.
/**
* Start prediction task
* @param request MLPredictionTaskRequest
* @param listener Action listener
*/
public void startPredictionTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask.builder().taskId(UUID.randomUUID().toString()).modelId(request.getModelId()).taskType(MLTaskType.PREDICTION).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(false).build();
MLInput mlInput = request.getMlInput();
if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
ActionListener<DataFrame> dataFrameActionListener = ActionListener.wrap(dataFrame -> {
predict(mlTask, dataFrame, request, listener);
}, e -> {
log.error("Failed to generate DataFrame from search query", e);
handleAsyncMLTaskFailure(mlTask, e);
listener.onFailure(e);
});
mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false));
} else {
DataFrame inputDataFrame = mlInputDatasetHandler.parseDataFrameInput(mlInput.getInputDataset());
threadPool.executor(TASK_THREAD_POOL).execute(() -> {
predict(mlTask, inputDataFrame, request, listener);
});
}
}
use of org.opensearch.ml.common.parameter.MLTask in project ml-commons by opensearch-project.
the class GetTaskTransportAction method doExecute.
@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskGetResponse> actionListener) {
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.fromActionRequest(request);
String taskId = mlTaskGetRequest.getTaskId();
GetRequest getRequest = new GetRequest(ML_TASK_INDEX).id(taskId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.wrap(r -> {
log.info("Completed Get Task Request, id:{}", taskId);
if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLTask mlTask = MLTask.parse(parser);
actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build());
} catch (Exception e) {
log.error("Failed to parse ml task" + r.getId(), e);
actionListener.onFailure(e);
}
} else {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find task"));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find task"));
} else {
log.error("Failed to get ML task " + taskId, e);
actionListener.onFailure(e);
}
}));
} catch (Exception e) {
log.error("Failed to get ML task " + taskId, e);
actionListener.onFailure(e);
}
}
use of org.opensearch.ml.common.parameter.MLTask in project ml-commons by opensearch-project.
the class MLTaskManager method createMLTask.
/**
* Create ML task. Will init ML task index first if absent.
* @param mlTask ML task
* @param listener action listener
*/
public void createMLTask(MLTask mlTask, ActionListener<IndexResponse> listener) {
mlIndicesHandler.initMLTaskIndex(ActionListener.wrap(indexCreated -> {
if (!indexCreated) {
listener.onFailure(new RuntimeException("No response to create ML task index"));
return;
}
IndexRequest request = new IndexRequest(ML_TASK_INDEX);
try (XContentBuilder builder = XContentFactory.jsonBuilder();
ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
request.source(mlTask.toXContent(builder, ToXContent.EMPTY_PARAMS)).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(request, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
log.error("Failed to create AD task for " + mlTask.getFunctionName() + ", " + mlTask.getTaskType(), e);
listener.onFailure(e);
}
}, e -> {
log.error("Failed to create ML index", e);
listener.onFailure(e);
}));
}
use of org.opensearch.ml.common.parameter.MLTask in project ml-commons by opensearch-project.
the class MLTaskManager method updateTaskStateAndError.
public synchronized void updateTaskStateAndError(String taskId, MLTaskState state, String error, boolean isAsyncTask) {
if (!contains(taskId)) {
throw new IllegalArgumentException("Task not found");
}
MLTask task = get(taskId);
task.setState(state);
task.setError(error);
if (isAsyncTask) {
Map<String, Object> updatedFields = new HashMap<>();
if (state != null) {
updatedFields.put(MLTask.STATE_FIELD, state.name());
}
if (error != null) {
updatedFields.put(MLTask.ERROR_FIELD, error);
}
updateMLTask(taskId, updatedFields, 0);
}
}
Aggregations