use of org.opensearch.ml.common.dataset.MLInputDataType in project ml-commons by opensearch-project.
the class MLTask method parse.
public static MLTask parse(XContentParser parser) throws IOException {
String taskId = null;
String modelId = null;
MLTaskType taskType = null;
FunctionName functionName = null;
MLTaskState state = null;
MLInputDataType inputType = null;
Float progress = null;
String outputIndex = null;
String workerNode = null;
Instant createTime = null;
Instant lastUpdateTime = null;
String error = null;
User user = null;
boolean async = false;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
switch(fieldName) {
case TASK_ID_FIELD:
taskId = parser.text();
break;
case MODEL_ID_FIELD:
modelId = parser.text();
break;
case TASK_TYPE_FIELD:
taskType = MLTaskType.valueOf(parser.text());
break;
case FUNCTION_NAME_FIELD:
functionName = FunctionName.from(parser.text());
break;
case STATE_FIELD:
state = MLTaskState.valueOf(parser.text());
break;
case INPUT_TYPE_FIELD:
inputType = MLInputDataType.valueOf(parser.text());
break;
case PROGRESS_FIELD:
progress = parser.floatValue();
break;
case OUTPUT_INDEX_FIELD:
outputIndex = parser.text();
break;
case WORKER_NODE_FIELD:
workerNode = parser.text();
break;
case CREATE_TIME_FIELD:
createTime = Instant.ofEpochMilli(parser.longValue());
break;
case LAST_UPDATE_TIME_FIELD:
lastUpdateTime = Instant.ofEpochMilli(parser.longValue());
break;
case ERROR_FIELD:
error = parser.text();
break;
case USER_FIELD:
user = User.parse(parser);
break;
case IS_ASYNC_TASK_FIELD:
async = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return MLTask.builder().taskId(taskId).modelId(modelId).taskType(taskType).functionName(functionName).state(state).inputType(inputType).progress(progress).outputIndex(outputIndex).workerNode(workerNode).createTime(createTime).lastUpdateTime(lastUpdateTime).error(error).user(user).async(async).build();
}
use of org.opensearch.ml.common.dataset.MLInputDataType in project ml-commons by opensearch-project.
the class MLCommonsClassLoader method loadMLInputDataSetClassMapping.
/**
* Load ML input data set class
*/
private static void loadMLInputDataSetClassMapping() {
Reflections reflections = new Reflections("org.opensearch.ml.common.dataset");
Set<Class<?>> classes = reflections.getTypesAnnotatedWith(InputDataSet.class);
for (Class<?> clazz : classes) {
InputDataSet inputDataSet = clazz.getAnnotation(InputDataSet.class);
MLInputDataType value = inputDataSet.value();
if (value != null) {
parameterClassMap.put(value, clazz);
}
}
}
use of org.opensearch.ml.common.dataset.MLInputDataType in project ml-commons by opensearch-project.
the class MLTrainingTaskRunner method createMLTaskAndTrain.
public void createMLTaskAndTrain(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask.builder().taskType(MLTaskType.TRAINING).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(request.isAsync()).build();
if (request.isAsync()) {
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(r -> {
String taskId = r.getId();
mlTask.setTaskId(taskId);
listener.onResponse(new MLTaskResponse(new MLTrainingOutput(null, taskId, mlTask.getState().name())));
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(res -> {
String modelId = ((MLTrainingOutput) res.getOutput()).getModelId();
log.info("ML model trained successfully, task id: {}, model id: {}", taskId, modelId);
mlTask.setModelId(modelId);
handleAsyncMLTaskComplete(mlTask);
}, ex -> {
log.error("Failed to train ML model for task " + taskId);
handleAsyncMLTaskFailure(mlTask, ex);
});
startTrainingTask(mlTask, request.getMlInput(), internalListener);
}, e -> {
log.error("Failed to create ML task", e);
listener.onFailure(e);
}));
} else {
mlTask.setTaskId(UUID.randomUUID().toString());
startTrainingTask(mlTask, request.getMlInput(), listener);
}
}
use of org.opensearch.ml.common.dataset.MLInputDataType in project ml-commons by opensearch-project.
the class MLPredictTaskRunner method startPredictionTask.
/**
* Start prediction task
* @param request MLPredictionTaskRequest
* @param listener Action listener
*/
public void startPredictionTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask.builder().taskId(UUID.randomUUID().toString()).modelId(request.getModelId()).taskType(MLTaskType.PREDICTION).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(false).build();
MLInput mlInput = request.getMlInput();
if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
ActionListener<DataFrame> dataFrameActionListener = ActionListener.wrap(dataFrame -> {
predict(mlTask, dataFrame, request, listener);
}, e -> {
log.error("Failed to generate DataFrame from search query", e);
handleAsyncMLTaskFailure(mlTask, e);
listener.onFailure(e);
});
mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false));
} else {
DataFrame inputDataFrame = mlInputDatasetHandler.parseDataFrameInput(mlInput.getInputDataset());
threadPool.executor(TASK_THREAD_POOL).execute(() -> {
predict(mlTask, inputDataFrame, request, listener);
});
}
}
use of org.opensearch.ml.common.dataset.MLInputDataType in project ml-commons by opensearch-project.
the class MLTrainAndPredictTaskRunner method startTrainAndPredictionTask.
/**
* Start prediction task
* @param request MLPredictionTaskRequest
* @param listener Action listener
*/
public void startTrainAndPredictionTask(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask.builder().taskId(UUID.randomUUID().toString()).taskType(MLTaskType.TRAINING_AND_PREDICTION).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(false).build();
MLInput mlInput = request.getMlInput();
if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
ActionListener<DataFrame> dataFrameActionListener = ActionListener.wrap(dataFrame -> {
trainAndPredict(mlTask, dataFrame, request, listener);
}, e -> {
log.error("Failed to generate DataFrame from search query", e);
handlePredictFailure(mlTask, listener, e, false);
});
mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false));
} else {
DataFrame inputDataFrame = mlInputDatasetHandler.parseDataFrameInput(mlInput.getInputDataset());
threadPool.executor(TASK_THREAD_POOL).execute(() -> {
trainAndPredict(mlTask, inputDataFrame, request, listener);
});
}
}
Aggregations