Search in sources :

Example 16 with MLTaskResponse

use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.

the class MLTrainingTaskResponseTest method fromActionResponse_Success_WithNonMLTrainingTaskResponse.

@Test
public void fromActionResponse_Success_WithNonMLTrainingTaskResponse() {
    MLTrainingOutput output = MLTrainingOutput.builder().status("success").modelId("taskId").build();
    MLTaskResponse response = MLTaskResponse.builder().output(output).build();
    ActionResponse actionResponse = new ActionResponse() {

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            response.writeTo(out);
        }
    };
    MLTaskResponse result = MLTaskResponse.fromActionResponse(actionResponse);
    assertNotSame(response, result);
    MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) response.getOutput();
    MLTrainingOutput resultModelTrainingOutput = (MLTrainingOutput) result.getOutput();
    assertEquals(modelTrainingOutput.getStatus(), resultModelTrainingOutput.getStatus());
    assertEquals(modelTrainingOutput.getModelId(), resultModelTrainingOutput.getModelId());
}
Also used : MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) StreamOutput(org.opensearch.common.io.stream.StreamOutput) BytesStreamOutput(org.opensearch.common.io.stream.BytesStreamOutput) ActionResponse(org.opensearch.action.ActionResponse) Test(org.junit.Test)

Example 17 with MLTaskResponse

use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.

the class MLTrainingTaskResponseTest method writeTo.

@Test
public void writeTo() throws IOException {
    MLTrainingOutput output = MLTrainingOutput.builder().status("success").modelId("taskId").build();
    MLTaskResponse response = MLTaskResponse.builder().output(output).build();
    BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
    response.writeTo(bytesStreamOutput);
    response = new MLTaskResponse(bytesStreamOutput.bytes().streamInput());
    MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) response.getOutput();
    assertEquals("success", modelTrainingOutput.getStatus());
    assertEquals("taskId", modelTrainingOutput.getModelId());
}
Also used : MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) BytesStreamOutput(org.opensearch.common.io.stream.BytesStreamOutput) Test(org.junit.Test)

Example 18 with MLTaskResponse

use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.

the class MLTrainingTaskRunner method startTrainingTask.

/**
 * Start training task
 * @param mlTask ML task
 * @param mlInput ML input
 * @param listener Action listener
 */
public void startTrainingTask(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> listener) {
    ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
    // track ML task count and add ML task into cache
    mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment();
    mlStats.getStat(ML_TOTAL_REQUEST_COUNT).increment();
    mlStats.createCounterStatIfAbsent(requestCountStat(mlTask.getFunctionName(), ActionName.TRAIN)).increment();
    mlTaskManager.add(mlTask);
    try {
        if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
            ActionListener<DataFrame> dataFrameActionListener = ActionListener.wrap(dataFrame -> {
                train(mlTask, mlInput.toBuilder().inputDataset(new DataFrameInputDataset(dataFrame)).build(), internalListener);
            }, e -> {
                log.error("Failed to generate DataFrame from search query", e);
                internalListener.onFailure(e);
            });
            mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false));
        } else {
            threadPool.executor(TASK_THREAD_POOL).execute(() -> {
                train(mlTask, mlInput, internalListener);
            });
        }
    } catch (Exception e) {
        log.error("Failed to train " + mlInput.getAlgorithm(), e);
        internalListener.onFailure(e);
    }
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) DataFrame(org.opensearch.ml.common.dataframe.DataFrame)

Example 19 with MLTaskResponse

use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.

the class MLTrainingTaskRunner method createMLTaskAndTrain.

