Search in sources :

Example 1 with Model

use of org.opensearch.ml.common.parameter.Model in project ml-commons by opensearch-project.

the class AnomalyDetectionLibSVM method train.

@Override
public Model train(DataFrame dataFrame) {
    KernelType kernelType = parseKernelType();
    SVMParameters params = new SVMParameters<>(new SVMAnomalyType(SVMAnomalyType.SVMMode.ONE_CLASS), kernelType);
    Double gamma = Optional.ofNullable(parameters.getGamma()).orElse(DEFAULT_GAMMA);
    Double nu = Optional.ofNullable(parameters.getNu()).orElse(DEFAULT_NU);
    params.setGamma(gamma);
    params.setNu(nu);
    if (parameters.getCost() != null) {
        params.setCost(parameters.getCost());
    }
    if (parameters.getCoeff() != null) {
        params.setCoeff(parameters.getCoeff());
    }
    if (parameters.getEpsilon() != null) {
        params.setEpsilon(parameters.getEpsilon());
    }
    if (parameters.getDegree() != null) {
        params.setDegree(parameters.getDegree());
    }
    MutableDataset<Event> data = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM training data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
    LibSVMAnomalyTrainer trainer = new LibSVMAnomalyTrainer(params);
    LibSVMModel libSVMModel = trainer.train(data);
    ((LibSVMAnomalyModel) libSVMModel).getNumberOfSupportVectors();
    Model model = new Model();
    model.setName(FunctionName.AD_LIBSVM.name());
    model.setVersion(VERSION);
    model.setContent(ModelSerDeSer.serialize(libSVMModel));
    return model;
}
Also used : LibSVMModel(org.tribuo.common.libsvm.LibSVMModel) LibSVMAnomalyModel(org.tribuo.anomaly.libsvm.LibSVMAnomalyModel) SVMAnomalyType(org.tribuo.anomaly.libsvm.SVMAnomalyType) SVMParameters(org.tribuo.common.libsvm.SVMParameters) LibSVMModel(org.tribuo.common.libsvm.LibSVMModel) Model(org.opensearch.ml.common.parameter.Model) LibSVMAnomalyModel(org.tribuo.anomaly.libsvm.LibSVMAnomalyModel) Event(org.tribuo.anomaly.Event) KernelType(org.tribuo.common.libsvm.KernelType) LibSVMAnomalyTrainer(org.tribuo.anomaly.libsvm.LibSVMAnomalyTrainer) AnomalyFactory(org.tribuo.anomaly.AnomalyFactory)

Example 2 with Model

use of org.opensearch.ml.common.parameter.Model in project ml-commons by opensearch-project.

the class MLTrainingTaskRunner method train.

