Search in sources :

Example 1 with LocalSampleCalculatorInput

use of org.opensearch.ml.common.parameter.LocalSampleCalculatorInput in project ml-commons by opensearch-project.

the class MLEngineClassLoaderTests method initInstance_LocalSampleCalculator.

@Test
public void initInstance_LocalSampleCalculator() {
    List<Double> inputData = new ArrayList<>();
    double d1 = 10.0;
    double d2 = 20.0;
    inputData.add(d1);
    inputData.add(d2);
    LocalSampleCalculatorInput input = LocalSampleCalculatorInput.builder().operation("sum").inputData(inputData).build();
    Map<String, Object> properties = new HashMap<>();
    properties.put("wrongField", "test");
    Client client = mock(Client.class);
    properties.put("client", client);
    Settings settings = Settings.EMPTY;
    properties.put("settings", settings);
    // set properties
    MLEngineClassLoader.deregister(FunctionName.LOCAL_SAMPLE_CALCULATOR);
    LocalSampleCalculator instance = MLEngineClassLoader.initInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, input, Input.class, properties);
    SampleAlgoOutput output = (SampleAlgoOutput) instance.execute(input);
    assertEquals(d1 + d2, output.getSampleResult(), 1e-6);
    assertEquals(client, instance.getClient());
    assertEquals(settings, instance.getSettings());
    // don't set properties
    instance = MLEngineClassLoader.initInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, input, Input.class);
    output = (SampleAlgoOutput) instance.execute(input);
    assertEquals(d1 + d2, output.getSampleResult(), 1e-6);
    assertNull(instance.getClient());
    assertNull(instance.getSettings());
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) LocalSampleCalculator(org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) Client(org.opensearch.client.Client) SampleAlgoOutput(org.opensearch.ml.common.parameter.SampleAlgoOutput) Settings(org.opensearch.common.settings.Settings) Test(org.junit.Test)

Example 2 with LocalSampleCalculatorInput

use of org.opensearch.ml.common.parameter.LocalSampleCalculatorInput in project ml-commons by opensearch-project.

the class LocalSampleCalculatorTest method executeWithWrongOperation.

@Test
public void executeWithWrongOperation() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("can't support this operation");
    input = new LocalSampleCalculatorInput("wrong_operation", Arrays.asList(1.0, 2.0, 3.0));
    calculator.execute(input);
}
Also used : LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) Test(org.junit.Test)

Example 3 with LocalSampleCalculatorInput

use of org.opensearch.ml.common.parameter.LocalSampleCalculatorInput in project ml-commons by opensearch-project.

the class LocalSampleCalculator method execute.

@Override
public Output execute(Input input) {
    if (input == null || !(input instanceof LocalSampleCalculatorInput)) {
        throw new IllegalArgumentException("wrong input");
    }
    LocalSampleCalculatorInput sampleCalculatorInput = (LocalSampleCalculatorInput) input;
    String operation = sampleCalculatorInput.getOperation();
    List<Double> inputData = sampleCalculatorInput.getInputData();
    switch(operation) {
        case "sum":
            double sum = inputData.stream().mapToDouble(f -> f.doubleValue()).sum();
            return new SampleAlgoOutput(sum);
        case "max":
            double max = inputData.stream().max(Comparator.naturalOrder()).get();
            return new SampleAlgoOutput(max);
        case "min":
            double min = inputData.stream().min(Comparator.naturalOrder()).get();
            return new SampleAlgoOutput(min);
        default:
            throw new IllegalArgumentException("can't support this operation");
    }
}
Also used : Client(org.opensearch.client.Client) Executable(org.opensearch.ml.engine.Executable) Function(org.opensearch.ml.engine.annotation.Function) Settings(org.opensearch.common.settings.Settings) SampleAlgoOutput(org.opensearch.ml.common.parameter.SampleAlgoOutput) Input(org.opensearch.ml.common.parameter.Input) Output(org.opensearch.ml.common.parameter.Output) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) List(java.util.List) FunctionName(org.opensearch.ml.common.parameter.FunctionName) Data(lombok.Data) Comparator(java.util.Comparator) NoArgsConstructor(lombok.NoArgsConstructor) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) SampleAlgoOutput(org.opensearch.ml.common.parameter.SampleAlgoOutput)

Example 4 with LocalSampleCalculatorInput

use of org.opensearch.ml.common.parameter.LocalSampleCalculatorInput in project ml-commons by opensearch-project.

the class MLEngineTest method executeLocalSampleCalculator.

@Test
public void executeLocalSampleCalculator() {
    Input input = new LocalSampleCalculatorInput("sum", Arrays.asList(1.0, 2.0));
    SampleAlgoOutput output = (SampleAlgoOutput) MLEngine.execute(input);
    Assert.assertEquals(3.0, output.getSampleResult(), 1e-5);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) SampleAlgoOutput(org.opensearch.ml.common.parameter.SampleAlgoOutput) Test(org.junit.Test)

Example 5 with LocalSampleCalculatorInput

use of org.opensearch.ml.common.parameter.LocalSampleCalculatorInput in project ml-commons by opensearch-project.

the class MLEngineTest method trainAndPredictWithInvalidInput.

@Test
public void trainAndPredictWithInvalidInput() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("Input should be MLInput");
    Input input = new LocalSampleCalculatorInput("sum", Arrays.asList(1.0, 2.0));
    MLEngine.trainAndPredict(input);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) Test(org.junit.Test)

Aggregations

LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)7 Test (org.junit.Test)5 Input (org.opensearch.ml.common.parameter.Input)4 SampleAlgoOutput (org.opensearch.ml.common.parameter.SampleAlgoOutput)4 Client (org.opensearch.client.Client)2 Settings (org.opensearch.common.settings.Settings)2 MLInput (org.opensearch.ml.common.parameter.MLInput)2 ArrayList (java.util.ArrayList)1 Comparator (java.util.Comparator)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Data (lombok.Data)1 NoArgsConstructor (lombok.NoArgsConstructor)1 Before (org.junit.Before)1 FunctionName (org.opensearch.ml.common.parameter.FunctionName)1 Output (org.opensearch.ml.common.parameter.Output)1 Executable (org.opensearch.ml.engine.Executable)1 LocalSampleCalculator (org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator)1 Function (org.opensearch.ml.engine.annotation.Function)1