Search in sources :

Example 1 with FunctionName

use of org.opensearch.ml.common.parameter.FunctionName in project ml-commons by opensearch-project.

the class MLCommonsClassLoader method loadMLAlgoParameterClassMapping.

/**
 * Load ML algorithm parameter and ML output class.
 */
private static void loadMLAlgoParameterClassMapping() {
    Reflections reflections = new Reflections("org.opensearch.ml.common.parameter");
    Set<Class<?>> classes = reflections.getTypesAnnotatedWith(MLAlgoParameter.class);
    // Load ML algorithm parameter class
    for (Class<?> clazz : classes) {
        MLAlgoParameter mlAlgoParameter = clazz.getAnnotation(MLAlgoParameter.class);
        FunctionName[] algorithms = mlAlgoParameter.algorithms();
        if (algorithms != null && algorithms.length > 0) {
            for (FunctionName name : algorithms) {
                parameterClassMap.put(name, clazz);
            }
        }
    }
    // Load ML output class
    classes = reflections.getTypesAnnotatedWith(MLAlgoOutput.class);
    for (Class<?> clazz : classes) {
        MLAlgoOutput mlAlgoOutput = clazz.getAnnotation(MLAlgoOutput.class);
        MLOutputType mlOutputType = mlAlgoOutput.value();
        if (mlOutputType != null) {
            parameterClassMap.put(mlOutputType, clazz);
        }
    }
}
Also used : FunctionName(org.opensearch.ml.common.parameter.FunctionName) MLOutputType(org.opensearch.ml.common.parameter.MLOutputType) MLAlgoOutput(org.opensearch.ml.common.annotation.MLAlgoOutput) MLAlgoParameter(org.opensearch.ml.common.annotation.MLAlgoParameter) Reflections(org.reflections.Reflections)

Example 2 with FunctionName

use of org.opensearch.ml.common.parameter.FunctionName in project ml-commons by opensearch-project.

the class MLEngineTest method train_NullInput.

@Test
public void train_NullInput() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("Input should not be null");
    FunctionName algoName = FunctionName.LINEAR_REGRESSION;
    try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
        loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
        MLEngine.train(null);
    }
}
Also used : FunctionName(org.opensearch.ml.common.parameter.FunctionName) Test(org.junit.Test)

Example 3 with FunctionName

use of org.opensearch.ml.common.parameter.FunctionName in project ml-commons by opensearch-project.

the class MLEngineTest method train_NullDataFrame.

@Test
public void train_NullDataFrame() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("Input data frame should not be null or empty");
    FunctionName algoName = FunctionName.LINEAR_REGRESSION;
    try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
        loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
        MLEngine.train(MLInput.builder().algorithm(algoName).build());
    }
}
Also used : FunctionName(org.opensearch.ml.common.parameter.FunctionName) Test(org.junit.Test)

Example 4 with FunctionName

use of org.opensearch.ml.common.parameter.FunctionName in project ml-commons by opensearch-project.

the class MLEngineTest method train_EmptyDataFrame.

@Test
public void train_EmptyDataFrame() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("Input data frame should not be null or empty");
    FunctionName algoName = FunctionName.LINEAR_REGRESSION;
    try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
        loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
        MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(0)).build();
        MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build());
    }
}
Also used : FunctionName(org.opensearch.ml.common.parameter.FunctionName) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) Test(org.junit.Test)

Example 5 with FunctionName

use of org.opensearch.ml.common.parameter.FunctionName in project ml-commons by opensearch-project.

the class MLEngineTest method predictUnsupportedAlgorithm.

@Test
public void predictUnsupportedAlgorithm() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("Unsupported algorithm: LINEAR_REGRESSION");
    FunctionName algoName = FunctionName.LINEAR_REGRESSION;
    try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
        loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
        MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructLinearRegressionPredictionDataFrame()).build();
        Input mlInput = MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build();
        MLEngine.predict(mlInput, null);
    }
}
Also used : FunctionName(org.opensearch.ml.common.parameter.FunctionName) MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) Test(org.junit.Test)

Aggregations

FunctionName (org.opensearch.ml.common.parameter.FunctionName)7 Test (org.junit.Test)5 MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)3 Reflections (org.reflections.Reflections)2 MLAlgoOutput (org.opensearch.ml.common.annotation.MLAlgoOutput)1 MLAlgoParameter (org.opensearch.ml.common.annotation.MLAlgoParameter)1 Input (org.opensearch.ml.common.parameter.Input)1 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)1 MLInput (org.opensearch.ml.common.parameter.MLInput)1 MLOutputType (org.opensearch.ml.common.parameter.MLOutputType)1 Function (org.opensearch.ml.engine.annotation.Function)1