private void train(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
    ActionListener<MLTaskResponse> listener = ActionListener.wrap(r -> actionListener.onResponse(r), e -> {
        mlStats.createCounterStatIfAbsent(failureCountStat(mlTask.getFunctionName(), ActionName.TRAIN)).increment();
        mlStats.getStat(ML_TOTAL_FAILURE_COUNT).increment();
        actionListener.onFailure(e);
    });
    try {
        // run training
        mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
        Model model = MLEngine.train(mlInput);
        mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(indexCreated -> {
            if (!indexCreated) {
                listener.onFailure(new RuntimeException("No response to create ML task index"));
                return;
            }
            // TODO: put the user into model for backend role based access control.
            MLModel mlModel = new MLModel(mlInput.getAlgorithm(), model);
            try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
                ActionListener<IndexResponse> indexResponseListener = ActionListener.wrap(r -> {
                    log.info("Model data indexing done, result:{}, model id: {}", r.getResult(), r.getId());
                    mlStats.getStat(ML_TOTAL_MODEL_COUNT).increment();
                    mlStats.createCounterStatIfAbsent(modelCountStat(mlTask.getFunctionName())).increment();
                    String returnedTaskId = mlTask.isAsync() ? mlTask.getTaskId() : null;
                    MLTrainingOutput output = new MLTrainingOutput(r.getId(), returnedTaskId, MLTaskState.COMPLETED.name());
                    listener.onResponse(MLTaskResponse.builder().output(output).build());
                }, e -> {
                    listener.onFailure(e);
                });
                IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX);
                indexRequest.source(mlModel.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
                indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                client.index(indexRequest, ActionListener.runBefore(indexResponseListener, () -> context.restore()));
            } catch (Exception e) {
                log.error("Failed to save ML model", e);
                listener.onFailure(e);
            }
        }, e -> {
            log.error("Failed to init ML model index", e);
            listener.onFailure(e);
        }));
    } catch (Exception e) {
        // todo need to specify what exception
        log.error("Failed to train " + mlInput.getAlgorithm(), e);
        listener.onFailure(e);
    }
}
Also used : IndexResponse(org.opensearch.action.index.IndexResponse) ToXContent(org.opensearch.common.xcontent.ToXContent) ThreadPool(org.opensearch.threadpool.ThreadPool) StatNames.modelCountStat(org.opensearch.ml.stats.StatNames.modelCountStat) MLTaskState(org.opensearch.ml.common.parameter.MLTaskState) MLInput(org.opensearch.ml.common.parameter.MLInput) MLInputDatasetHandler(org.opensearch.ml.indices.MLInputDatasetHandler) ThreadedActionListener(org.opensearch.action.support.ThreadedActionListener) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) MLTask(org.opensearch.ml.common.parameter.MLTask) ML_EXECUTING_TASK_COUNT(org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT) WriteRequest(org.opensearch.action.support.WriteRequest) ActionListener(org.opensearch.action.ActionListener) MLModel(org.opensearch.ml.common.parameter.MLModel) MLIndicesHandler(org.opensearch.ml.indices.MLIndicesHandler) ML_TOTAL_MODEL_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_MODEL_COUNT) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLInputDataType(org.opensearch.ml.common.dataset.MLInputDataType) Client(org.opensearch.client.Client) MLStats(org.opensearch.ml.stats.MLStats) ActionName(org.opensearch.ml.stats.ActionName) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) UUID(java.util.UUID) Instant(java.time.Instant) TransportService(org.opensearch.transport.TransportService) StatNames.requestCountStat(org.opensearch.ml.stats.StatNames.requestCountStat) MLTrainingTaskAction(org.opensearch.ml.common.transport.training.MLTrainingTaskAction) XContentBuilder(org.opensearch.common.xcontent.XContentBuilder) MLEngine(org.opensearch.ml.engine.MLEngine) StatNames.failureCountStat(org.opensearch.ml.stats.StatNames.failureCountStat) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) ML_MODEL_INDEX(org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX) Model(org.opensearch.ml.common.parameter.Model) ML_TOTAL_FAILURE_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_FAILURE_COUNT) MLTaskType(org.opensearch.ml.common.parameter.MLTaskType) MLCircuitBreakerService(org.opensearch.ml.common.breaker.MLCircuitBreakerService) Log4j2(lombok.extern.log4j.Log4j2) ActionListenerResponseHandler(org.opensearch.action.ActionListenerResponseHandler) ClusterService(org.opensearch.cluster.service.ClusterService) TASK_THREAD_POOL(org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL) ML_TOTAL_REQUEST_COUNT(org.opensearch.ml.stats.StatNames.ML_TOTAL_REQUEST_COUNT) XContentType(org.opensearch.common.xcontent.XContentType) IndexRequest(org.opensearch.action.index.IndexRequest) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLTrainingTaskRequest(org.opensearch.ml.common.transport.training.MLTrainingTaskRequest) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) MLTrainingOutput(org.opensearch.ml.common.parameter.MLTrainingOutput) ThreadedActionListener(org.opensearch.action.support.ThreadedActionListener) ActionListener(org.opensearch.action.ActionListener) MLModel(org.opensearch.ml.common.parameter.MLModel) Model(org.opensearch.ml.common.parameter.Model) MLModel(org.opensearch.ml.common.parameter.MLModel) IndexRequest(org.opensearch.action.index.IndexRequest)

Example 3 with Model

use of org.opensearch.ml.common.parameter.Model in project ml-commons by opensearch-project.

the class MLPredictTaskRunner method predict.

