Search in sources :

Example 1 with MLInputDataType

use of org.opensearch.ml.common.dataset.MLInputDataType in project ml-commons by opensearch-project.

the class MLTask method parse.

public static MLTask parse(XContentParser parser) throws IOException {
    String taskId = null;
    String modelId = null;
    MLTaskType taskType = null;
    FunctionName functionName = null;
    MLTaskState state = null;
    MLInputDataType inputType = null;
    Float progress = null;
    String outputIndex = null;
    String workerNode = null;
    Instant createTime = null;
    Instant lastUpdateTime = null;
    String error = null;
    User user = null;
    boolean async = false;
    ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
    while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
        String fieldName = parser.currentName();
        parser.nextToken();
        switch(fieldName) {
            case TASK_ID_FIELD:
                taskId = parser.text();
                break;
            case MODEL_ID_FIELD:
                modelId = parser.text();
                break;
            case TASK_TYPE_FIELD:
                taskType = MLTaskType.valueOf(parser.text());
                break;
            case FUNCTION_NAME_FIELD:
                functionName = FunctionName.from(parser.text());
                break;
            case STATE_FIELD:
                state = MLTaskState.valueOf(parser.text());
                break;
            case INPUT_TYPE_FIELD:
                inputType = MLInputDataType.valueOf(parser.text());
                break;
            case PROGRESS_FIELD:
                progress = parser.floatValue();
                break;
            case OUTPUT_INDEX_FIELD:
                outputIndex = parser.text();
                break;
            case WORKER_NODE_FIELD:
                workerNode = parser.text();
                break;
            case CREATE_TIME_FIELD:
                createTime = Instant.ofEpochMilli(parser.longValue());
                break;
            case LAST_UPDATE_TIME_FIELD:
                lastUpdateTime = Instant.ofEpochMilli(parser.longValue());
                break;
            case ERROR_FIELD:
                error = parser.text();
                break;
            case USER_FIELD:
                user = User.parse(parser);
                break;
            case IS_ASYNC_TASK_FIELD:
                async = parser.booleanValue();
                break;
            default:
                parser.skipChildren();
                break;
        }
    }
    return MLTask.builder().taskId(taskId).modelId(modelId).taskType(taskType).functionName(functionName).state(state).inputType(inputType).progress(progress).outputIndex(outputIndex).workerNode(workerNode).createTime(createTime).lastUpdateTime(lastUpdateTime).error(error).user(user).async(async).build();
}
Also used : User(org.opensearch.commons.authuser.User) Instant(java.time.Instant) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType)

Example 2 with MLInputDataType

use of org.opensearch.ml.common.dataset.MLInputDataType in project ml-commons by opensearch-project.

the class MLCommonsClassLoader method loadMLInputDataSetClassMapping.

/**
 * Load ML input data set class
 */
private static void loadMLInputDataSetClassMapping() {
    Reflections reflections = new Reflections("org.opensearch.ml.common.dataset");
    Set<Class<?>> classes = reflections.getTypesAnnotatedWith(InputDataSet.class);
    for (Class<?> clazz : classes) {
        InputDataSet inputDataSet = clazz.getAnnotation(InputDataSet.class);
        MLInputDataType value = inputDataSet.value();
        if (value != null) {
            parameterClassMap.put(value, clazz);
        }
    }
}
Also used : InputDataSet(org.opensearch.ml.common.annotation.InputDataSet) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) Reflections(org.reflections.Reflections)

Example 3 with MLInputDataType

use of org.opensearch.ml.common.dataset.MLInputDataType in project ml-commons by opensearch-project.

the class MLTrainingTaskRunner method createMLTaskAndTrain.

