use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class MLPredictionTaskRequestTest method writeTo_Success.
@Test
public void writeTo_Success() throws IOException {
MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).build();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
request.writeTo(bytesStreamOutput);
request = new MLPredictionTaskRequest(bytesStreamOutput.bytes().streamInput());
assertEquals(FunctionName.KMEANS, request.getMlInput().getAlgorithm());
KMeansParams params = (KMeansParams) request.getMlInput().getParameters();
assertEquals(1, params.getCentroids().intValue());
MLInputDataset inputDataset = request.getMlInput().getInputDataset();
assertEquals(MLInputDataType.DATA_FRAME, inputDataset.getInputDataType());
DataFrame dataFrame = ((DataFrameInputDataset) inputDataset).getDataFrame();
assertEquals(1, dataFrame.size());
assertEquals(1, dataFrame.columnMetas().length);
assertEquals("key1", dataFrame.columnMetas()[0].getName());
assertEquals(ColumnType.DOUBLE, dataFrame.columnMetas()[0].getColumnType());
assertEquals(1, dataFrame.getRow(0).size());
assertEquals(2.00, dataFrame.getRow(0).getValue(0).getValue());
assertNull(request.getModelId());
}
use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class MLEngineTest method predictWithoutModel.
@Test
public void predictWithoutModel() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("No model found for linear regression prediction.");
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructLinearRegressionPredictionDataFrame()).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build();
MLEngine.predict(mlInput, null);
}
use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class MLEngineTest method train_UnsupportedAlgorithm.
@Test
public void train_UnsupportedAlgorithm() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Unsupported algorithm: LINEAR_REGRESSION");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build();
MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build());
}
}
use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class MLEngineTest method predictLinearRegression.
@Test
public void predictLinearRegression() {
Model model = trainLinearRegressionModel();
DataFrame predictionDataFrame = constructLinearRegressionPredictionDataFrame();
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build();
MLPredictionOutput output = (MLPredictionOutput) MLEngine.predict(mlInput, model);
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(2, predictions.size());
}
use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class MLEngineTest method trainAndPredictWithKmeans.
@Test
public void trainAndPredictWithKmeans() {
int dataSize = 100;
MLAlgoParams parameters = KMeansParams.builder().build();
DataFrame dataFrame = constructKMeansDataFrame(dataSize);
MLInputDataset inputData = new DataFrameInputDataset(dataFrame);
Input input = new MLInput(FunctionName.KMEANS, parameters, inputData);
MLPredictionOutput output = (MLPredictionOutput) MLEngine.trainAndPredict(input);
Assert.assertEquals(dataSize, output.getPredictionResult().size());
}
Aggregations