Search in sources :

Example 6 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class MLInputTests method testParseKmeansInputQuery.

public void testParseKmeansInputQuery() throws IOException {
    String query = "{\"input_query\":{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"k1\":1}}]}},\"size\":10},\"input_index\":[\"test_data\"]}";
    XContentParser parser = parser(query);
    MLInput mlInput = MLInput.parse(parser, FunctionName.KMEANS.name());
    String expectedQuery = "{\"size\":10,\"query\":{\"bool\":{\"filter\":[{\"term\":{\"k1\":{\"value\":1,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}}}";
    SearchQueryInputDataset inputDataset = (SearchQueryInputDataset) mlInput.getInputDataset();
    assertEquals(expectedQuery, inputDataset.getSearchSourceBuilder().toString());
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) SearchQueryInputDataset(org.opensearch.ml.common.dataset.SearchQueryInputDataset) XContentParser(org.opensearch.common.xcontent.XContentParser)

Example 7 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class MLInputTests method testParseKmeansInputDataFrame.

public void testParseKmeansInputDataFrame() throws IOException {
    String query = "{\"input_data\":{\"column_metas\":[{\"name\":\"total_sum\",\"column_type\":\"DOUBLE\"},{\"name\":\"is_error\"," + "\"column_type\":\"BOOLEAN\"}],\"rows\":[{\"values\":[{\"column_type\":\"DOUBLE\",\"value\":15}," + "{\"column_type\":\"BOOLEAN\",\"value\":false}]},{\"values\":[{\"column_type\":\"DOUBLE\",\"value\":100}," + "{\"column_type\":\"BOOLEAN\",\"value\":true}]}]}}";
    XContentParser parser = parser(query);
    MLInput mlInput = MLInput.parse(parser, FunctionName.KMEANS.name());
    DataFrameInputDataset inputDataset = (DataFrameInputDataset) mlInput.getInputDataset();
    DataFrame dataFrame = inputDataset.getDataFrame();
    assertEquals(2, dataFrame.columnMetas().length);
    assertEquals(ColumnType.DOUBLE, dataFrame.columnMetas()[0].getColumnType());
    assertEquals(ColumnType.BOOLEAN, dataFrame.columnMetas()[1].getColumnType());
    assertEquals("total_sum", dataFrame.columnMetas()[0].getName());
    assertEquals("is_error", dataFrame.columnMetas()[1].getName());
    assertEquals(ColumnType.DOUBLE, dataFrame.getRow(0).getValue(0).columnType());
    assertEquals(ColumnType.BOOLEAN, dataFrame.getRow(0).getValue(1).columnType());
    assertEquals(15.0, dataFrame.getRow(0).getValue(0).getValue());
    assertEquals(false, dataFrame.getRow(0).getValue(1).getValue());
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) XContentParser(org.opensearch.common.xcontent.XContentParser)

Example 8 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class MLCommonsRestTestCase method trainAndPredict.

public void trainAndPredict(RestClient client, FunctionName functionName, String indexName, MLAlgoParams params, SearchSourceBuilder searchSourceBuilder, Consumer<Map<String, Object>> function) throws IOException {
    MLInputDataset inputData = SearchQueryInputDataset.builder().indices(ImmutableList.of(indexName)).searchSourceBuilder(searchSourceBuilder).build();
    MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build();
    Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/_train_predict/" + functionName.name().toLowerCase(Locale.ROOT), ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null);
    HttpEntity entity = response.getEntity();
    assertNotNull(response);
    String entityString = TestHelper.httpEntityToString(entity);
    Map map = gson.fromJson(entityString, Map.class);
    Map<String, Object> predictionResult = (Map<String, Object>) map.get("prediction_result");
    if (function != null) {
        function.accept(predictionResult);
    }
}
Also used : Response(org.opensearch.client.Response) MLInput(org.opensearch.ml.common.parameter.MLInput) HttpEntity(org.apache.http.HttpEntity) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap)

Example 9 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class RestMLTrainAndPredictIT method trainAndPredictKmeansWithIrisData.

private void trainAndPredictKmeansWithIrisData(KMeansParams params, MLInputDataset inputData, Consumer<Map<Double, Integer>> function) throws IOException {
    MLInput kmeansInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(params).inputDataset(inputData).build();
    Response kmeansResponse = TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/_train_predict/kmeans", ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null);
    HttpEntity entity = kmeansResponse.getEntity();
    assertNotNull(kmeansResponse);
    String entityString = TestHelper.httpEntityToString(entity);
    Map map = gson.fromJson(entityString, Map.class);
    Map predictionResult = (Map) map.get("prediction_result");
    ArrayList rows = (ArrayList) predictionResult.get("rows");
    Map<Double, Integer> clusterCount = new HashMap<>();
    for (Object obj : rows) {
        Double value = (Double) ((Map) ((ArrayList) ((Map) obj).get("values")).get(0)).get("value");
        if (!clusterCount.containsKey(value)) {
            clusterCount.put(value, 1);
        } else {
            Integer count = clusterCount.get(value);
            clusterCount.put(value, ++count);
        }
    }
    function.accept(clusterCount);
}
Also used : Response(org.opensearch.client.Response) MLInput(org.opensearch.ml.common.parameter.MLInput) HttpEntity(org.apache.http.HttpEntity) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) ImmutableMap(com.google.common.collect.ImmutableMap) HashMap(java.util.HashMap) Map(java.util.Map)

Example 10 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class MLTrainingTaskRunner method train.

