use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MLInputTests method testParseKmeansInputQuery.
public void testParseKmeansInputQuery() throws IOException {
String query = "{\"input_query\":{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"k1\":1}}]}},\"size\":10},\"input_index\":[\"test_data\"]}";
XContentParser parser = parser(query);
MLInput mlInput = MLInput.parse(parser, FunctionName.KMEANS.name());
String expectedQuery = "{\"size\":10,\"query\":{\"bool\":{\"filter\":[{\"term\":{\"k1\":{\"value\":1,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}}}";
SearchQueryInputDataset inputDataset = (SearchQueryInputDataset) mlInput.getInputDataset();
assertEquals(expectedQuery, inputDataset.getSearchSourceBuilder().toString());
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MLInputTests method testParseKmeansInputDataFrame.
public void testParseKmeansInputDataFrame() throws IOException {
String query = "{\"input_data\":{\"column_metas\":[{\"name\":\"total_sum\",\"column_type\":\"DOUBLE\"},{\"name\":\"is_error\"," + "\"column_type\":\"BOOLEAN\"}],\"rows\":[{\"values\":[{\"column_type\":\"DOUBLE\",\"value\":15}," + "{\"column_type\":\"BOOLEAN\",\"value\":false}]},{\"values\":[{\"column_type\":\"DOUBLE\",\"value\":100}," + "{\"column_type\":\"BOOLEAN\",\"value\":true}]}]}}";
XContentParser parser = parser(query);
MLInput mlInput = MLInput.parse(parser, FunctionName.KMEANS.name());
DataFrameInputDataset inputDataset = (DataFrameInputDataset) mlInput.getInputDataset();
DataFrame dataFrame = inputDataset.getDataFrame();
assertEquals(2, dataFrame.columnMetas().length);
assertEquals(ColumnType.DOUBLE, dataFrame.columnMetas()[0].getColumnType());
assertEquals(ColumnType.BOOLEAN, dataFrame.columnMetas()[1].getColumnType());
assertEquals("total_sum", dataFrame.columnMetas()[0].getName());
assertEquals("is_error", dataFrame.columnMetas()[1].getName());
assertEquals(ColumnType.DOUBLE, dataFrame.getRow(0).getValue(0).columnType());
assertEquals(ColumnType.BOOLEAN, dataFrame.getRow(0).getValue(1).columnType());
assertEquals(15.0, dataFrame.getRow(0).getValue(0).getValue());
assertEquals(false, dataFrame.getRow(0).getValue(1).getValue());
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MLCommonsRestTestCase method trainAndPredict.
public void trainAndPredict(RestClient client, FunctionName functionName, String indexName, MLAlgoParams params, SearchSourceBuilder searchSourceBuilder, Consumer<Map<String, Object>> function) throws IOException {
MLInputDataset inputData = SearchQueryInputDataset.builder().indices(ImmutableList.of(indexName)).searchSourceBuilder(searchSourceBuilder).build();
MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build();
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/_train_predict/" + functionName.name().toLowerCase(Locale.ROOT), ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null);
HttpEntity entity = response.getEntity();
assertNotNull(response);
String entityString = TestHelper.httpEntityToString(entity);
Map map = gson.fromJson(entityString, Map.class);
Map<String, Object> predictionResult = (Map<String, Object>) map.get("prediction_result");
if (function != null) {
function.accept(predictionResult);
}
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class RestMLTrainAndPredictIT method trainAndPredictKmeansWithIrisData.
private void trainAndPredictKmeansWithIrisData(KMeansParams params, MLInputDataset inputData, Consumer<Map<Double, Integer>> function) throws IOException {
MLInput kmeansInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(params).inputDataset(inputData).build();
Response kmeansResponse = TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/_train_predict/kmeans", ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null);
HttpEntity entity = kmeansResponse.getEntity();
assertNotNull(kmeansResponse);
String entityString = TestHelper.httpEntityToString(entity);
Map map = gson.fromJson(entityString, Map.class);
Map predictionResult = (Map) map.get("prediction_result");
ArrayList rows = (ArrayList) predictionResult.get("rows");
Map<Double, Integer> clusterCount = new HashMap<>();
for (Object obj : rows) {
Double value = (Double) ((Map) ((ArrayList) ((Map) obj).get("values")).get(0)).get("value");
if (!clusterCount.containsKey(value)) {
clusterCount.put(value, 1);
} else {
Integer count = clusterCount.get(value);
clusterCount.put(value, ++count);
}
}
function.accept(clusterCount);
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MLTrainingTaskRunner method train.
private void train(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
ActionListener<MLTaskResponse> listener = ActionListener.wrap(r -> actionListener.onResponse(r), e -> {
mlStats.createCounterStatIfAbsent(failureCountStat(mlTask.getFunctionName(), ActionName.TRAIN)).increment();
mlStats.getStat(ML_TOTAL_FAILURE_COUNT).increment();
actionListener.onFailure(e);
});
try {
// run training
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
Model model = MLEngine.train(mlInput);
mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(indexCreated -> {
if (!indexCreated) {
listener.onFailure(new RuntimeException("No response to create ML task index"));
return;
}
// TODO: put the user into model for backend role based access control.
MLModel mlModel = new MLModel(mlInput.getAlgorithm(), model);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<IndexResponse> indexResponseListener = ActionListener.wrap(r -> {
log.info("Model data indexing done, result:{}, model id: {}", r.getResult(), r.getId());
mlStats.getStat(ML_TOTAL_MODEL_COUNT).increment();
mlStats.createCounterStatIfAbsent(modelCountStat(mlTask.getFunctionName())).increment();
String returnedTaskId = mlTask.isAsync() ? mlTask.getTaskId() : null;
MLTrainingOutput output = new MLTrainingOutput(r.getId(), returnedTaskId, MLTaskState.COMPLETED.name());
listener.onResponse(MLTaskResponse.builder().output(output).build());
}, e -> {
listener.onFailure(e);
});
IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX);
indexRequest.source(mlModel.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(indexRequest, ActionListener.runBefore(indexResponseListener, () -> context.restore()));
} catch (Exception e) {
log.error("Failed to save ML model", e);
listener.onFailure(e);
}
}, e -> {
log.error("Failed to init ML model index", e);
listener.onFailure(e);
}));
} catch (Exception e) {
// todo need to specify what exception
log.error("Failed to train " + mlInput.getAlgorithm(), e);
listener.onFailure(e);
}
}
Aggregations