use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.
the class AnomalyLocalizerImpl method newSearchRequestForEntry.
private SearchRequest newSearchRequestForEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Bucket bucket, Optional<Map<String, Object>> afterKey) {
RangeQueryBuilder timeRangeFilter = new RangeQueryBuilder(input.getTimeFieldName()).from(bucket.getStartTime(), true).to(bucket.getEndTime(), true);
BoolQueryBuilder filter = QueryBuilders.boolQuery().filter(timeRangeFilter);
input.getFilterQuery().ifPresent(q -> filter.filter(q));
CompositeAggregationBuilder compositeAgg = new CompositeAggregationBuilder(agg.getName(), input.getAttributeFieldNames().stream().map(name -> new TermsValuesSourceBuilder(name).field(name)).collect(Collectors.toList())).size(MAX_BUCKET_SETTING.get(this.settings));
compositeAgg.subAggregation(agg);
if (afterKey.isPresent()) {
compositeAgg.aggregateAfter(afterKey.get());
}
SearchSourceBuilder search = new SearchSourceBuilder().size(0).query(filter).aggregation(compositeAgg);
SearchRequest searchRequest = new SearchRequest(new String[] { input.getIndexName() }, search);
return searchRequest;
}
use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.
the class AnomalyLocalizerImpl method processNewEntry.
private void processNewEntry(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, AnomalyLocalizationOutput.Bucket bucket, Optional<Map<String, Object>> afterKey, PriorityQueue<AnomalyLocalizationOutput.Entity> queue, AnomalyLocalizationOutput output, ActionListener<AnomalyLocalizationOutput> listener) {
SearchRequest request = newSearchRequestForEntry(input, agg, bucket, afterKey);
client.search(request, wrap(r -> onNewEntryResponse(r, input, agg, result, bucket, queue, output, listener), listener::onFailure));
}
use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.
the class AnomalyLocalizerImpl method newSearchRequestForOverallAggregates.
private MultiSearchRequest newSearchRequestForOverallAggregates(AnomalyLocalizationInput input, AggregationBuilder agg, LocalizationTimeBuckets timeBuckets) {
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
timeBuckets.getAllIntervals().stream().map(i -> {
RangeQueryBuilder timeRangeFilter = new RangeQueryBuilder(input.getTimeFieldName()).from(i.getKey(), true).to(i.getValue(), true);
BoolQueryBuilder filter = QueryBuilders.boolQuery().filter(timeRangeFilter);
input.getFilterQuery().ifPresent(q -> filter.filter(q));
SearchSourceBuilder search = new SearchSourceBuilder().size(0).query(filter).aggregation(agg);
SearchRequest searchRequest = new SearchRequest(new String[] { input.getIndexName() }, search);
return searchRequest;
}).forEach(multiSearchRequest::add);
return multiSearchRequest;
}
use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.
the class MLEngineTest method predictWithoutModel.
@Test
public void predictWithoutModel() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("No model found for linear regression prediction.");
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructLinearRegressionPredictionDataFrame()).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build();
MLEngine.predict(mlInput, null);
}
use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.
the class MLEngineTest method predictLinearRegression.
@Test
public void predictLinearRegression() {
Model model = trainLinearRegressionModel();
DataFrame predictionDataFrame = constructLinearRegressionPredictionDataFrame();
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build();
MLPredictionOutput output = (MLPredictionOutput) MLEngine.predict(mlInput, model);
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(2, predictions.size());
}
Aggregations