public void createMLTaskAndTrain(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
    MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
    Instant now = Instant.now();
    MLTask mlTask = MLTask.builder().taskType(MLTaskType.TRAINING).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(request.isAsync()).build();
    if (request.isAsync()) {
        mlTaskManager.createMLTask(mlTask, ActionListener.wrap(r -> {
            String taskId = r.getId();
            mlTask.setTaskId(taskId);
            listener.onResponse(new MLTaskResponse(new MLTrainingOutput(null, taskId, mlTask.getState().name())));
            ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(res -> {
                String modelId = ((MLTrainingOutput) res.getOutput()).getModelId();
                log.info("ML model trained successfully, task id: {}, model id: {}", taskId, modelId);
                mlTask.setModelId(modelId);
                handleAsyncMLTaskComplete(mlTask);
            }, ex -> {
                log.error("Failed to train ML model for task " + taskId);
                handleAsyncMLTaskFailure(mlTask, ex);
            });
            startTrainingTask(mlTask, request.getMlInput(), internalListener);
        }, e -> {
            log.error("Failed to create ML task", e);
            listener.onFailure(e);
        }));
    } else {
        mlTask.setTaskId(UUID.randomUUID().toString());
        startTrainingTask(mlTask, request.getMlInput(), listener);
    }
}
Also used : IndexResponse(org.opensearch.action.index.IndexResponse) ToXContent(org.opensearch.common.xcontent.ToXContent) ThreadPool(org.opensearch.threadpool.ThreadPool) StatNames.modelCountStat(org.opensearch.ml.stats.StatNames.modelCountStat) MLTaskState(org.opensearch.ml.common.parameter.MLTaskState) MLInput(org.opensearch.ml.common.parameter.MLInput) MLInputDatasetHandler(org.opensearch.ml.indices.MLInputDatasetHandler) ThreadedActionListener(org.opensearch.action.support.ThreadedActionListener) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) MLTask(org.opensearch.ml.common.parameter.MLTask) ML_EXECUTING_TASK_COUNT(org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT) WriteRequest(org.opensearch.action.support.WriteRequest) ActionListener(org.opensearch.action.ActionListener) MLModel(org.opensearch.ml.common.parameter.MLModel) MLIndicesHandler(org.opensearch.ml.indices.MLIndicesHandler) ML_TOTAL_MODEL_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_MODEL_COUNT) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) Client(org.opensearch.client.Client) MLStats(org.opensearch.ml.stats.MLStats) ActionName(org.opensearch.ml.stats.ActionName) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) UUID(java.util.UUID) Instant(java.time.Instant) TransportService(org.opensearch.transport.TransportService) StatNames.requestCountStat(org.opensearch.ml.stats.StatNames.requestCountStat) MLTrainingTaskAction(org.opensearch.ml.common.transport.training.MLTrainingTaskAction) XContentBuilder(org.opensearch.common.xcontent.XContentBuilder) MLEngine(org.opensearch.ml.engine.MLEngine) StatNames.failureCountStat(org.opensearch.ml.stats.StatNames.failureCountStat) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) ML_MODEL_INDEX(org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX) Model(org.opensearch.ml.common.parameter.Model) ML_TOTAL_FAILURE_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_FAILURE_COUNT) MLTaskType(org.opensearch.ml.common.parameter.MLTaskType) MLCircuitBreakerService(org.opensearch.ml.common.breaker.MLCircuitBreakerService) Log4j2(lombok.extern.log4j.Log4j2) ActionListenerResponseHandler(org.opensearch.action.ActionListenerResponseHandler) ClusterService(org.opensearch.cluster.service.ClusterService) TASK_THREAD_POOL(org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL) ML_TOTAL_REQUEST_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_REQUEST_COUNT) XContentType(org.opensearch.common.xcontent.XContentType) IndexRequest(org.opensearch.action.index.IndexRequest) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLTrainingTaskRequest(org.opensearch.ml.common.transport.training.MLTrainingTaskRequest) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) ThreadedActionListener(org.opensearch.action.support.ThreadedActionListener) ActionListener(org.opensearch.action.ActionListener) Instant(java.time.Instant) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) MLTask(org.opensearch.ml.common.parameter.MLTask)

Example 4 with MLInputDataType

