use of org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest in project ml-commons by opensearch-project.
the class RestMLPredictionAction method getRequest.
/**
* Creates a MLPredictionTaskRequest from a RestRequest
*
* @param request RestRequest
* @return MLPredictionTaskRequest
*/
@VisibleForTesting
MLPredictionTaskRequest getRequest(RestRequest request) throws IOException {
String algorithm = getAlgorithm(request);
String modelId = getParameterId(request, PARAMETER_MODEL_ID);
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLInput mlInput = MLInput.parse(parser, algorithm);
return new MLPredictionTaskRequest(modelId, mlInput);
}
use of org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest in project ml-commons by opensearch-project.
the class TransportPredictionTaskAction method doExecute.
@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskResponse> listener) {
MLPredictionTaskRequest mlPredictionTaskRequest = MLPredictionTaskRequest.fromActionRequest(request);
mlPredictTaskRunner.run(mlPredictionTaskRequest, transportService, listener);
}
use of org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest in project ml-commons by opensearch-project.
the class IntegTestUtils method predictAndVerifyResult.
// Predict with the model generated, and verify the prediction result.
public static void predictAndVerifyResult(String taskId, MLInputDataset inputDataset) throws IOException {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest(taskId, mlInput);
ActionFuture<MLTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest);
MLTaskResponse predictionResponse = predictionFuture.actionGet();
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
builder.startObject();
MLPredictionOutput mlPredictionOutput = (MLPredictionOutput) predictionResponse.getOutput();
mlPredictionOutput.getPredictionResult().toXContent(builder, ToXContent.EMPTY_PARAMS);
builder.endObject();
String jsonStr = Strings.toString(builder);
String expectedStr1 = "{\"column_metas\":[{\"name\":\"ClusterID\",\"column_type\":\"INTEGER\"}]," + "\"rows\":[{\"values\":[{\"column_type\":\"INTEGER\",\"value\":0}]}]}";
String expectedStr2 = "{\"column_metas\":[{\"name\":\"ClusterID\",\"column_type\":\"INTEGER\"}]," + "\"rows\":[{\"values\":[{\"column_type\":\"INTEGER\",\"value\":1}]}]}";
// The prediction result would not be a fixed value.
assertTrue(expectedStr1.equals(jsonStr) || expectedStr2.equals(jsonStr));
}
Aggregations