Search in sources :

Example 6 with FunctionName

use of org.opensearch.ml.common.parameter.FunctionName in project ml-commons by opensearch-project.

the class MLEngineClassLoader method loadClassMapping.

public static void loadClassMapping() {
    Reflections reflections = new Reflections("org.opensearch.ml.engine.algorithms");
    Set<Class<?>> classes = reflections.getTypesAnnotatedWith(Function.class);
    // Load ML algorithm parameter class
    for (Class<?> clazz : classes) {
        Function function = clazz.getAnnotation(Function.class);
        FunctionName functionName = function.value();
        if (functionName != null) {
            mlAlgoClassMap.put(functionName, clazz);
        }
    }
}
Also used : Function(org.opensearch.ml.engine.annotation.Function) FunctionName(org.opensearch.ml.common.parameter.FunctionName) Reflections(org.reflections.Reflections)

Example 7 with FunctionName

use of org.opensearch.ml.common.parameter.FunctionName in project ml-commons by opensearch-project.

the class MLEngineTest method train_UnsupportedAlgorithm.

@Test
public void train_UnsupportedAlgorithm() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("Unsupported algorithm: LINEAR_REGRESSION");
    FunctionName algoName = FunctionName.LINEAR_REGRESSION;
    try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
        loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
        MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build();
        MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build());
    }
}
Also used : FunctionName(org.opensearch.ml.common.parameter.FunctionName) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) Test(org.junit.Test)

Aggregations

FunctionName (org.opensearch.ml.common.parameter.FunctionName)7 Test (org.junit.Test)5 MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)3 Reflections (org.reflections.Reflections)2 MLAlgoOutput (org.opensearch.ml.common.annotation.MLAlgoOutput)1 MLAlgoParameter (org.opensearch.ml.common.annotation.MLAlgoParameter)1 Input (org.opensearch.ml.common.parameter.Input)1 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)1 MLInput (org.opensearch.ml.common.parameter.MLInput)1 MLOutputType (org.opensearch.ml.common.parameter.MLOutputType)1 Function (org.opensearch.ml.engine.annotation.Function)1