use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MachineLearningClientTest method predict_WithAlgoAndInputDataAndParametersAndListener.
@Test
public void predict_WithAlgoAndInputDataAndParametersAndListener() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(mlParameters).inputDataset(new DataFrameInputDataset(input)).build();
ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
machineLearningClient.predict(null, mlInput, dataFrameActionListener);
verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
assertEquals(output, dataFrameArgumentCaptor.getValue());
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MachineLearningClientTest method predict_WithAlgoAndParametersAndInputData.
@Test
public void predict_WithAlgoAndParametersAndInputData() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(mlParameters).inputDataset(new DataFrameInputDataset(input)).build();
assertEquals(output, machineLearningClient.predict(null, mlInput).actionGet());
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MLEngineTest method predictWithoutAlgoName.
@Test
public void predictWithoutAlgoName() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("algorithm can't be null");
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build();
Input mlInput = MLInput.builder().inputDataset(inputDataset).build();
MLEngine.predict(mlInput, null);
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MLEngineTest method predictUnsupportedAlgorithm.
@Test
public void predictUnsupportedAlgorithm() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Unsupported algorithm: LINEAR_REGRESSION");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructLinearRegressionPredictionDataFrame()).build();
Input mlInput = MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build();
MLEngine.predict(mlInput, null);
}
}
use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.
the class MLEngineTest method trainAndPredictWithInvalidInput.
@Test
public void trainAndPredictWithInvalidInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Input should be MLInput");
Input input = new LocalSampleCalculatorInput("sum", Arrays.asList(1.0, 2.0));
MLEngine.trainAndPredict(input);
}
Aggregations