use of org.opensearch.ml.common.parameter.SampleAlgoOutput in project ml-commons by opensearch-project.
the class MLEngineClassLoaderTests method initInstance_LocalSampleCalculator.
@Test
public void initInstance_LocalSampleCalculator() {
List<Double> inputData = new ArrayList<>();
double d1 = 10.0;
double d2 = 20.0;
inputData.add(d1);
inputData.add(d2);
LocalSampleCalculatorInput input = LocalSampleCalculatorInput.builder().operation("sum").inputData(inputData).build();
Map<String, Object> properties = new HashMap<>();
properties.put("wrongField", "test");
Client client = mock(Client.class);
properties.put("client", client);
Settings settings = Settings.EMPTY;
properties.put("settings", settings);
// set properties
MLEngineClassLoader.deregister(FunctionName.LOCAL_SAMPLE_CALCULATOR);
LocalSampleCalculator instance = MLEngineClassLoader.initInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, input, Input.class, properties);
SampleAlgoOutput output = (SampleAlgoOutput) instance.execute(input);
assertEquals(d1 + d2, output.getSampleResult(), 1e-6);
assertEquals(client, instance.getClient());
assertEquals(settings, instance.getSettings());
// don't set properties
instance = MLEngineClassLoader.initInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, input, Input.class);
output = (SampleAlgoOutput) instance.execute(input);
assertEquals(d1 + d2, output.getSampleResult(), 1e-6);
assertNull(instance.getClient());
assertNull(instance.getSettings());
}
use of org.opensearch.ml.common.parameter.SampleAlgoOutput in project ml-commons by opensearch-project.
the class LocalSampleCalculator method execute.
@Override
public Output execute(Input input) {
if (input == null || !(input instanceof LocalSampleCalculatorInput)) {
throw new IllegalArgumentException("wrong input");
}
LocalSampleCalculatorInput sampleCalculatorInput = (LocalSampleCalculatorInput) input;
String operation = sampleCalculatorInput.getOperation();
List<Double> inputData = sampleCalculatorInput.getInputData();
switch(operation) {
case "sum":
double sum = inputData.stream().mapToDouble(f -> f.doubleValue()).sum();
return new SampleAlgoOutput(sum);
case "max":
double max = inputData.stream().max(Comparator.naturalOrder()).get();
return new SampleAlgoOutput(max);
case "min":
double min = inputData.stream().min(Comparator.naturalOrder()).get();
return new SampleAlgoOutput(min);
default:
throw new IllegalArgumentException("can't support this operation");
}
}
use of org.opensearch.ml.common.parameter.SampleAlgoOutput in project ml-commons by opensearch-project.
the class MLEngineTest method executeLocalSampleCalculator.
@Test
public void executeLocalSampleCalculator() {
Input input = new LocalSampleCalculatorInput("sum", Arrays.asList(1.0, 2.0));
SampleAlgoOutput output = (SampleAlgoOutput) MLEngine.execute(input);
Assert.assertEquals(3.0, output.getSampleResult(), 1e-5);
}
use of org.opensearch.ml.common.parameter.SampleAlgoOutput in project ml-commons by opensearch-project.
the class SampleAlgoTest method predict.
@Test
public void predict() {
Model model = sampleAlgo.train(trainDataFrame);
SampleAlgoOutput output = (SampleAlgoOutput) sampleAlgo.predict(predictionDataFrame, model);
Assert.assertEquals(3.0, output.getSampleResult().doubleValue(), 1e-5);
}
use of org.opensearch.ml.common.parameter.SampleAlgoOutput in project ml-commons by opensearch-project.
the class LocalSampleCalculatorTest method execute.
@Test
public void execute() {
SampleAlgoOutput output = (SampleAlgoOutput) calculator.execute(input);
Assert.assertEquals(6.0, output.getSampleResult().doubleValue(), 1e-5);
input = new LocalSampleCalculatorInput("max", Arrays.asList(1.0, 2.0, 3.0));
output = (SampleAlgoOutput) calculator.execute(input);
Assert.assertEquals(3.0, output.getSampleResult().doubleValue(), 1e-5);
input = new LocalSampleCalculatorInput("min", Arrays.asList(1.0, 2.0, 3.0));
output = (SampleAlgoOutput) calculator.execute(input);
Assert.assertEquals(1.0, output.getSampleResult().doubleValue(), 1e-5);
}
Aggregations