private void train(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
    ActionListener<MLTaskResponse> listener = ActionListener.wrap(r -> actionListener.onResponse(r), e -> {
        mlStats.createCounterStatIfAbsent(failureCountStat(mlTask.getFunctionName(), ActionName.TRAIN)).increment();
        mlStats.getStat(ML_TOTAL_FAILURE_COUNT).increment();
        actionListener.onFailure(e);
    });
    try {
        // run training
        mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
        Model model = MLEngine.train(mlInput);
        mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(indexCreated -> {
            if (!indexCreated) {
                listener.onFailure(new RuntimeException("No response to create ML task index"));
                return;
            }
            // TODO: put the user into model for backend role based access control.
            MLModel mlModel = new MLModel(mlInput.getAlgorithm(), model);
            try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
                ActionListener<IndexResponse> indexResponseListener = ActionListener.wrap(r -> {
                    log.info("Model data indexing done, result:{}, model id: {}", r.getResult(), r.getId());
                    mlStats.getStat(ML_TOTAL_MODEL_COUNT).increment();
                    mlStats.createCounterStatIfAbsent(modelCountStat(mlTask.getFunctionName())).increment();
                    String returnedTaskId = mlTask.isAsync() ? mlTask.getTaskId() : null;
                    MLTrainingOutput output = new MLTrainingOutput(r.getId(), returnedTaskId, MLTaskState.COMPLETED.name());
                    listener.onResponse(MLTaskResponse.builder().output(output).build());
                }, e -> {
                    listener.onFailure(e);
                });
                IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX);
                indexRequest.source(mlModel.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
                indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                client.index(indexRequest, ActionListener.runBefore(indexResponseListener, () -> context.restore()));
            } catch (Exception e) {
                log.error("Failed to save ML model", e);
                listener.onFailure(e);
            }
        }, e -> {
            log.error("Failed to init ML model index", e);
            listener.onFailure(e);
        }));
    } catch (Exception e) {
        // todo need to specify what exception
        log.error("Failed to train " + mlInput.getAlgorithm(), e);
        listener.onFailure(e);
    }
}
Also used : IndexResponse(org.opensearch.action.index.IndexResponse) ToXContent(org.opensearch.common.xcontent.ToXContent) ThreadPool(org.opensearch.threadpool.ThreadPool) StatNames.modelCountStat(org.opensearch.ml.stats.StatNames.modelCountStat) MLTaskState(org.opensearch.ml.common.parameter.MLTaskState) MLInput(org.opensearch.ml.common.parameter.MLInput) MLInputDatasetHandler(org.opensearch.ml.indices.MLInputDatasetHandler) ThreadedActionListener(org.opensearch.action.support.ThreadedActionListener) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) MLTask(org.opensearch.ml.common.parameter.MLTask) ML_EXECUTING_TASK_COUNT(org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT) WriteRequest(org.opensearch.action.support.WriteRequest) ActionListener(org.opensearch.action.ActionListener) MLModel(org.opensearch.ml.common.parameter.MLModel) MLIndicesHandler(org.opensearch.ml.indices.MLIndicesHandler) ML_TOTAL_MODEL_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_MODEL_COUNT) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) Client(org.opensearch.client.Client) MLStats(org.opensearch.ml.stats.MLStats) ActionName(org.opensearch.ml.stats.ActionName) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) UUID(java.util.UUID) Instant(java.time.Instant) TransportService(org.opensearch.transport.TransportService) StatNames.requestCountStat(org.opensearch.ml.stats.StatNames.requestCountStat) MLTrainingTaskAction(org.opensearch.ml.common.transport.training.MLTrainingTaskAction) XContentBuilder(org.opensearch.common.xcontent.XContentBuilder) MLEngine(org.opensearch.ml.engine.MLEngine) StatNames.failureCountStat(org.opensearch.ml.stats.StatNames.failureCountStat) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) ML_MODEL_INDEX(org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX) Model(org.opensearch.ml.common.parameter.Model) ML_TOTAL_FAILURE_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_FAILURE_COUNT) MLTaskType(org.opensearch.ml.common.parameter.MLTaskType) MLCircuitBreakerService(org.opensearch.ml.common.breaker.MLCircuitBreakerService) Log4j2(lombok.extern.log4j.Log4j2) ActionListenerResponseHandler(org.opensearch.action.ActionListenerResponseHandler) ClusterService(org.opensearch.cluster.service.ClusterService) TASK_THREAD_POOL(org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL) ML_TOTAL_REQUEST_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_REQUEST_COUNT) XContentType(org.opensearch.common.xcontent.XContentType) IndexRequest(org.opensearch.action.index.IndexRequest) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLTrainingTaskRequest(org.opensearch.ml.common.transport.training.MLTrainingTaskRequest) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) ThreadedActionListener(org.opensearch.action.support.ThreadedActionListener) ActionListener(org.opensearch.action.ActionListener) MLModel(org.opensearch.ml.common.parameter.MLModel) Model(org.opensearch.ml.common.parameter.Model) MLModel(org.opensearch.ml.common.parameter.MLModel) IndexRequest(org.opensearch.action.index.IndexRequest)

Aggregations

MLInput (org.opensearch.ml.common.parameter.MLInput)46 Test (org.junit.Test)18 MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)13 MLTaskResponse (org.opensearch.ml.common.transport.MLTaskResponse)12 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)11 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)10 Input (org.opensearch.ml.common.parameter.Input)9 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)9 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)7 MLPredictionTaskRequest (org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)7 MLTrainingTaskRequest (org.opensearch.ml.common.transport.training.MLTrainingTaskRequest)7 MLOutput (org.opensearch.ml.common.parameter.MLOutput)6 XContentParser (org.opensearch.common.xcontent.XContentParser)5 Response (org.opensearch.client.Response)4 Model (org.opensearch.ml.common.parameter.Model)4 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)4 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)4 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)4 VisibleForTesting (com.google.common.annotations.VisibleForTesting)3 Instant (java.time.Instant)3