Search in sources :

Example 11 with Input

use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.

the class MLEngineTest method predictUnsupportedAlgorithm.

@Test
public void predictUnsupportedAlgorithm() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("Unsupported algorithm: LINEAR_REGRESSION");
    FunctionName algoName = FunctionName.LINEAR_REGRESSION;
    try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
        loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
        MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructLinearRegressionPredictionDataFrame()).build();
        Input mlInput = MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build();
        MLEngine.predict(mlInput, null);
    }
}
Also used : FunctionName(org.opensearch.ml.common.parameter.FunctionName) MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) Test(org.junit.Test)

Example 12 with Input

use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.

the class MLEngineTest method trainAndPredictWithInvalidInput.

@Test
public void trainAndPredictWithInvalidInput() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("Input should be MLInput");
    Input input = new LocalSampleCalculatorInput("sum", Arrays.asList(1.0, 2.0));
    MLEngine.trainAndPredict(input);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) Test(org.junit.Test)

Example 13 with Input

use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.

the class MLEngineTest method executeWithInvalidInput.

@Test
public void executeWithInvalidInput() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("Function name should not be null");
    Input input = new Input() {

        @Override
        public FunctionName getFunctionName() {
            return null;
        }

        @Override
        public void writeTo(StreamOutput streamOutput) throws IOException {
        }

        @Override
        public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
            return null;
        }
    };
    MLEngine.execute(input);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) MLAlgoParams(org.opensearch.ml.common.parameter.MLAlgoParams) LinearRegressionParams(org.opensearch.ml.common.parameter.LinearRegressionParams) KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) StreamOutput(org.opensearch.common.io.stream.StreamOutput) XContentBuilder(org.opensearch.common.xcontent.XContentBuilder) Test(org.junit.Test)

Example 14 with Input

use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.

the class MLEngineTest method trainKMeansModel.

private Model trainKMeansModel() {
    KMeansParams parameters = KMeansParams.builder().centroids(2).iterations(10).distanceType(KMeansParams.DistanceType.EUCLIDEAN).build();
    DataFrame trainDataFrame = constructKMeansDataFrame(100);
    MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build();
    Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(parameters).inputDataset(inputDataset).build();
    return MLEngine.train(mlInput);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame)

Example 15 with Input

use of org.opensearch.ml.common.parameter.Input in project ml-commons by opensearch-project.

the class AnomalyLocalizerImpl method newSearchRequestForEntityKeys.

private SearchRequest newSearchRequestForEntityKeys(AnomalyLocalizationInput input, AggregationBuilder agg, AnomalyLocalizationOutput.Bucket bucket, List<List<String>> keys) {
    RangeQueryBuilder timeRangeFilter = new RangeQueryBuilder(input.getTimeFieldName()).from(bucket.getBase().get().getStartTime(), true).to(bucket.getBase().get().getEndTime(), true);
    BoolQueryBuilder filter = QueryBuilders.boolQuery().filter(timeRangeFilter);
    input.getFilterQuery().ifPresent(q -> filter.filter(q));
    KeyedFilter[] filters = IntStream.range(0, keys.size()).mapToObj(i -> new KeyedFilter(Integer.toString(i), newQueryByKey(keys.get(i), input))).toArray(KeyedFilter[]::new);
    FiltersAggregationBuilder filtersAgg = AggregationBuilders.filters(agg.getName(), filters);
    filtersAgg.subAggregation(agg);
    SearchSourceBuilder search = new SearchSourceBuilder().size(0).query(filter).aggregation(filtersAgg);
    SearchRequest searchRequest = new SearchRequest(new String[] { input.getIndexName() }, search);
    return searchRequest;
}
Also used : KeyedFilter(org.opensearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) MultiSearchRequest(org.opensearch.action.search.MultiSearchRequest) SneakyThrows(lombok.SneakyThrows) PriorityQueue(java.util.PriorityQueue) NotifyOnceListener(org.opensearch.action.NotifyOnceListener) Executable(org.opensearch.ml.engine.Executable) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) AtomicReference(java.util.concurrent.atomic.AtomicReference) Input(org.opensearch.ml.common.parameter.Input) ArrayList(java.util.ArrayList) CompositeAggregationBuilder(org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder) Output(org.opensearch.ml.common.parameter.Output) AggregationBuilder(org.opensearch.search.aggregations.AggregationBuilder) LatchedActionListener(org.opensearch.action.LatchedActionListener) KeyedFilter(org.opensearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter) Map(java.util.Map) SearchRequest(org.opensearch.action.search.SearchRequest) ActionListener(org.opensearch.action.ActionListener) SearchResponse(org.opensearch.action.search.SearchResponse) SimpleEntry(java.util.AbstractMap.SimpleEntry) QueryBuilders(org.opensearch.index.query.QueryBuilders) Client(org.opensearch.client.Client) MAX_BUCKET_SETTING(org.opensearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING) RangeQueryBuilder(org.opensearch.index.query.RangeQueryBuilder) Filters(org.opensearch.search.aggregations.bucket.filter.Filters) TermQueryBuilder(org.opensearch.index.query.TermQueryBuilder) Settings(org.opensearch.common.settings.Settings) FiltersAggregationBuilder(org.opensearch.search.aggregations.bucket.filter.FiltersAggregationBuilder) Collectors(java.util.stream.Collectors) CompositeAggregation(org.opensearch.search.aggregations.bucket.composite.CompositeAggregation) MultiSearchResponse(org.opensearch.action.search.MultiSearchResponse) CountDownLatch(java.util.concurrent.CountDownLatch) AggregationBuilders(org.opensearch.search.aggregations.AggregationBuilders) List(java.util.List) SingleValue(org.opensearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue) SearchSourceBuilder(org.opensearch.search.builder.SearchSourceBuilder) Data(lombok.Data) Log4j2(lombok.extern.log4j.Log4j2) Optional(java.util.Optional) ActionListener.wrap(org.opensearch.action.ActionListener.wrap) Collections(java.util.Collections) TermsValuesSourceBuilder(org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder) BoolQueryBuilder(org.opensearch.index.query.BoolQueryBuilder) MultiSearchRequest(org.opensearch.action.search.MultiSearchRequest) SearchRequest(org.opensearch.action.search.SearchRequest) BoolQueryBuilder(org.opensearch.index.query.BoolQueryBuilder) FiltersAggregationBuilder(org.opensearch.search.aggregations.bucket.filter.FiltersAggregationBuilder) RangeQueryBuilder(org.opensearch.index.query.RangeQueryBuilder) SearchSourceBuilder(org.opensearch.search.builder.SearchSourceBuilder)

Aggregations

Input (org.opensearch.ml.common.parameter.Input)23 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)13 MLInput (org.opensearch.ml.common.parameter.MLInput)12 Test (org.junit.Test)10 Output (org.opensearch.ml.common.parameter.Output)10 Client (org.opensearch.client.Client)9 Settings (org.opensearch.common.settings.Settings)9 ArrayList (java.util.ArrayList)8 List (java.util.List)8 Data (lombok.Data)8 ActionListener (org.opensearch.action.ActionListener)8 SearchRequest (org.opensearch.action.search.SearchRequest)8 SearchResponse (org.opensearch.action.search.SearchResponse)8 MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)8 Executable (org.opensearch.ml.engine.Executable)8 SimpleEntry (java.util.AbstractMap.SimpleEntry)7 Arrays (java.util.Arrays)7 Collections (java.util.Collections)7 Map (java.util.Map)7 Optional (java.util.Optional)7