private void predict(MLTask mlTask, DataFrame inputDataFrame, MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
    ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
    // track ML task count and add ML task into cache
    mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment();
    mlStats.getStat(ML_TOTAL_REQUEST_COUNT).increment();
    mlStats.createCounterStatIfAbsent(requestCountStat(mlTask.getFunctionName(), ActionName.PREDICT)).increment();
    mlTaskManager.add(mlTask);
    // run predict
    if (request.getModelId() != null) {
        // search model by model id.
        try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) {
            MLInput mlInput = request.getMlInput();
            ActionListener<GetResponse> getResponseListener = ActionListener.wrap(r -> {
                if (r == null || !r.isExists()) {
                    internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId."));
                    return;
                }
                Map<String, Object> source = r.getSourceAsMap();
                User requestUser = getUserContext(client);
                User resourceUser = User.parse((String) source.get(USER));
                if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
                    // The backend roles of request user and resource user doesn't have intersection
                    OpenSearchException e = new OpenSearchException("User: " + requestUser.getName() + " does not have permissions to run predict by model: " + request.getModelId());
                    handlePredictFailure(mlTask, internalListener, e, false);
                    return;
                }
                Model model = new Model();
                model.setName((String) source.get(MLModel.MODEL_NAME));
                model.setVersion((Integer) source.get(MLModel.MODEL_VERSION));
                byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT));
                model.setContent(decoded);
                // run predict
                mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
                MLOutput output = MLEngine.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
                if (output instanceof MLPredictionOutput) {
                    ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
                }
                // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
                handleAsyncMLTaskComplete(mlTask);
                MLTaskResponse response = MLTaskResponse.builder().output(output).build();
                internalListener.onResponse(response);
            }, e -> {
                log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e);
                handlePredictFailure(mlTask, internalListener, e, true);
            });
            GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId());
            client.get(getRequest, ActionListener.runBefore(getResponseListener, () -> context.restore()));
        } catch (Exception e) {
            log.error("Failed to get model " + mlTask.getModelId(), e);
            handlePredictFailure(mlTask, internalListener, e, true);
        }
    } else {
        IllegalArgumentException e = new IllegalArgumentException("ModelId is invalid");
        log.error("ModelId is invalid", e);
        handlePredictFailure(mlTask, internalListener, e, false);
    }
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) User(org.opensearch.commons.authuser.User) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) MLOutput(org.opensearch.ml.common.parameter.MLOutput) GetResponse(org.opensearch.action.get.GetResponse) OpenSearchException(org.opensearch.OpenSearchException) ResourceNotFoundException(org.opensearch.ResourceNotFoundException) MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) GetRequest(org.opensearch.action.get.GetRequest) MLModel(org.opensearch.ml.common.parameter.MLModel) Model(org.opensearch.ml.common.parameter.Model) OpenSearchException(org.opensearch.OpenSearchException) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) ResourceNotFoundException(org.opensearch.ResourceNotFoundException)

Example 4 with Model

use of org.opensearch.ml.common.parameter.Model in project ml-commons by opensearch-project.

the class ModelSerDeSerTest method testModelSerDeSerKMeans.

@Test
public void testModelSerDeSerKMeans() {
    KMeansParams params = KMeansParams.builder().build();
    KMeans kMeans = new KMeans(params);
    Model model = kMeans.train(constructKMeansDataFrame(100));
    KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent());
    byte[] serializedModel = ModelSerDeSer.serialize(kMeansModel);
    assertFalse(Arrays.equals(serializedModel, model.getContent()));
}
Also used : KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) KMeansModel(org.tribuo.clustering.kmeans.KMeansModel) KMeans(org.opensearch.ml.engine.algorithms.clustering.KMeans) Model(org.opensearch.ml.common.parameter.Model) KMeansModel(org.tribuo.clustering.kmeans.KMeansModel) Test(org.junit.Test)

Example 5 with Model

use of org.opensearch.ml.common.parameter.Model in project ml-commons by opensearch-project.

the class KMeansTest method train.

@Test
public void train() {
    Model model = kMeans.train(trainDataFrame);
    Assert.assertEquals(FunctionName.KMEANS.name(), model.getName());
    Assert.assertEquals(1, model.getVersion());
    Assert.assertNotNull(model.getContent());
}
Also used : Model(org.opensearch.ml.common.parameter.Model) Test(org.junit.Test)

Aggregations

Model (org.opensearch.ml.common.parameter.Model)29 Test (org.junit.Test)18 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)8 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)8 MLInput (org.opensearch.ml.common.parameter.MLInput)5 ThreadContext (org.opensearch.common.util.concurrent.ThreadContext)3 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)3 MLModel (org.opensearch.ml.common.parameter.MLModel)3 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)3 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)3 Instant (java.time.Instant)2 UUID (java.util.UUID)2 Log4j2 (lombok.extern.log4j.Log4j2)2 ActionListener (org.opensearch.action.ActionListener)2 ActionListenerResponseHandler (org.opensearch.action.ActionListenerResponseHandler)2 IndexRequest (org.opensearch.action.index.IndexRequest)2 IndexResponse (org.opensearch.action.index.IndexResponse)2 ThreadedActionListener (org.opensearch.action.support.ThreadedActionListener)2 WriteRequest (org.opensearch.action.support.WriteRequest)2 Client (org.opensearch.client.Client)2