Search in sources :

Example 1 with LibSVMModel

use of org.tribuo.common.libsvm.LibSVMModel in project ml-commons by opensearch-project.

the class AnomalyDetectionLibSVM method train.

@Override
public Model train(DataFrame dataFrame) {
    KernelType kernelType = parseKernelType();
    SVMParameters params = new SVMParameters<>(new SVMAnomalyType(SVMAnomalyType.SVMMode.ONE_CLASS), kernelType);
    Double gamma = Optional.ofNullable(parameters.getGamma()).orElse(DEFAULT_GAMMA);
    Double nu = Optional.ofNullable(parameters.getNu()).orElse(DEFAULT_NU);
    params.setGamma(gamma);
    params.setNu(nu);
    if (parameters.getCost() != null) {
        params.setCost(parameters.getCost());
    }
    if (parameters.getCoeff() != null) {
        params.setCoeff(parameters.getCoeff());
    }
    if (parameters.getEpsilon() != null) {
        params.setEpsilon(parameters.getEpsilon());
    }
    if (parameters.getDegree() != null) {
        params.setDegree(parameters.getDegree());
    }
    MutableDataset<Event> data = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM training data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
    LibSVMAnomalyTrainer trainer = new LibSVMAnomalyTrainer(params);
    LibSVMModel libSVMModel = trainer.train(data);
    ((LibSVMAnomalyModel) libSVMModel).getNumberOfSupportVectors();
    Model model = new Model();
    model.setName(FunctionName.AD_LIBSVM.name());
    model.setVersion(VERSION);
    model.setContent(ModelSerDeSer.serialize(libSVMModel));
    return model;
}
Also used : LibSVMModel(org.tribuo.common.libsvm.LibSVMModel) LibSVMAnomalyModel(org.tribuo.anomaly.libsvm.LibSVMAnomalyModel) SVMAnomalyType(org.tribuo.anomaly.libsvm.SVMAnomalyType) SVMParameters(org.tribuo.common.libsvm.SVMParameters) LibSVMModel(org.tribuo.common.libsvm.LibSVMModel) Model(org.opensearch.ml.common.parameter.Model) LibSVMAnomalyModel(org.tribuo.anomaly.libsvm.LibSVMAnomalyModel) Event(org.tribuo.anomaly.Event) KernelType(org.tribuo.common.libsvm.KernelType) LibSVMAnomalyTrainer(org.tribuo.anomaly.libsvm.LibSVMAnomalyTrainer) AnomalyFactory(org.tribuo.anomaly.AnomalyFactory)

Example 2 with LibSVMModel

use of org.tribuo.common.libsvm.LibSVMModel in project ml-commons by opensearch-project.

the class AnomalyDetectionLibSVM method predict.

@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
    if (model == null) {
        throw new IllegalArgumentException("No model found for KMeans prediction.");
    }
    List<Prediction<Event>> predictions;
    MutableDataset<Event> predictionDataset = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM prediction data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
    LibSVMModel libSVMAnomalyModel = (LibSVMModel) ModelSerDeSer.deserialize(model.getContent());
    predictions = libSVMAnomalyModel.predict(predictionDataset);
    List<Map<String, Object>> adResults = new ArrayList<>();
    predictions.forEach(e -> {
        Map<String, Object> result = new HashMap<>();
        result.put("score", e.getOutput().getScore());
        result.put("anomaly_type", e.getOutput().getType().name());
        adResults.add(result);
    });
    return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(adResults)).build();
}
Also used : LibSVMModel(org.tribuo.common.libsvm.LibSVMModel) HashMap(java.util.HashMap) Prediction(org.tribuo.Prediction) ArrayList(java.util.ArrayList) AnomalyFactory(org.tribuo.anomaly.AnomalyFactory) Event(org.tribuo.anomaly.Event) HashMap(java.util.HashMap) Map(java.util.Map)

Aggregations

AnomalyFactory (org.tribuo.anomaly.AnomalyFactory)2 Event (org.tribuo.anomaly.Event)2 LibSVMModel (org.tribuo.common.libsvm.LibSVMModel)2 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 Model (org.opensearch.ml.common.parameter.Model)1 Prediction (org.tribuo.Prediction)1 LibSVMAnomalyModel (org.tribuo.anomaly.libsvm.LibSVMAnomalyModel)1 LibSVMAnomalyTrainer (org.tribuo.anomaly.libsvm.LibSVMAnomalyTrainer)1 SVMAnomalyType (org.tribuo.anomaly.libsvm.SVMAnomalyType)1 KernelType (org.tribuo.common.libsvm.KernelType)1 SVMParameters (org.tribuo.common.libsvm.SVMParameters)1