Search in sources :

Example 6 with KMeansParams

use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.

the class MLEngineTest method trainKMeansModel.

private Model trainKMeansModel() {
    KMeansParams parameters = KMeansParams.builder().centroids(2).iterations(10).distanceType(KMeansParams.DistanceType.EUCLIDEAN).build();
    DataFrame trainDataFrame = constructKMeansDataFrame(100);
    MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build();
    Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(parameters).inputDataset(inputDataset).build();
    return MLEngine.train(mlInput);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame)

Example 7 with KMeansParams

use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.

the class MLPredictionTaskRequestTest method writeTo_Success.

@Test
public void writeTo_Success() throws IOException {
    MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).build();
    BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
    request.writeTo(bytesStreamOutput);
    request = new MLPredictionTaskRequest(bytesStreamOutput.bytes().streamInput());
    assertEquals(FunctionName.KMEANS, request.getMlInput().getAlgorithm());
    KMeansParams params = (KMeansParams) request.getMlInput().getParameters();
    assertEquals(1, params.getCentroids().intValue());
    MLInputDataset inputDataset = request.getMlInput().getInputDataset();
    assertEquals(MLInputDataType.DATA_FRAME, inputDataset.getInputDataType());
    DataFrame dataFrame = ((DataFrameInputDataset) inputDataset).getDataFrame();
    assertEquals(1, dataFrame.size());
    assertEquals(1, dataFrame.columnMetas().length);
    assertEquals("key1", dataFrame.columnMetas()[0].getName());
    assertEquals(ColumnType.DOUBLE, dataFrame.columnMetas()[0].getColumnType());
    assertEquals(1, dataFrame.getRow(0).size());
    assertEquals(2.00, dataFrame.getRow(0).getValue(0).getValue());
    assertNull(request.getModelId());
}
Also used : KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) BytesStreamOutput(org.opensearch.common.io.stream.BytesStreamOutput) Test(org.junit.Test)

Example 8 with KMeansParams

use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.

the class KMeansTest method trainAndPredict.

@Test
public void trainAndPredict() {
    KMeansParams parameters = KMeansParams.builder().distanceType(KMeansParams.DistanceType.EUCLIDEAN).iterations(10).centroids(2).build();
    KMeans kMeans = new KMeans(parameters);
    MLPredictionOutput output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
    DataFrame predictions = output.getPredictionResult();
    Assert.assertEquals(trainSize, predictions.size());
    predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
    parameters = parameters.toBuilder().distanceType(KMeansParams.DistanceType.COSINE).build();
    kMeans = new KMeans(parameters);
    output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
    predictions = output.getPredictionResult();
    Assert.assertEquals(trainSize, predictions.size());
    parameters = parameters.toBuilder().distanceType(KMeansParams.DistanceType.L1).build();
    kMeans = new KMeans(parameters);
    output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
    predictions = output.getPredictionResult();
    Assert.assertEquals(trainSize, predictions.size());
}
Also used : KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) Test(org.junit.Test)

Example 9 with KMeansParams

use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.

the class RestMLTrainAndPredictIT method trainAndPredictKmeansWithEmptyParam.

private void trainAndPredictKmeansWithEmptyParam() throws IOException {
    KMeansParams params = KMeansParams.builder().build();
    trainAndPredictKmeansWithParmas(params, clusterCount -> assertEquals(2, clusterCount.size()));
}
Also used : KMeansParams(org.opensearch.ml.common.parameter.KMeansParams)

Example 10 with KMeansParams

use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.

the class RestMLTrainAndPredictIT method trainAndPredictKmeansWithCustomParam.

private void trainAndPredictKmeansWithCustomParam() throws IOException {
    KMeansParams params = KMeansParams.builder().centroids(3).build();
    trainAndPredictKmeansWithParmas(params, clusterCount -> assertTrue(clusterCount.size() >= 2));
}
Also used : KMeansParams(org.opensearch.ml.common.parameter.KMeansParams)

Aggregations

KMeansParams (org.opensearch.ml.common.parameter.KMeansParams)14 IOException (java.io.IOException)5 Test (org.junit.Test)4 ArrayList (java.util.ArrayList)3 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)3 ExpectedException (org.junit.rules.ExpectedException)2 ResponseException (org.opensearch.client.ResponseException)2 BytesStreamOutput (org.opensearch.common.io.stream.BytesStreamOutput)2 MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)2 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)2 Map (java.util.Map)1 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)1 Input (org.opensearch.ml.common.parameter.Input)1 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)1 MLInput (org.opensearch.ml.common.parameter.MLInput)1 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)1 Model (org.opensearch.ml.common.parameter.Model)1 KMeans (org.opensearch.ml.engine.algorithms.clustering.KMeans)1 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)1 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)1