use of org.opensearch.knn.indices.ModelMetadata in project k-NN by opensearch-project.
the class TrainingModelRequestTests method testValidation_invalid_descriptionToLong.
public void testValidation_invalid_descriptionToLong() {
// Setup the training request
String modelId = "test-model-id";
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.validate()).thenReturn(null);
when(knnMethodContext.isTrainingRequired()).thenReturn(true);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
String trainingFieldModeId = "training-field-model-id";
char[] chars = new char[KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH + 1];
Arrays.fill(chars, 'a');
String description = new String(chars);
TrainingModelRequest trainingModelRequest = new TrainingModelRequest(modelId, knnMethodContext, dimension, trainingIndex, trainingField, null, description);
// Mock the model dao to return metadata for modelId to recognize it is a duplicate
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(null);
when(modelDao.getMetadata(trainingFieldModeId)).thenReturn(trainingFieldModelMetadata);
// Cluster service that wont produce validation exception
ClusterService clusterService = getClusterServiceForValidReturns(trainingIndex, trainingField, dimension);
// Initialize static components with the mocks
TrainingModelRequest.initialize(modelDao, clusterService);
// Test that validation produces model already exists error message
ActionRequestValidationException exception = trainingModelRequest.validate();
assertNotNull(exception);
List<String> validationErrors = exception.validationErrors();
assertEquals(1, validationErrors.size());
assertTrue(validationErrors.get(0).contains("Description exceeds limit"));
}
use of org.opensearch.knn.indices.ModelMetadata in project k-NN by opensearch-project.
the class TrainingModelRequestTests method testValidation_valid_trainingIndexBuiltFromModel.
public void testValidation_valid_trainingIndexBuiltFromModel() {
// Setup the training request
String modelId = "test-model-id";
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.validate()).thenReturn(null);
when(knnMethodContext.isTrainingRequired()).thenReturn(true);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
String trainingFieldModeId = "training-field-model-id";
TrainingModelRequest trainingModelRequest = new TrainingModelRequest(modelId, knnMethodContext, dimension, trainingIndex, trainingField, null, null);
// Mock the model dao to return metadata for modelId to recognize it is a duplicate
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(null);
when(modelDao.getMetadata(trainingFieldModeId)).thenReturn(trainingFieldModelMetadata);
// Return model id instead of dimension directly
Map<String, Object> mappingMap = ImmutableMap.of("properties", ImmutableMap.of(trainingField, ImmutableMap.of("type", KNNVectorFieldMapper.CONTENT_TYPE, KNNConstants.MODEL_ID, trainingFieldModeId)));
MappingMetadata mappingMetadata = mock(MappingMetadata.class);
when(mappingMetadata.getSourceAsMap()).thenReturn(mappingMap);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
Metadata metadata = mock(Metadata.class);
when(metadata.index(trainingIndex)).thenReturn(indexMetadata);
DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class);
when(discoveryNodes.getDataNodes()).thenReturn(ImmutableOpenMap.of());
ClusterState clusterState = mock(ClusterState.class);
when(clusterState.metadata()).thenReturn(metadata);
when(clusterState.nodes()).thenReturn(discoveryNodes);
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.state()).thenReturn(clusterState);
// Initialize static components with the mocks
TrainingModelRequest.initialize(modelDao, clusterService);
// Test that validation produces model already exists error message
ActionRequestValidationException exception = trainingModelRequest.validate();
assertNull(exception);
}
use of org.opensearch.knn.indices.ModelMetadata in project k-NN by opensearch-project.
the class TrainingModelRequestTests method testValidation_invalid_modelIdAlreadyExists.
public void testValidation_invalid_modelIdAlreadyExists() {
// Check that validation produces exception when the modelId passed in already has a model
// associated with it
// Setup the training request
String modelId = "test-model-id";
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.validate()).thenReturn(null);
when(knnMethodContext.isTrainingRequired()).thenReturn(true);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
TrainingModelRequest trainingModelRequest = new TrainingModelRequest(modelId, knnMethodContext, dimension, trainingIndex, trainingField, null, null);
// Mock the model dao to return metadata for modelId to recognize it is a duplicate
ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 128, ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "");
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
// This cluster service will result in no validation exceptions
ClusterService clusterService = getClusterServiceForValidReturns(trainingIndex, trainingField, dimension);
// Initialize static components with the mocks
TrainingModelRequest.initialize(modelDao, clusterService);
// Test that validation produces model already exists error message
ActionRequestValidationException exception = trainingModelRequest.validate();
assertNotNull(exception);
List<String> validationErrors = exception.validationErrors();
assertEquals(1, validationErrors.size());
assertTrue(validationErrors.get(0).contains("already exists"));
}
use of org.opensearch.knn.indices.ModelMetadata in project k-NN by opensearch-project.
the class RemoveModelFromCacheTransportActionTests method testNodeOperation_modelInCache.
@Ignore
public void testNodeOperation_modelInCache() throws ExecutionException, InterruptedException {
ClusterService clusterService = mock(ClusterService.class);
Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), "10%").build();
ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING));
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
when(clusterService.getSettings()).thenReturn(settings);
ModelDao modelDao = mock(ModelDao.class);
String modelId = "test-model-id";
Model model = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 16, ModelState.CREATED, "timestamp", "description", ""), new byte[128], modelId);
when(modelDao.get(modelId)).thenReturn(model);
ModelCache.initialize(modelDao, clusterService);
// Load the model into the cache
ModelCache modelCache = ModelCache.getInstance();
modelCache.get(modelId);
// Remove the model from the cache
RemoveModelFromCacheTransportAction action = node().injector().getInstance(RemoveModelFromCacheTransportAction.class);
RemoveModelFromCacheNodeRequest request = new RemoveModelFromCacheNodeRequest(modelId);
action.nodeOperation(request);
assertEquals(0L, modelCache.getTotalWeightInKB());
}
use of org.opensearch.knn.indices.ModelMetadata in project k-NN by opensearch-project.
the class RestSearchModelHandlerIT method testSearchModelWithSourceFilteringExcludes.
public void testSearchModelWithSourceFilteringExcludes() throws IOException {
createModelSystemIndex();
createIndex("irrelevant-index", Settings.EMPTY);
addDocWithBinaryField("irrelevant-index", "id1", "field-name", "value");
List<String> testModelID = Arrays.asList("test-modelid1", "test-modelid2");
byte[] testModelBlob = "hello".getBytes();
ModelMetadata testModelMetadata = getModelMetadata();
for (String modelID : testModelID) {
addModelToSystemIndex(modelID, testModelMetadata, testModelBlob);
}
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search");
for (String method : Arrays.asList("GET", "POST")) {
Request request = new Request(method, restURI);
request.setJsonEntity("{\n" + " \"_source\": {\n" + " \"excludes\": [\"model_blob\" ]\n" + " }, " + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}");
Response response = client().performRequest(request);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
String responseBody = EntityUtils.toString(response.getEntity());
assertNotNull(responseBody);
XContentParser parser = createParser(XContentType.JSON.xContent(), responseBody);
SearchResponse searchResponse = SearchResponse.fromXContent(parser);
assertNotNull(searchResponse);
// returns only model from ModelIndex
assertEquals(searchResponse.getHits().getHits().length, testModelID.size());
for (SearchHit hit : searchResponse.getHits().getHits()) {
assertTrue(testModelID.contains(hit.getId()));
Map<String, Object> sourceAsMap = hit.getSourceAsMap();
assertFalse(sourceAsMap.containsKey("model_blob"));
assertTrue(sourceAsMap.containsKey("state"));
assertTrue(sourceAsMap.containsKey("timestamp"));
assertTrue(sourceAsMap.containsKey("description"));
}
}
}
Aggregations