use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.
the class MachineLearningNodeClientTest method predict.
@SuppressWarnings("unchecked")
@Test
public void predict() {
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
MLPredictionOutput predictionOutput = MLPredictionOutput.builder().status("Success").predictionResult(output).taskId("taskId").build();
actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build());
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());
ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener);
verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), isA(MLPredictionTaskRequest.class), any(ActionListener.class));
verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
assertEquals(output, ((MLPredictionOutput) dataFrameArgumentCaptor.getValue()).getPredictionResult());
}
use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.
the class MLPredictionTaskResponseTest method fromActionResponse_WithNonMLPredictionTaskResponse.
@Test
public void fromActionResponse_WithNonMLPredictionTaskResponse() {
MLPredictionOutput output = MLPredictionOutput.builder().taskId("taskId").status("Success").predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {
{
put("key1", 2.0D);
}
}))).build();
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
response.writeTo(out);
}
};
MLTaskResponse result = MLTaskResponse.fromActionResponse(actionResponse);
assertNotSame(response, result);
MLPredictionOutput mlPredictionOutput = (MLPredictionOutput) response.getOutput();
MLPredictionOutput resultMlPredictionOutput = (MLPredictionOutput) result.getOutput();
assertEquals(mlPredictionOutput.getTaskId(), resultMlPredictionOutput.getTaskId());
assertEquals(mlPredictionOutput.getStatus(), resultMlPredictionOutput.getStatus());
assertEquals(mlPredictionOutput.getPredictionResult().size(), resultMlPredictionOutput.getPredictionResult().size());
}
use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.
the class MLPredictionTaskResponseTest method fromActionResponse_WithMLPredictionTaskResponse.
@Test
public void fromActionResponse_WithMLPredictionTaskResponse() {
MLPredictionOutput output = MLPredictionOutput.builder().taskId("taskId").status("Success").predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {
{
put("key1", 2.0D);
}
}))).build();
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
assertSame(response, MLTaskResponse.fromActionResponse(response));
}
use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.
the class MLPredictionTaskResponseTest method writeTo_Success.
@Test
public void writeTo_Success() throws IOException {
MLPredictionOutput output = MLPredictionOutput.builder().taskId("taskId").status("Success").predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {
{
put("key1", 2.0D);
}
}))).build();
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
response.writeTo(bytesStreamOutput);
response = new MLTaskResponse(bytesStreamOutput.bytes().streamInput());
MLPredictionOutput mlPredictionOutput = (MLPredictionOutput) response.getOutput();
assertEquals("taskId", mlPredictionOutput.getTaskId());
assertEquals("Success", mlPredictionOutput.getStatus());
assertEquals(1, mlPredictionOutput.getPredictionResult().size());
}
use of org.opensearch.ml.common.parameter.MLPredictionOutput in project ml-commons by opensearch-project.
the class KMeansTest method trainAndPredict.
@Test
public void trainAndPredict() {
KMeansParams parameters = KMeansParams.builder().distanceType(KMeansParams.DistanceType.EUCLIDEAN).iterations(10).centroids(2).build();
KMeans kMeans = new KMeans(parameters);
MLPredictionOutput output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(trainSize, predictions.size());
predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
parameters = parameters.toBuilder().distanceType(KMeansParams.DistanceType.COSINE).build();
kMeans = new KMeans(parameters);
output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
predictions = output.getPredictionResult();
Assert.assertEquals(trainSize, predictions.size());
parameters = parameters.toBuilder().distanceType(KMeansParams.DistanceType.L1).build();
kMeans = new KMeans(parameters);
output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
predictions = output.getPredictionResult();
Assert.assertEquals(trainSize, predictions.size());
}
Aggregations