Search in sources :

Example 6 with MLPredictionOutput

use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.

the class MachineLearningNodeClientTest method predict.

@SuppressWarnings("unchecked")
@Test
public void predict() {
    doAnswer(invocation -> {
        ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
        MLPredictionOutput predictionOutput = MLPredictionOutput.builder().status("Success").predictionResult(output).taskId("taskId").build();
        actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build());
        return null;
    }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());
    ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
    MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
    machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener);
    verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), isA(MLPredictionTaskRequest.class), any(ActionListener.class));
    verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
    assertEquals(output, ((MLPredictionOutput) dataFrameArgumentCaptor.getValue()).getPredictionResult());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInput(org.opensearch.ml.common.parameter.MLInput) ActionListener(org.opensearch.action.ActionListener) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) MLOutput(org.opensearch.ml.common.parameter.MLOutput) MLPredictionTaskRequest(org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest) Test(org.junit.Test)

Example 7 with MLPredictionOutput

use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.

the class MLPredictionTaskResponseTest method fromActionResponse_WithNonMLPredictionTaskResponse.

@Test
public void fromActionResponse_WithNonMLPredictionTaskResponse() {
    MLPredictionOutput output = MLPredictionOutput.builder().taskId("taskId").status("Success").predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {

        {
            put("key1", 2.0D);
        }
    }))).build();
    MLTaskResponse response = MLTaskResponse.builder().output(output).build();
    ActionResponse actionResponse = new ActionResponse() {

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            response.writeTo(out);
        }
    };
    MLTaskResponse result = MLTaskResponse.fromActionResponse(actionResponse);
    assertNotSame(response, result);
    MLPredictionOutput mlPredictionOutput = (MLPredictionOutput) response.getOutput();
    MLPredictionOutput resultMlPredictionOutput = (MLPredictionOutput) result.getOutput();
    assertEquals(mlPredictionOutput.getTaskId(), resultMlPredictionOutput.getTaskId());
    assertEquals(mlPredictionOutput.getStatus(), resultMlPredictionOutput.getStatus());
    assertEquals(mlPredictionOutput.getPredictionResult().size(), resultMlPredictionOutput.getPredictionResult().size());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) HashMap(java.util.HashMap) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) StreamOutput(org.opensearch.common.io.stream.StreamOutput) BytesStreamOutput(org.opensearch.common.io.stream.BytesStreamOutput) ActionResponse(org.opensearch.action.ActionResponse) Test(org.junit.Test)

Example 8 with MLPredictionOutput

use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.

the class MLPredictionTaskResponseTest method fromActionResponse_WithMLPredictionTaskResponse.

@Test
public void fromActionResponse_WithMLPredictionTaskResponse() {
    MLPredictionOutput output = MLPredictionOutput.builder().taskId("taskId").status("Success").predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {

        {
            put("key1", 2.0D);
        }
    }))).build();
    MLTaskResponse response = MLTaskResponse.builder().output(output).build();
    assertSame(response, MLTaskResponse.fromActionResponse(response));
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) HashMap(java.util.HashMap) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) Test(org.junit.Test)

Example 9 with MLPredictionOutput

use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.

the class MLPredictionTaskResponseTest method writeTo_Success.

@Test
public void writeTo_Success() throws IOException {
    MLPredictionOutput output = MLPredictionOutput.builder().taskId("taskId").status("Success").predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {

        {
            put("key1", 2.0D);
        }
    }))).build();
    MLTaskResponse response = MLTaskResponse.builder().output(output).build();
    BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
    response.writeTo(bytesStreamOutput);
    response = new MLTaskResponse(bytesStreamOutput.bytes().streamInput());
    MLPredictionOutput mlPredictionOutput = (MLPredictionOutput) response.getOutput();
    assertEquals("taskId", mlPredictionOutput.getTaskId());
    assertEquals("Success", mlPredictionOutput.getStatus());
    assertEquals(1, mlPredictionOutput.getPredictionResult().size());
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) HashMap(java.util.HashMap) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) BytesStreamOutput(org.opensearch.common.io.stream.BytesStreamOutput) Test(org.junit.Test)

Example 10 with MLPredictionOutput

use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.

the class KMeansTest method trainAndPredict.

@Test
public void trainAndPredict() {
    KMeansParams parameters = KMeansParams.builder().distanceType(KMeansParams.DistanceType.EUCLIDEAN).iterations(10).centroids(2).build();
    KMeans kMeans = new KMeans(parameters);
    MLPredictionOutput output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
    DataFrame predictions = output.getPredictionResult();
    Assert.assertEquals(trainSize, predictions.size());
    predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
    parameters = parameters.toBuilder().distanceType(KMeansParams.DistanceType.COSINE).build();
    kMeans = new KMeans(parameters);
    output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
    predictions = output.getPredictionResult();
    Assert.assertEquals(trainSize, predictions.size());
    parameters = parameters.toBuilder().distanceType(KMeansParams.DistanceType.L1).build();
    kMeans = new KMeans(parameters);
    output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
    predictions = output.getPredictionResult();
    Assert.assertEquals(trainSize, predictions.size());
}
Also used : KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) Test(org.junit.Test)

Aggregations

MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)19 Test (org.junit.Test)16 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)8 Model (org.opensearch.ml.common.parameter.Model)8 MLTaskResponse (org.opensearch.ml.common.transport.MLTaskResponse)8 MLInput (org.opensearch.ml.common.parameter.MLInput)7 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)5 HashMap (java.util.HashMap)4 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)4 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)4 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)3 MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)3 Input (org.opensearch.ml.common.parameter.Input)3 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)3 MLOutput (org.opensearch.ml.common.parameter.MLOutput)3 BytesStreamOutput (org.opensearch.common.io.stream.BytesStreamOutput)2 XContentBuilder (org.opensearch.common.xcontent.XContentBuilder)2 DefaultDataFrame (org.opensearch.ml.common.dataframe.DefaultDataFrame)2 MLPredictionTaskRequest (org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)2 OpenSearchException (org.opensearch.OpenSearchException)1