use of org.opensearch.ml.common.parameter.MLTrainingOutput in project ml-commons by opensearch-project.
the class MLTrainingTaskRunner method createMLTaskAndTrain.
public void createMLTaskAndTrain(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask.builder().taskType(MLTaskType.TRAINING).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(request.isAsync()).build();
if (request.isAsync()) {
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(r -> {
String taskId = r.getId();
mlTask.setTaskId(taskId);
listener.onResponse(new MLTaskResponse(new MLTrainingOutput(null, taskId, mlTask.getState().name())));
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(res -> {
String modelId = ((MLTrainingOutput) res.getOutput()).getModelId();
log.info("ML model trained successfully, task id: {}, model id: {}", taskId, modelId);
mlTask.setModelId(modelId);
handleAsyncMLTaskComplete(mlTask);
}, ex -> {
log.error("Failed to train ML model for task " + taskId);
handleAsyncMLTaskFailure(mlTask, ex);
});
startTrainingTask(mlTask, request.getMlInput(), internalListener);
}, e -> {
log.error("Failed to create ML task", e);
listener.onFailure(e);
}));
} else {
mlTask.setTaskId(UUID.randomUUID().toString());
startTrainingTask(mlTask, request.getMlInput(), listener);
}
}
use of org.opensearch.ml.common.parameter.MLTrainingOutput in project ml-commons by opensearch-project.
the class IntegTestUtils method trainModel.
// Train a model.
public static String trainModel(MLInputDataset inputDataset) throws ExecutionException, InterruptedException {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
// TODO: support train test in sync way
MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput, true);
ActionFuture<MLTaskResponse> trainingFuture = client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest);
MLTaskResponse trainingResponse = trainingFuture.actionGet();
assertNotNull(trainingResponse);
MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) trainingResponse.getOutput();
String modelId = modelTrainingOutput.getModelId();
String status = modelTrainingOutput.getStatus();
assertNotNull(modelId);
assertFalse(modelId.isEmpty());
assertEquals("CREATED", status);
return modelId;
}
Aggregations