Search in sources :

Example 6 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLEngineTest method predictWithoutAlgoName.

@Test
public void predictWithoutAlgoName() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("algorithm can't be null");
    MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build();
    Input mlInput = MLInput.builder().inputDataset(inputDataset).build();
    MLEngine.predict(mlInput, null);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) Test(org.junit.Test)

Example 7 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLEngineTest method train_EmptyDataFrame.

@Test
public void train_EmptyDataFrame() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("Input data frame should not be null or empty");
    FunctionName algoName = FunctionName.LINEAR_REGRESSION;
    try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
        loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
        MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(0)).build();
        MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build());
    }
}
Also used : FunctionName(org.opensearch.ml.common.parameter.FunctionName) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) Test(org.junit.Test)

Example 8 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLEngineTest method predictUnsupportedAlgorithm.

@Test
public void predictUnsupportedAlgorithm() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("Unsupported algorithm: LINEAR_REGRESSION");
    FunctionName algoName = FunctionName.LINEAR_REGRESSION;
    try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
        loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
        MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructLinearRegressionPredictionDataFrame()).build();
        Input mlInput = MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build();
        MLEngine.predict(mlInput, null);
    }
}
Also used : FunctionName(org.opensearch.ml.common.parameter.FunctionName) MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) Test(org.junit.Test)

Example 9 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLEngineTest method trainKMeansModel.

private Model trainKMeansModel() {
    KMeansParams parameters = KMeansParams.builder().centroids(2).iterations(10).distanceType(KMeansParams.DistanceType.EUCLIDEAN).build();
    DataFrame trainDataFrame = constructKMeansDataFrame(100);
    MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build();
    Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(parameters).inputDataset(inputDataset).build();
    return MLEngine.train(mlInput);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame)

Example 10 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLInputDatasetHandler method parseSearchQueryInput.

/**
 * Create DataFrame based on given search query
 * @param mlInputDataset MLInputDataset
 * @param listener ActionListener
 */
public void parseSearchQueryInput(MLInputDataset mlInputDataset, ActionListener<DataFrame> listener) {
    if (!mlInputDataset.getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
        throw new IllegalArgumentException("Input dataset is not SEARCH_QUERY type.");
    }
    SearchQueryInputDataset inputDataset = (SearchQueryInputDataset) mlInputDataset;
    SearchRequest searchRequest = new SearchRequest();
    searchRequest.source(inputDataset.getSearchSourceBuilder());
    List<String> indicesList = inputDataset.getIndices();
    String[] indices = new String[indicesList.size()];
    indices = indicesList.toArray(indices);
    searchRequest.indices(indices);
    client.search(searchRequest, ActionListener.wrap(r -> {
        if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
            listener.onFailure(new IllegalArgumentException("No document found"));
            return;
        }
        SearchHits hits = r.getHits();
        List<Map<String, Object>> input = new ArrayList<>();
        SearchHit[] searchHits = hits.getHits();
        for (SearchHit hit : searchHits) {
            input.add(hit.getSourceAsMap());
        }
        DataFrame dataFrame = DataFrameBuilder.load(input);
        listener.onResponse(dataFrame);
        return;
    }, e -> {
        log.error("Failed to search" + e);
        listener.onFailure(e);
    }));
    return;
}
Also used : FieldDefaults(lombok.experimental.FieldDefaults) Client(org.opensearch.client.Client) SearchHit(org.opensearch.search.SearchHit) RequiredArgsConstructor(lombok.RequiredArgsConstructor) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) SearchHits(org.opensearch.search.SearchHits) ArrayList(java.util.ArrayList) List(java.util.List) AccessLevel(lombok.AccessLevel) Map(java.util.Map) Log4j2(lombok.extern.log4j.Log4j2) SearchRequest(org.opensearch.action.search.SearchRequest) ActionListener(org.opensearch.action.ActionListener) DataFrameBuilder(org.opensearch.ml.common.dataframe.DataFrameBuilder) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) SearchQueryInputDataset(org.opensearch.ml.common.dataset.SearchQueryInputDataset) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) SearchRequest(org.opensearch.action.search.SearchRequest) SearchQueryInputDataset(org.opensearch.ml.common.dataset.SearchQueryInputDataset) SearchHit(org.opensearch.search.SearchHit) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) ArrayList(java.util.ArrayList) List(java.util.List) SearchHits(org.opensearch.search.SearchHits)

Aggregations

MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)21 MLInput (org.opensearch.ml.common.parameter.MLInput)13 Test (org.junit.Test)9 Input (org.opensearch.ml.common.parameter.Input)8 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)8 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)7 SearchQueryInputDataset (org.opensearch.ml.common.dataset.SearchQueryInputDataset)5 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)5 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)5 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)5 SearchSourceBuilder (org.opensearch.search.builder.SearchSourceBuilder)5 IntegTestUtils.generateSearchSourceBuilder (org.opensearch.ml.utils.IntegTestUtils.generateSearchSourceBuilder)4 Response (org.opensearch.client.Response)3 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)3 FunctionName (org.opensearch.ml.common.parameter.FunctionName)3 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)3 Map (java.util.Map)2 KMeansParams (org.opensearch.ml.common.parameter.KMeansParams)2 Model (org.opensearch.ml.common.parameter.Model)2 ImmutableMap (com.google.common.collect.ImmutableMap)1