use of org.opensearch.ml.common.parameter.Model in project ml-commons by opensearch-project.
the class AnomalyDetectionLibSVM method train.
@Override
public Model train(DataFrame dataFrame) {
KernelType kernelType = parseKernelType();
SVMParameters params = new SVMParameters<>(new SVMAnomalyType(SVMAnomalyType.SVMMode.ONE_CLASS), kernelType);
Double gamma = Optional.ofNullable(parameters.getGamma()).orElse(DEFAULT_GAMMA);
Double nu = Optional.ofNullable(parameters.getNu()).orElse(DEFAULT_NU);
params.setGamma(gamma);
params.setNu(nu);
if (parameters.getCost() != null) {
params.setCost(parameters.getCost());
}
if (parameters.getCoeff() != null) {
params.setCoeff(parameters.getCoeff());
}
if (parameters.getEpsilon() != null) {
params.setEpsilon(parameters.getEpsilon());
}
if (parameters.getDegree() != null) {
params.setDegree(parameters.getDegree());
}
MutableDataset<Event> data = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM training data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
LibSVMAnomalyTrainer trainer = new LibSVMAnomalyTrainer(params);
LibSVMModel libSVMModel = trainer.train(data);
((LibSVMAnomalyModel) libSVMModel).getNumberOfSupportVectors();
Model model = new Model();
model.setName(FunctionName.AD_LIBSVM.name());
model.setVersion(VERSION);
model.setContent(ModelSerDeSer.serialize(libSVMModel));
return model;
}
use of org.opensearch.ml.common.parameter.Model in project ml-commons by opensearch-project.
the class MLTrainingTaskRunner method train.
private void train(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
ActionListener<MLTaskResponse> listener = ActionListener.wrap(r -> actionListener.onResponse(r), e -> {
mlStats.createCounterStatIfAbsent(failureCountStat(mlTask.getFunctionName(), ActionName.TRAIN)).increment();
mlStats.getStat(ML_TOTAL_FAILURE_COUNT).increment();
actionListener.onFailure(e);
});
try {
// run training
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
Model model = MLEngine.train(mlInput);
mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(indexCreated -> {
if (!indexCreated) {
listener.onFailure(new RuntimeException("No response to create ML task index"));
return;
}
// TODO: put the user into model for backend role based access control.
MLModel mlModel = new MLModel(mlInput.getAlgorithm(), model);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<IndexResponse> indexResponseListener = ActionListener.wrap(r -> {
log.info("Model data indexing done, result:{}, model id: {}", r.getResult(), r.getId());
mlStats.getStat(ML_TOTAL_MODEL_COUNT).increment();
mlStats.createCounterStatIfAbsent(modelCountStat(mlTask.getFunctionName())).increment();
String returnedTaskId = mlTask.isAsync() ? mlTask.getTaskId() : null;
MLTrainingOutput output = new MLTrainingOutput(r.getId(), returnedTaskId, MLTaskState.COMPLETED.name());
listener.onResponse(MLTaskResponse.builder().output(output).build());
}, e -> {
listener.onFailure(e);
});
IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX);
indexRequest.source(mlModel.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(indexRequest, ActionListener.runBefore(indexResponseListener, () -> context.restore()));
} catch (Exception e) {
log.error("Failed to save ML model", e);
listener.onFailure(e);
}
}, e -> {
log.error("Failed to init ML model index", e);
listener.onFailure(e);
}));
} catch (Exception e) {
// todo need to specify what exception
log.error("Failed to train " + mlInput.getAlgorithm(), e);
listener.onFailure(e);
}
}
use of org.opensearch.ml.common.parameter.Model in project ml-commons by opensearch-project.
the class MLPredictTaskRunner method predict.
private void predict(MLTask mlTask, DataFrame inputDataFrame, MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
// track ML task count and add ML task into cache
mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(ML_TOTAL_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(requestCountStat(mlTask.getFunctionName(), ActionName.PREDICT)).increment();
mlTaskManager.add(mlTask);
// run predict
if (request.getModelId() != null) {
// search model by model id.
try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) {
MLInput mlInput = request.getMlInput();
ActionListener<GetResponse> getResponseListener = ActionListener.wrap(r -> {
if (r == null || !r.isExists()) {
internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId."));
return;
}
Map<String, Object> source = r.getSourceAsMap();
User requestUser = getUserContext(client);
User resourceUser = User.parse((String) source.get(USER));
if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
// The backend roles of request user and resource user doesn't have intersection
OpenSearchException e = new OpenSearchException("User: " + requestUser.getName() + " does not have permissions to run predict by model: " + request.getModelId());
handlePredictFailure(mlTask, internalListener, e, false);
return;
}
Model model = new Model();
model.setName((String) source.get(MLModel.MODEL_NAME));
model.setVersion((Integer) source.get(MLModel.MODEL_VERSION));
byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT));
model.setContent(decoded);
// run predict
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
MLOutput output = MLEngine.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
if (output instanceof MLPredictionOutput) {
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
}
// Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
handleAsyncMLTaskComplete(mlTask);
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
internalListener.onResponse(response);
}, e -> {
log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e);
handlePredictFailure(mlTask, internalListener, e, true);
});
GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId());
client.get(getRequest, ActionListener.runBefore(getResponseListener, () -> context.restore()));
} catch (Exception e) {
log.error("Failed to get model " + mlTask.getModelId(), e);
handlePredictFailure(mlTask, internalListener, e, true);
}
} else {
IllegalArgumentException e = new IllegalArgumentException("ModelId is invalid");
log.error("ModelId is invalid", e);
handlePredictFailure(mlTask, internalListener, e, false);
}
}
use of org.opensearch.ml.common.parameter.Model in project ml-commons by opensearch-project.
the class ModelSerDeSerTest method testModelSerDeSerKMeans.
@Test
public void testModelSerDeSerKMeans() {
KMeansParams params = KMeansParams.builder().build();
KMeans kMeans = new KMeans(params);
Model model = kMeans.train(constructKMeansDataFrame(100));
KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent());
byte[] serializedModel = ModelSerDeSer.serialize(kMeansModel);
assertFalse(Arrays.equals(serializedModel, model.getContent()));
}
use of org.opensearch.ml.common.parameter.Model in project ml-commons by opensearch-project.
the class KMeansTest method train.
@Test
public void train() {
Model model = kMeans.train(trainDataFrame);
Assert.assertEquals(FunctionName.KMEANS.name(), model.getName());
Assert.assertEquals(1, model.getVersion());
Assert.assertNotNull(model.getContent());
}
Aggregations