Search in sources :

Example 11 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLPredictionTaskRequestTest method writeTo_Success.

@Test
public void writeTo_Success() throws IOException {
    MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).build();
    BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
    request.writeTo(bytesStreamOutput);
    request = new MLPredictionTaskRequest(bytesStreamOutput.bytes().streamInput());
    assertEquals(FunctionName.KMEANS, request.getMlInput().getAlgorithm());
    KMeansParams params = (KMeansParams) request.getMlInput().getParameters();
    assertEquals(1, params.getCentroids().intValue());
    MLInputDataset inputDataset = request.getMlInput().getInputDataset();
    assertEquals(MLInputDataType.DATA_FRAME, inputDataset.getInputDataType());
    DataFrame dataFrame = ((DataFrameInputDataset) inputDataset).getDataFrame();
    assertEquals(1, dataFrame.size());
    assertEquals(1, dataFrame.columnMetas().length);
    assertEquals("key1", dataFrame.columnMetas()[0].getName());
    assertEquals(ColumnType.DOUBLE, dataFrame.columnMetas()[0].getColumnType());
    assertEquals(1, dataFrame.getRow(0).size());
    assertEquals(2.00, dataFrame.getRow(0).getValue(0).getValue());
    assertNull(request.getModelId());
}
Also used : KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) BytesStreamOutput(org.opensearch.common.io.stream.BytesStreamOutput) Test(org.junit.Test)

Example 12 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLEngineTest method predictWithoutModel.

@Test
public void predictWithoutModel() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("No model found for linear regression prediction.");
    MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructLinearRegressionPredictionDataFrame()).build();
    Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build();
    MLEngine.predict(mlInput, null);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) Test(org.junit.Test)

Example 13 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLEngineTest method train_UnsupportedAlgorithm.

@Test
public void train_UnsupportedAlgorithm() {
    exceptionRule.expect(IllegalArgumentException.class);
    exceptionRule.expectMessage("Unsupported algorithm: LINEAR_REGRESSION");
    FunctionName algoName = FunctionName.LINEAR_REGRESSION;
    try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
        loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
        MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build();
        MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build());
    }
}
Also used : FunctionName(org.opensearch.ml.common.parameter.FunctionName) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) Test(org.junit.Test)

Example 14 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLEngineTest method predictLinearRegression.

@Test
public void predictLinearRegression() {
    Model model = trainLinearRegressionModel();
    DataFrame predictionDataFrame = constructLinearRegressionPredictionDataFrame();
    MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build();
    Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build();
    MLPredictionOutput output = (MLPredictionOutput) MLEngine.predict(mlInput, model);
    DataFrame predictions = output.getPredictionResult();
    Assert.assertEquals(2, predictions.size());
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) Model(org.opensearch.ml.common.parameter.Model) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) Test(org.junit.Test)

Example 15 with MLInputDataset

use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.

the class MLEngineTest method trainAndPredictWithKmeans.

@Test
public void trainAndPredictWithKmeans() {
    int dataSize = 100;
    MLAlgoParams parameters = KMeansParams.builder().build();
    DataFrame dataFrame = constructKMeansDataFrame(dataSize);
    MLInputDataset inputData = new DataFrameInputDataset(dataFrame);
    Input input = new MLInput(FunctionName.KMEANS, parameters, inputData);
    MLPredictionOutput output = (MLPredictionOutput) MLEngine.trainAndPredict(input);
    Assert.assertEquals(dataSize, output.getPredictionResult().size());
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) MLInput(org.opensearch.ml.common.parameter.MLInput) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) MLAlgoParams(org.opensearch.ml.common.parameter.MLAlgoParams) Test(org.junit.Test)

Aggregations

MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)21 MLInput (org.opensearch.ml.common.parameter.MLInput)13 Test (org.junit.Test)9 Input (org.opensearch.ml.common.parameter.Input)8 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)8 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)7 SearchQueryInputDataset (org.opensearch.ml.common.dataset.SearchQueryInputDataset)5 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)5 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)5 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)5 SearchSourceBuilder (org.opensearch.search.builder.SearchSourceBuilder)5 IntegTestUtils.generateSearchSourceBuilder (org.opensearch.ml.utils.IntegTestUtils.generateSearchSourceBuilder)4 Response (org.opensearch.client.Response)3 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)3 FunctionName (org.opensearch.ml.common.parameter.FunctionName)3 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)3 Map (java.util.Map)2 KMeansParams (org.opensearch.ml.common.parameter.KMeansParams)2 Model (org.opensearch.ml.common.parameter.Model)2 ImmutableMap (com.google.common.collect.ImmutableMap)1