use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class MLPredictTaskRunner method startPredictionTask.
/**
* Start prediction task
* @param request MLPredictionTaskRequest
* @param listener Action listener
*/
public void startPredictionTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask.builder().taskId(UUID.randomUUID().toString()).modelId(request.getModelId()).taskType(MLTaskType.PREDICTION).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(false).build();
MLInput mlInput = request.getMlInput();
if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
ActionListener<DataFrame> dataFrameActionListener = ActionListener.wrap(dataFrame -> {
predict(mlTask, dataFrame, request, listener);
}, e -> {
log.error("Failed to generate DataFrame from search query", e);
handleAsyncMLTaskFailure(mlTask, e);
listener.onFailure(e);
});
mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false));
} else {
DataFrame inputDataFrame = mlInputDatasetHandler.parseDataFrameInput(mlInput.getInputDataset());
threadPool.executor(TASK_THREAD_POOL).execute(() -> {
predict(mlTask, inputDataFrame, request, listener);
});
}
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class MLEngineTest method predictLinearRegression.
@Test
public void predictLinearRegression() {
Model model = trainLinearRegressionModel();
DataFrame predictionDataFrame = constructLinearRegressionPredictionDataFrame();
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build();
MLPredictionOutput output = (MLPredictionOutput) MLEngine.predict(mlInput, model);
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(2, predictions.size());
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class MLEngineTest method trainAndPredictWithKmeans.
@Test
public void trainAndPredictWithKmeans() {
int dataSize = 100;
MLAlgoParams parameters = KMeansParams.builder().build();
DataFrame dataFrame = constructKMeansDataFrame(dataSize);
MLInputDataset inputData = new DataFrameInputDataset(dataFrame);
Input input = new MLInput(FunctionName.KMEANS, parameters, inputData);
MLPredictionOutput output = (MLPredictionOutput) MLEngine.trainAndPredict(input);
Assert.assertEquals(dataSize, output.getPredictionResult().size());
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class MLEngineTest method trainLinearRegressionModel.
private Model trainLinearRegressionModel() {
LinearRegressionParams parameters = LinearRegressionParams.builder().objectiveType(LinearRegressionParams.ObjectiveType.SQUARED_LOSS).optimizerType(LinearRegressionParams.OptimizerType.ADAM).learningRate(0.01).epsilon(1e-6).beta1(0.9).beta2(0.99).target("price").build();
DataFrame trainDataFrame = constructLinearRegressionTrainDataFrame();
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).parameters(parameters).inputDataset(inputDataset).build();
return MLEngine.train(mlInput);
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class MLEngineTest method predictKMeans.
@Test
public void predictKMeans() {
Model model = trainKMeansModel();
DataFrame predictionDataFrame = constructKMeansDataFrame(10);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
MLPredictionOutput output = (MLPredictionOutput) MLEngine.predict(mlInput, model);
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(10, predictions.size());
predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
}
Aggregations