use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.
the class MLTrainingTaskResponseTest method fromActionResponse_Success_WithNonMLTrainingTaskResponse.
@Test
public void fromActionResponse_Success_WithNonMLTrainingTaskResponse() {
MLTrainingOutput output = MLTrainingOutput.builder().status("success").modelId("taskId").build();
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
response.writeTo(out);
}
};
MLTaskResponse result = MLTaskResponse.fromActionResponse(actionResponse);
assertNotSame(response, result);
MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) response.getOutput();
MLTrainingOutput resultModelTrainingOutput = (MLTrainingOutput) result.getOutput();
assertEquals(modelTrainingOutput.getStatus(), resultModelTrainingOutput.getStatus());
assertEquals(modelTrainingOutput.getModelId(), resultModelTrainingOutput.getModelId());
}
use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.
the class MLTrainingTaskResponseTest method writeTo.
@Test
public void writeTo() throws IOException {
MLTrainingOutput output = MLTrainingOutput.builder().status("success").modelId("taskId").build();
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
response.writeTo(bytesStreamOutput);
response = new MLTaskResponse(bytesStreamOutput.bytes().streamInput());
MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) response.getOutput();
assertEquals("success", modelTrainingOutput.getStatus());
assertEquals("taskId", modelTrainingOutput.getModelId());
}
use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.
the class MLTrainingTaskRunner method startTrainingTask.
/**
* Start training task
* @param mlTask ML task
* @param mlInput ML input
* @param listener Action listener
*/
public void startTrainingTask(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> listener) {
ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
// track ML task count and add ML task into cache
mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(ML_TOTAL_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(requestCountStat(mlTask.getFunctionName(), ActionName.TRAIN)).increment();
mlTaskManager.add(mlTask);
try {
if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
ActionListener<DataFrame> dataFrameActionListener = ActionListener.wrap(dataFrame -> {
train(mlTask, mlInput.toBuilder().inputDataset(new DataFrameInputDataset(dataFrame)).build(), internalListener);
}, e -> {
log.error("Failed to generate DataFrame from search query", e);
internalListener.onFailure(e);
});
mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false));
} else {
threadPool.executor(TASK_THREAD_POOL).execute(() -> {
train(mlTask, mlInput, internalListener);
});
}
} catch (Exception e) {
log.error("Failed to train " + mlInput.getAlgorithm(), e);
internalListener.onFailure(e);
}
}
use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.
the class MLTrainingTaskRunner method createMLTaskAndTrain.
public void createMLTaskAndTrain(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask.builder().taskType(MLTaskType.TRAINING).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(request.isAsync()).build();
if (request.isAsync()) {
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(r -> {
String taskId = r.getId();
mlTask.setTaskId(taskId);
listener.onResponse(new MLTaskResponse(new MLTrainingOutput(null, taskId, mlTask.getState().name())));
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(res -> {
String modelId = ((MLTrainingOutput) res.getOutput()).getModelId();
log.info("ML model trained successfully, task id: {}, model id: {}", taskId, modelId);
mlTask.setModelId(modelId);
handleAsyncMLTaskComplete(mlTask);
}, ex -> {
log.error("Failed to train ML model for task " + taskId);
handleAsyncMLTaskFailure(mlTask, ex);
});
startTrainingTask(mlTask, request.getMlInput(), internalListener);
}, e -> {
log.error("Failed to create ML task", e);
listener.onFailure(e);
}));
} else {
mlTask.setTaskId(UUID.randomUUID().toString());
startTrainingTask(mlTask, request.getMlInput(), listener);
}
}
use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.
the class TrainingITTests method testTrainingWithoutDataset.
// Train a model without dataset.
public void testTrainingWithoutDataset() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build();
MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput, true);
expectThrows(ActionRequestValidationException.class, () -> {
ActionFuture<MLTaskResponse> trainingFuture = client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest);
trainingFuture.actionGet();
});
}
Aggregations