public void createMLTaskAndTrain(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
    MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
    Instant now = Instant.now();
    MLTask mlTask = MLTask.builder().taskType(MLTaskType.TRAINING).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(request.isAsync()).build();
    if (request.isAsync()) {
        mlTaskManager.createMLTask(mlTask, ActionListener.wrap(r -> {
            String taskId = r.getId();
            mlTask.setTaskId(taskId);
            listener.onResponse(new MLTaskResponse(new MLTrainingOutput(null, taskId, mlTask.getState().name())));
            ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(res -> {
                String modelId = ((MLTrainingOutput) res.getOutput()).getModelId();
                log.info("ML model trained successfully, task id: {}, model id: {}", taskId, modelId);
                mlTask.setModelId(modelId);
                handleAsyncMLTaskComplete(mlTask);
            }, ex -> {
                log.error("Failed to train ML model for task " + taskId);
                handleAsyncMLTaskFailure(mlTask, ex);
            });
            startTrainingTask(mlTask, request.getMlInput(), internalListener);
        }, e -> {
            log.error("Failed to create ML task", e);
            listener.onFailure(e);
        }));
    } else {
        mlTask.setTaskId(UUID.randomUUID().toString());
        startTrainingTask(mlTask, request.getMlInput(), listener);
    }
}
Also used : IndexResponse(org.opensearch.action.index.IndexResponse) ToXContent(org.opensearch.common.xcontent.ToXContent) ThreadPool(org.opensearch.threadpool.ThreadPool) StatNames.modelCountStat(org.opensearch.ml.stats.StatNames.modelCountStat) MLTaskState(org.opensearch.ml.common.parameter.MLTaskState) MLInput(org.opensearch.ml.common.parameter.MLInput) MLInputDatasetHandler(org.opensearch.ml.indices.MLInputDatasetHandler) ThreadedActionListener(org.opensearch.action.support.ThreadedActionListener) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) MLTask(org.opensearch.ml.common.parameter.MLTask) ML_EXECUTING_TASK_COUNT(org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT) WriteRequest(org.opensearch.action.support.WriteRequest) ActionListener(org.opensearch.action.ActionListener) MLModel(org.opensearch.ml.common.parameter.MLModel) MLIndicesHandler(org.opensearch.ml.indices.MLIndicesHandler) ML_TOTAL_MODEL_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_MODEL_COUNT) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) Client(org.opensearch.client.Client) MLStats(org.opensearch.ml.stats.MLStats) ActionName(org.opensearch.ml.stats.ActionName) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) UUID(java.util.UUID) Instant(java.time.Instant) TransportService(org.opensearch.transport.TransportService) StatNames.requestCountStat(org.opensearch.ml.stats.StatNames.requestCountStat) MLTrainingTaskAction(org.opensearch.ml.common.transport.training.MLTrainingTaskAction) XContentBuilder(org.opensearch.common.xcontent.XContentBuilder) MLEngine(org.opensearch.ml.engine.MLEngine) StatNames.failureCountStat(org.opensearch.ml.stats.StatNames.failureCountStat) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) ML_MODEL_INDEX(org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX) Model(org.opensearch.ml.common.parameter.Model) ML_TOTAL_FAILURE_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_FAILURE_COUNT) MLTaskType(org.opensearch.ml.common.parameter.MLTaskType) MLCircuitBreakerService(org.opensearch.ml.common.breaker.MLCircuitBreakerService) Log4j2(lombok.extern.log4j.Log4j2) ActionListenerResponseHandler(org.opensearch.action.ActionListenerResponseHandler) ClusterService(org.opensearch.cluster.service.ClusterService) TASK_THREAD_POOL(org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL) ML_TOTAL_REQUEST_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_REQUEST_COUNT) XContentType(org.opensearch.common.xcontent.XContentType) IndexRequest(org.opensearch.action.index.IndexRequest) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLTrainingTaskRequest(org.opensearch.ml.common.transport.training.MLTrainingTaskRequest) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) ThreadedActionListener(org.opensearch.action.support.ThreadedActionListener) ActionListener(org.opensearch.action.ActionListener) Instant(java.time.Instant) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) MLTask(org.opensearch.ml.common.parameter.MLTask)

Example 20 with MLTaskResponse

use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.

the class TrainingITTests method testTrainingWithoutDataset.

// Train a model without dataset.
public void testTrainingWithoutDataset() {
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build();
    MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput, true);
    expectThrows(ActionRequestValidationException.class, () -> {
        ActionFuture<MLTaskResponse> trainingFuture = client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest);
        trainingFuture.actionGet();
    });
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInput(org.opensearch.ml.common.parameter.MLInput) MLTrainingTaskRequest(org.opensearch.ml.common.transport.training.MLTrainingTaskRequest)

Aggregations

MLTaskResponse (org.opensearch.ml.common.transport.MLTaskResponse)22 MLInput (org.opensearch.ml.common.parameter.MLInput)13 Test (org.junit.Test)9 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)8 MLTrainingOutput (org.opensearch.ml.common.parameter.MLTrainingOutput)7 MLPredictionTaskRequest (org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)6 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)5 MLTrainingTaskRequest (org.opensearch.ml.common.transport.training.MLTrainingTaskRequest)5 HashMap (java.util.HashMap)4 ActionListener (org.opensearch.action.ActionListener)4 BytesStreamOutput (org.opensearch.common.io.stream.BytesStreamOutput)4 XContentBuilder (org.opensearch.common.xcontent.XContentBuilder)4 MLOutput (org.opensearch.ml.common.parameter.MLOutput)4 ThreadContext (org.opensearch.common.util.concurrent.ThreadContext)3 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)3 MLModel (org.opensearch.ml.common.parameter.MLModel)3 Model (org.opensearch.ml.common.parameter.Model)3 Instant (java.time.Instant)2 UUID (java.util.UUID)2 Log4j2 (lombok.extern.log4j.Log4j2)2