use of org.opensearch.ml.common.transport.training.MLTrainingTaskRequest in project ml-commons by opensearch-project.
the class TrainingITTests method testTrainingWithEmptyDataset.
// Train a model with empty dataset.
public void testTrainingWithEmptyDataset() {
SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder();
searchSourceBuilder.query(QueryBuilders.matchQuery("noSuchName", ""));
MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput, false);
expectThrows(IllegalArgumentException.class, () -> client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest).actionGet());
}
use of org.opensearch.ml.common.transport.training.MLTrainingTaskRequest in project ml-commons by opensearch-project.
the class MachineLearningNodeClient method trainAndPredict.
@Override
public void trainAndPredict(MLInput mlInput, ActionListener<MLOutput> listener) {
validateMLInput(mlInput, true);
MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).build();
client.execute(MLTrainAndPredictionTaskAction.INSTANCE, request, getMlPredictionTaskResponseActionListener(listener));
}
use of org.opensearch.ml.common.transport.training.MLTrainingTaskRequest in project ml-commons by opensearch-project.
the class TransportTrainAndPredictionTaskAction method doExecute.
@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskResponse> listener) {
MLTrainingTaskRequest trainingRequest = MLTrainingTaskRequest.fromActionRequest(request);
mlTrainAndPredictTaskRunner.run(trainingRequest, transportService, listener);
}
use of org.opensearch.ml.common.transport.training.MLTrainingTaskRequest in project ml-commons by opensearch-project.
the class MachineLearningNodeClient method train.
@Override
public void train(MLInput mlInput, boolean asyncTask, ActionListener<MLOutput> listener) {
validateMLInput(mlInput, true);
MLTrainingTaskRequest trainingTaskRequest = MLTrainingTaskRequest.builder().mlInput(mlInput).async(asyncTask).build();
client.execute(MLTrainingTaskAction.INSTANCE, trainingTaskRequest, getMlPredictionTaskResponseActionListener(listener));
}
use of org.opensearch.ml.common.transport.training.MLTrainingTaskRequest in project ml-commons by opensearch-project.
the class MLTrainingTaskRunner method createMLTaskAndTrain.
public void createMLTaskAndTrain(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask.builder().taskType(MLTaskType.TRAINING).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(request.isAsync()).build();
if (request.isAsync()) {
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(r -> {
String taskId = r.getId();
mlTask.setTaskId(taskId);
listener.onResponse(new MLTaskResponse(new MLTrainingOutput(null, taskId, mlTask.getState().name())));
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(res -> {
String modelId = ((MLTrainingOutput) res.getOutput()).getModelId();
log.info("ML model trained successfully, task id: {}, model id: {}", taskId, modelId);
mlTask.setModelId(modelId);
handleAsyncMLTaskComplete(mlTask);
}, ex -> {
log.error("Failed to train ML model for task " + taskId);
handleAsyncMLTaskFailure(mlTask, ex);
});
startTrainingTask(mlTask, request.getMlInput(), internalListener);
}, e -> {
log.error("Failed to create ML task", e);
listener.onFailure(e);
}));
} else {
mlTask.setTaskId(UUID.randomUUID().toString());
startTrainingTask(mlTask, request.getMlInput(), listener);
}
}
Aggregations