Search in sources :

Example 21 with MLTaskResponse

use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.

the class IntegTestUtils method trainModel.

// Train a model.
public static String trainModel(MLInputDataset inputDataset) throws ExecutionException, InterruptedException {
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
    // TODO: support train test in sync way
    MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput, true);
    ActionFuture<MLTaskResponse> trainingFuture = client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest);
    MLTaskResponse trainingResponse = trainingFuture.actionGet();
    assertNotNull(trainingResponse);
    MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) trainingResponse.getOutput();
    String modelId = modelTrainingOutput.getModelId();
    String status = modelTrainingOutput.getStatus();
    assertNotNull(modelId);
    assertFalse(modelId.isEmpty());
    assertEquals("CREATED", status);
    return modelId;
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) MLInput(org.opensearch.ml.common.parameter.MLInput) MLTrainingTaskRequest(org.opensearch.ml.common.transport.training.MLTrainingTaskRequest)

Example 22 with MLTaskResponse

use of org.opensearch.ml.common.transport.MLTaskResponse in project ml-commons by opensearch-project.

the class IntegTestUtils method predictAndVerifyResult.

// Predict with the model generated, and verify the prediction result.
public static void predictAndVerifyResult(String taskId, MLInputDataset inputDataset) throws IOException {
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
    MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput);
    ActionFuture<MLTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest);
    MLTaskResponse predictionResponse = predictionFuture.actionGet();
    XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
    builder.startObject();
    MLPredictionOutput mlPredictionOutput = (MLPredictionOutput) predictionResponse.getOutput();
    mlPredictionOutput.getPredictionResult().toXContent(builder, ToXContent.EMPTY_PARAMS);
    builder.endObject();
    String jsonStr = Strings.toString(builder);
    String expectedStr1 = "{\"column_metas\":[{\"name\":\"ClusterID\",\"column_type\":\"INTEGER\"}]," + "\"rows\":[{\"values\":[{\"column_type\":\"INTEGER\",\"value\":0}]}]}";
    String expectedStr2 = "{\"column_metas\":[{\"name\":\"ClusterID\",\"column_type\":\"INTEGER\"}]," + "\"rows\":[{\"values\":[{\"column_type\":\"INTEGER\",\"value\":1}]}]}";
    // The prediction result would not be a fixed value.
    assertTrue(expectedStr1.equals(jsonStr) || expectedStr2.equals(jsonStr));
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInput(org.opensearch.ml.common.parameter.MLInput) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) MLPredictionTaskRequest(org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest) XContentBuilder(org.opensearch.common.xcontent.XContentBuilder)

Aggregations

MLTaskResponse (org.opensearch.ml.common.transport.MLTaskResponse)22 MLInput (org.opensearch.ml.common.parameter.MLInput)13 Test (org.junit.Test)9 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)8 MLTrainingOutput (org.opensearch.ml.common.parameter.MLTrainingOutput)7 MLPredictionTaskRequest (org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)6 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)5 MLTrainingTaskRequest (org.opensearch.ml.common.transport.training.MLTrainingTaskRequest)5 HashMap (java.util.HashMap)4 ActionListener (org.opensearch.action.ActionListener)4 BytesStreamOutput (org.opensearch.common.io.stream.BytesStreamOutput)4 XContentBuilder (org.opensearch.common.xcontent.XContentBuilder)4 MLOutput (org.opensearch.ml.common.parameter.MLOutput)4 ThreadContext (org.opensearch.common.util.concurrent.ThreadContext)3 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)3 MLModel (org.opensearch.ml.common.parameter.MLModel)3 Model (org.opensearch.ml.common.parameter.Model)3 Instant (java.time.Instant)2 UUID (java.util.UUID)2 Log4j2 (lombok.extern.log4j.Log4j2)2