Search in sources :

Example 1 with KMeansParams

use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.

the class MLTrainingTaskRequestTest method writeTo.

@Test
public void writeTo() throws IOException {
    MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).build();
    BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
    request.writeTo(bytesStreamOutput);
    request = new MLTrainingTaskRequest(bytesStreamOutput.bytes().streamInput());
    assertEquals(FunctionName.KMEANS, request.getMlInput().getAlgorithm());
    assertEquals(1, ((KMeansParams) request.getMlInput().getParameters()).getCentroids().intValue());
    assertEquals(MLInputDataType.DATA_FRAME, request.getMlInput().getInputDataset().getInputDataType());
}
Also used : KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) BytesStreamOutput(org.opensearch.common.io.stream.BytesStreamOutput) Test(org.junit.Test)

Example 2 with KMeansParams

use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.

the class SecureMLRestIT method testReadOnlyUser_CanGetModel_CanNotDeleteModel.

public void testReadOnlyUser_CanGetModel_CanNotDeleteModel() throws IOException {
    KMeansParams kMeansParams = KMeansParams.builder().build();
    // train model with full access client
    train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
        String modelId = (String) trainResult.get("model_id");
        assertNotNull(modelId);
        String status = (String) trainResult.get("status");
        assertEquals(MLTaskState.COMPLETED.name(), status);
        try {
            // get model with readonly client
            getModel(mlReadOnlyClient, modelId, model -> {
                String algorithm = (String) model.get("algorithm");
                assertEquals(FunctionName.KMEANS.name(), algorithm);
            });
        } catch (IOException e) {
            assertNull(e);
        }
        try {
            // Failed to delete model with read only client
            deleteModel(mlReadOnlyClient, modelId, null);
            throw new RuntimeException("Delete model for readonly user does not fail");
        } catch (Exception e) {
            assertEquals(ResponseException.class, e.getClass());
            assertTrue(Throwables.getStackTraceAsString(e).contains("no permissions for [cluster:admin/opensearch/ml/models/delete]"));
        }
    }, false);
}
Also used : KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) ResponseException(org.opensearch.client.ResponseException) IOException(java.io.IOException) IOException(java.io.IOException) ResponseException(org.opensearch.client.ResponseException) ExpectedException(org.junit.rules.ExpectedException)

Example 3 with KMeansParams

use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.

the class SecureMLRestIT method testReadOnlyUser_CanSearchTasks.

public void testReadOnlyUser_CanSearchTasks() throws IOException {
    KMeansParams kMeansParams = KMeansParams.builder().build();
    // train model with full access client
    train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
        assertFalse(trainResult.containsKey("model_id"));
        String taskId = (String) trainResult.get("task_id");
        assertNotNull(taskId);
        String status = (String) trainResult.get("status");
        assertEquals(MLTaskState.CREATED.name(), status);
        try {
            // search tasks with readonly client
            searchTasksWithAlgoName(mlReadOnlyClient, FunctionName.KMEANS.name(), tasks -> {
                ArrayList<Object> hits = (ArrayList) ((Map<String, Object>) tasks.get("hits")).get("hits");
                assertTrue(hits.size() > 0);
            });
        } catch (IOException e) {
            assertNull(e);
        }
    }, true);
}
Also used : KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) ArrayList(java.util.ArrayList) IOException(java.io.IOException)

Example 4 with KMeansParams

use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.

the class SecureMLRestIT method testTrainModelWithFullAccessThenPredict.

public void testTrainModelWithFullAccessThenPredict() throws IOException {
    KMeansParams kMeansParams = KMeansParams.builder().build();
    // train model
    train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
        String modelId = (String) trainResult.get("model_id");
        assertNotNull(modelId);
        String status = (String) trainResult.get("status");
        assertEquals(MLTaskState.COMPLETED.name(), status);
        try {
            getModel(mlFullAccessClient, modelId, model -> {
                String algorithm = (String) model.get("algorithm");
                assertEquals(FunctionName.KMEANS.name(), algorithm);
            });
        } catch (IOException e) {
            assertNull(e);
        }
        try {
            // predict with trained model
            predict(mlFullAccessClient, FunctionName.KMEANS, modelId, irisIndex, kMeansParams, searchSourceBuilder, predictResult -> {
                String predictStatus = (String) predictResult.get("status");
                assertEquals(MLTaskState.COMPLETED.name(), predictStatus);
                Map<String, Object> predictionResult = (Map<String, Object>) predictResult.get("prediction_result");
                ArrayList rows = (ArrayList) predictionResult.get("rows");
                assertTrue(rows.size() > 1);
            });
        } catch (IOException e) {
            assertNull(e);
        }
    }, false);
}
Also used : KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) ArrayList(java.util.ArrayList) IOException(java.io.IOException) Map(java.util.Map)

Example 5 with KMeansParams

use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.

the class ModelSerDeSerTest method testModelSerDeSerKMeans.

@Test
public void testModelSerDeSerKMeans() {
    KMeansParams params = KMeansParams.builder().build();
    KMeans kMeans = new KMeans(params);
    Model model = kMeans.train(constructKMeansDataFrame(100));
    KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent());
    byte[] serializedModel = ModelSerDeSer.serialize(kMeansModel);
    assertFalse(Arrays.equals(serializedModel, model.getContent()));
}
Also used : KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) KMeansModel(org.tribuo.clustering.kmeans.KMeansModel) KMeans(org.opensearch.ml.engine.algorithms.clustering.KMeans) Model(org.opensearch.ml.common.parameter.Model) KMeansModel(org.tribuo.clustering.kmeans.KMeansModel) Test(org.junit.Test)

Aggregations

KMeansParams (org.opensearch.ml.common.parameter.KMeansParams)14 IOException (java.io.IOException)5 Test (org.junit.Test)4 ArrayList (java.util.ArrayList)3 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)3 ExpectedException (org.junit.rules.ExpectedException)2 ResponseException (org.opensearch.client.ResponseException)2 BytesStreamOutput (org.opensearch.common.io.stream.BytesStreamOutput)2 MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)2 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)2 Map (java.util.Map)1 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)1 Input (org.opensearch.ml.common.parameter.Input)1 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)1 MLInput (org.opensearch.ml.common.parameter.MLInput)1 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)1 Model (org.opensearch.ml.common.parameter.Model)1 KMeans (org.opensearch.ml.engine.algorithms.clustering.KMeans)1 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)1 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)1