use of org.opensearch.ml.common.parameter.FunctionName in project ml-commons by opensearch-project.
the class MLCommonsClassLoader method loadMLAlgoParameterClassMapping.
/**
* Load ML algorithm parameter and ML output class.
*/
private static void loadMLAlgoParameterClassMapping() {
Reflections reflections = new Reflections("org.opensearch.ml.common.parameter");
Set<Class<?>> classes = reflections.getTypesAnnotatedWith(MLAlgoParameter.class);
// Load ML algorithm parameter class
for (Class<?> clazz : classes) {
MLAlgoParameter mlAlgoParameter = clazz.getAnnotation(MLAlgoParameter.class);
FunctionName[] algorithms = mlAlgoParameter.algorithms();
if (algorithms != null && algorithms.length > 0) {
for (FunctionName name : algorithms) {
parameterClassMap.put(name, clazz);
}
}
}
// Load ML output class
classes = reflections.getTypesAnnotatedWith(MLAlgoOutput.class);
for (Class<?> clazz : classes) {
MLAlgoOutput mlAlgoOutput = clazz.getAnnotation(MLAlgoOutput.class);
MLOutputType mlOutputType = mlAlgoOutput.value();
if (mlOutputType != null) {
parameterClassMap.put(mlOutputType, clazz);
}
}
}
use of org.opensearch.ml.common.parameter.FunctionName in project ml-commons by opensearch-project.
the class MLEngineTest method train_NullInput.
@Test
public void train_NullInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Input should not be null");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLEngine.train(null);
}
}
use of org.opensearch.ml.common.parameter.FunctionName in project ml-commons by opensearch-project.
the class MLEngineTest method train_NullDataFrame.
@Test
public void train_NullDataFrame() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Input data frame should not be null or empty");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLEngine.train(MLInput.builder().algorithm(algoName).build());
}
}
use of org.opensearch.ml.common.parameter.FunctionName in project ml-commons by opensearch-project.
the class MLEngineTest method train_EmptyDataFrame.
@Test
public void train_EmptyDataFrame() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Input data frame should not be null or empty");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(0)).build();
MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build());
}
}
use of org.opensearch.ml.common.parameter.FunctionName in project ml-commons by opensearch-project.
the class MLEngineTest method predictUnsupportedAlgorithm.
@Test
public void predictUnsupportedAlgorithm() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Unsupported algorithm: LINEAR_REGRESSION");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructLinearRegressionPredictionDataFrame()).build();
Input mlInput = MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build();
MLEngine.predict(mlInput, null);
}
}
Aggregations