use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.
the class MLEngineTest method predictUnsupportedAlgorithm.
@Test
public void predictUnsupportedAlgorithm() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Unsupported algorithm: LINEAR_REGRESSION");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructLinearRegressionPredictionDataFrame()).build();
Input mlInput = MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build();
MLEngine.predict(mlInput, null);
}
}
use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.
the class MLEngineTest method trainAndPredictWithInvalidInput.
@Test
public void trainAndPredictWithInvalidInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Input should be MLInput");
Input input = new LocalSampleCalculatorInput("sum", Arrays.asList(1.0, 2.0));
MLEngine.trainAndPredict(input);
}
use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.
the class MLEngineTest method executeWithInvalidInput.
@Test
public void executeWithInvalidInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Function name should not be null");
Input input = new Input() {
@Override
public FunctionName getFunctionName() {
return null;
}
@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
}
@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
return null;
}
};
MLEngine.execute(input);
}
use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.
the class MLEngineTest method trainKMeansModel.
private Model trainKMeansModel() {
KMeansParams parameters = KMeansParams.builder().centroids(2).iterations(10).distanceType(KMeansParams.DistanceType.EUCLIDEAN).build();
DataFrame trainDataFrame = constructKMeansDataFrame(100);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(parameters).inputDataset(inputDataset).build();
return MLEngine.train(mlInput);
}
use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.
the class AnomalyLocalizerImpl method newSearchRequestForEntityKeys.
private SearchRequest newSearchRequestForEntityKeys(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Bucket bucket, List<List<String>> keys) {
RangeQueryBuilder timeRangeFilter = new RangeQueryBuilder(input.getTimeFieldName()).from(bucket.getBase().get().getStartTime(), true).to(bucket.getBase().get().getEndTime(), true);
BoolQueryBuilder filter = QueryBuilders.boolQuery().filter(timeRangeFilter);
input.getFilterQuery().ifPresent(q -> filter.filter(q));
KeyedFilter[] filters = IntStream.range(0, keys.size()).mapToObj(i -> new KeyedFilter(Integer.toString(i), newQueryByKey(keys.get(i), input))).toArray(KeyedFilter[]::new);
FiltersAggregationBuilder filtersAgg = AggregationBuilders.filters(agg.getName(), filters);
filtersAgg.subAggregation(agg);
SearchSourceBuilder search = new SearchSourceBuilder().size(0).query(filter).aggregation(filtersAgg);
SearchRequest searchRequest = new SearchRequest(new String[] { input.getIndexName() }, search);
return searchRequest;
}
Aggregations