use of org.opensearch.ml.common.dataset.MLInputDataType in project ml-commons by opensearch-project.

the class MLPredictTaskRunner method startPredictionTask.

/**
 * Start prediction task
 * @param request MLPredictionTaskRequest
 * @param listener Action listener
 */
public void startPredictionTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
    MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
    Instant now = Instant.now();
    MLTask mlTask = MLTask.builder().taskId(UUID.randomUUID().toString()).modelId(request.getModelId()).taskType(MLTaskType.PREDICTION).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(false).build();
    MLInput mlInput = request.getMlInput();
    if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
        ActionListener<DataFrame> dataFrameActionListener = ActionListener.wrap(dataFrame -> {
            predict(mlTask, dataFrame, request, listener);
        }, e -> {
            log.error("Failed to generate DataFrame from search query", e);
            handleAsyncMLTaskFailure(mlTask, e);
            listener.onFailure(e);
        });
        mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false));
    } else {
        DataFrame inputDataFrame = mlInputDatasetHandler.parseDataFrameInput(mlInput.getInputDataset());
        threadPool.executor(TASK_THREAD_POOL).execute(() -> {
            predict(mlTask, inputDataFrame, request, listener);
        });
    }
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Instant(java.time.Instant) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) MLTask(org.opensearch.ml.common.parameter.MLTask)

Example 5 with MLInputDataType

use of org.opensearch.ml.common.dataset.MLInputDataType in project ml-commons by opensearch-project.

the class MLTrainAndPredictTaskRunner method startTrainAndPredictionTask.

/**
 * Start prediction task
 * @param request MLPredictionTaskRequest
 * @param listener Action listener
 */
public void startTrainAndPredictionTask(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
    MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
    Instant now = Instant.now();
    MLTask mlTask = MLTask.builder().taskId(UUID.randomUUID().toString()).taskType(MLTaskType.TRAINING_AND_PREDICTION).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(false).build();
    MLInput mlInput = request.getMlInput();
    if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
        ActionListener<DataFrame> dataFrameActionListener = ActionListener.wrap(dataFrame -> {
            trainAndPredict(mlTask, dataFrame, request, listener);
        }, e -> {
            log.error("Failed to generate DataFrame from search query", e);
            handlePredictFailure(mlTask, listener, e, false);
        });
        mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false));
    } else {
        DataFrame inputDataFrame = mlInputDatasetHandler.parseDataFrameInput(mlInput.getInputDataset());
        threadPool.executor(TASK_THREAD_POOL).execute(() -> {
            trainAndPredict(mlTask, inputDataFrame, request, listener);
        });
    }
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Instant(java.time.Instant) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) MLTask(org.opensearch.ml.common.parameter.MLTask)

Aggregations

MLInputDataType (org.opensearch.ml.common.dataset.MLInputDataType)5 Instant (java.time.Instant)4 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)3 MLInput (org.opensearch.ml.common.parameter.MLInput)3 MLTask (org.opensearch.ml.common.parameter.MLTask)3 UUID (java.util.UUID)1 Log4j2 (lombok.extern.log4j.Log4j2)1 ActionListener (org.opensearch.action.ActionListener)1 ActionListenerResponseHandler (org.opensearch.action.ActionListenerResponseHandler)1 IndexRequest (org.opensearch.action.index.IndexRequest)1 IndexResponse (org.opensearch.action.index.IndexResponse)1 ThreadedActionListener (org.opensearch.action.support.ThreadedActionListener)1 WriteRequest (org.opensearch.action.support.WriteRequest)1 Client (org.opensearch.client.Client)1 ClusterService (org.opensearch.cluster.service.ClusterService)1 ThreadContext (org.opensearch.common.util.concurrent.ThreadContext)1 ToXContent (org.opensearch.common.xcontent.ToXContent)1 XContentBuilder (org.opensearch.common.xcontent.XContentBuilder)1 XContentType (org.opensearch.common.xcontent.XContentType)1 User (org.opensearch.commons.authuser.User)1