use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.
the class MLTrainingTaskRequestTest method writeTo.
@Test
public void writeTo() throws IOException {
MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).build();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
request.writeTo(bytesStreamOutput);
request = new MLTrainingTaskRequest(bytesStreamOutput.bytes().streamInput());
assertEquals(FunctionName.KMEANS, request.getMlInput().getAlgorithm());
assertEquals(1, ((KMeansParams) request.getMlInput().getParameters()).getCentroids().intValue());
assertEquals(MLInputDataType.DATA_FRAME, request.getMlInput().getInputDataset().getInputDataType());
}
use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.
the class SecureMLRestIT method testReadOnlyUser_CanGetModel_CanNotDeleteModel.
public void testReadOnlyUser_CanGetModel_CanNotDeleteModel() throws IOException {
KMeansParams kMeansParams = KMeansParams.builder().build();
// train model with full access client
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
String modelId = (String) trainResult.get("model_id");
assertNotNull(modelId);
String status = (String) trainResult.get("status");
assertEquals(MLTaskState.COMPLETED.name(), status);
try {
// get model with readonly client
getModel(mlReadOnlyClient, modelId, model -> {
String algorithm = (String) model.get("algorithm");
assertEquals(FunctionName.KMEANS.name(), algorithm);
});
} catch (IOException e) {
assertNull(e);
}
try {
// Failed to delete model with read only client
deleteModel(mlReadOnlyClient, modelId, null);
throw new RuntimeException("Delete model for readonly user does not fail");
} catch (Exception e) {
assertEquals(ResponseException.class, e.getClass());
assertTrue(Throwables.getStackTraceAsString(e).contains("no permissions for [cluster:admin/opensearch/ml/models/delete]"));
}
}, false);
}
use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.
the class SecureMLRestIT method testReadOnlyUser_CanSearchTasks.
public void testReadOnlyUser_CanSearchTasks() throws IOException {
KMeansParams kMeansParams = KMeansParams.builder().build();
// train model with full access client
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
assertFalse(trainResult.containsKey("model_id"));
String taskId = (String) trainResult.get("task_id");
assertNotNull(taskId);
String status = (String) trainResult.get("status");
assertEquals(MLTaskState.CREATED.name(), status);
try {
// search tasks with readonly client
searchTasksWithAlgoName(mlReadOnlyClient, FunctionName.KMEANS.name(), tasks -> {
ArrayList<Object> hits = (ArrayList) ((Map<String, Object>) tasks.get("hits")).get("hits");
assertTrue(hits.size() > 0);
});
} catch (IOException e) {
assertNull(e);
}
}, true);
}
use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.
the class SecureMLRestIT method testTrainModelWithFullAccessThenPredict.
public void testTrainModelWithFullAccessThenPredict() throws IOException {
KMeansParams kMeansParams = KMeansParams.builder().build();
// train model
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
String modelId = (String) trainResult.get("model_id");
assertNotNull(modelId);
String status = (String) trainResult.get("status");
assertEquals(MLTaskState.COMPLETED.name(), status);
try {
getModel(mlFullAccessClient, modelId, model -> {
String algorithm = (String) model.get("algorithm");
assertEquals(FunctionName.KMEANS.name(), algorithm);
});
} catch (IOException e) {
assertNull(e);
}
try {
// predict with trained model
predict(mlFullAccessClient, FunctionName.KMEANS, modelId, irisIndex, kMeansParams, searchSourceBuilder, predictResult -> {
String predictStatus = (String) predictResult.get("status");
assertEquals(MLTaskState.COMPLETED.name(), predictStatus);
Map<String, Object> predictionResult = (Map<String, Object>) predictResult.get("prediction_result");
ArrayList rows = (ArrayList) predictionResult.get("rows");
assertTrue(rows.size() > 1);
});
} catch (IOException e) {
assertNull(e);
}
}, false);
}
use of org.opensearch.ml.common.parameter.KMeansParams in project ml-commons by opensearch-project.
the class ModelSerDeSerTest method testModelSerDeSerKMeans.
@Test
public void testModelSerDeSerKMeans() {
KMeansParams params = KMeansParams.builder().build();
KMeans kMeans = new KMeans(params);
Model model = kMeans.train(constructKMeansDataFrame(100));
KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent());
byte[] serializedModel = ModelSerDeSer.serialize(kMeansModel);
assertFalse(Arrays.equals(serializedModel, model.getContent()));
}
Aggregations