use of org.opensearch.ml.common.parameter.MLTrainingOutput in project ml-commons by opensearch-project.
the class MLTrainingTaskRunner method train.
private void train(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
ActionListener<MLTaskResponse> listener = ActionListener.wrap(r -> actionListener.onResponse(r), e -> {
mlStats.createCounterStatIfAbsent(failureCountStat(mlTask.getFunctionName(), ActionName.TRAIN)).increment();
mlStats.getStat(ML_TOTAL_FAILURE_COUNT).increment();
actionListener.onFailure(e);
});
try {
// run training
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
Model model = MLEngine.train(mlInput);
mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(indexCreated -> {
if (!indexCreated) {
listener.onFailure(new RuntimeException("No response to create ML task index"));
return;
}
// TODO: put the user into model for backend role based access control.
MLModel mlModel = new MLModel(mlInput.getAlgorithm(), model);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<IndexResponse> indexResponseListener = ActionListener.wrap(r -> {
log.info("Model data indexing done, result:{}, model id: {}", r.getResult(), r.getId());
mlStats.getStat(ML_TOTAL_MODEL_COUNT).increment();
mlStats.createCounterStatIfAbsent(modelCountStat(mlTask.getFunctionName())).increment();
String returnedTaskId = mlTask.isAsync() ? mlTask.getTaskId() : null;
MLTrainingOutput output = new MLTrainingOutput(r.getId(), returnedTaskId, MLTaskState.COMPLETED.name());
listener.onResponse(MLTaskResponse.builder().output(output).build());
}, e -> {
listener.onFailure(e);
});
IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX);
indexRequest.source(mlModel.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(indexRequest, ActionListener.runBefore(indexResponseListener, () -> context.restore()));
} catch (Exception e) {
log.error("Failed to save ML model", e);
listener.onFailure(e);
}
}, e -> {
log.error("Failed to init ML model index", e);
listener.onFailure(e);
}));
} catch (Exception e) {
// todo need to specify what exception
log.error("Failed to train " + mlInput.getAlgorithm(), e);
listener.onFailure(e);
}
}
use of org.opensearch.ml.common.parameter.MLTrainingOutput in project ml-commons by opensearch-project.
the class MLTrainingTaskResponseTest method fromActionResponse_Success_WithMLTrainingTaskResponse.
@Test
public void fromActionResponse_Success_WithMLTrainingTaskResponse() {
MLTrainingOutput output = MLTrainingOutput.builder().status("success").modelId("taskId").build();
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
assertSame(response, MLTaskResponse.fromActionResponse(response));
}
use of org.opensearch.ml.common.parameter.MLTrainingOutput in project ml-commons by opensearch-project.
the class MachineLearningNodeClientTest method train.
@SuppressWarnings("unchecked")
@Test
public void train() {
String modelId = "test_model_id";
String status = "InProgress";
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
MLTrainingOutput output = MLTrainingOutput.builder().status(status).modelId(modelId).build();
actionListener.onResponse(MLTaskResponse.builder().output(output).build());
return null;
}).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any());
ArgumentCaptor<MLOutput> argumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
machineLearningNodeClient.train(mlInput, false, trainingActionListener);
verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any(ActionListener.class));
verify(trainingActionListener).onResponse(argumentCaptor.capture());
assertEquals(modelId, ((MLTrainingOutput) argumentCaptor.getValue()).getModelId());
assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus());
}
use of org.opensearch.ml.common.parameter.MLTrainingOutput in project ml-commons by opensearch-project.
the class MLTrainingTaskResponseTest method fromActionResponse_Success_WithNonMLTrainingTaskResponse.
@Test
public void fromActionResponse_Success_WithNonMLTrainingTaskResponse() {
MLTrainingOutput output = MLTrainingOutput.builder().status("success").modelId("taskId").build();
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
response.writeTo(out);
}
};
MLTaskResponse result = MLTaskResponse.fromActionResponse(actionResponse);
assertNotSame(response, result);
MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) response.getOutput();
MLTrainingOutput resultModelTrainingOutput = (MLTrainingOutput) result.getOutput();
assertEquals(modelTrainingOutput.getStatus(), resultModelTrainingOutput.getStatus());
assertEquals(modelTrainingOutput.getModelId(), resultModelTrainingOutput.getModelId());
}
use of org.opensearch.ml.common.parameter.MLTrainingOutput in project ml-commons by opensearch-project.
the class MLTrainingTaskResponseTest method writeTo.
@Test
public void writeTo() throws IOException {
MLTrainingOutput output = MLTrainingOutput.builder().status("success").modelId("taskId").build();
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
response.writeTo(bytesStreamOutput);
response = new MLTaskResponse(bytesStreamOutput.bytes().streamInput());
MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) response.getOutput();
assertEquals("success", modelTrainingOutput.getStatus());
assertEquals("taskId", modelTrainingOutput.getModelId());
}
Aggregations