Search in sources :

Example 1 with KernelType

use of org.tribuo.common.libsvm.KernelType in project ml-commons by opensearch-project.

the class AnomalyDetectionLibSVM method train.

@Override
public Model train(DataFrame dataFrame) {
    KernelType kernelType = parseKernelType();
    SVMParameters params = new SVMParameters<>(new SVMAnomalyType(SVMAnomalyType.SVMMode.ONE_CLASS), kernelType);
    Double gamma = Optional.ofNullable(parameters.getGamma()).orElse(DEFAULT_GAMMA);
    Double nu = Optional.ofNullable(parameters.getNu()).orElse(DEFAULT_NU);
    params.setGamma(gamma);
    params.setNu(nu);
    if (parameters.getCost() != null) {
        params.setCost(parameters.getCost());
    }
    if (parameters.getCoeff() != null) {
        params.setCoeff(parameters.getCoeff());
    }
    if (parameters.getEpsilon() != null) {
        params.setEpsilon(parameters.getEpsilon());
    }
    if (parameters.getDegree() != null) {
        params.setDegree(parameters.getDegree());
    }
    MutableDataset<Event> data = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM training data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
    LibSVMAnomalyTrainer trainer = new LibSVMAnomalyTrainer(params);
    LibSVMModel libSVMModel = trainer.train(data);
    ((LibSVMAnomalyModel) libSVMModel).getNumberOfSupportVectors();
    Model model = new Model();
    model.setName(FunctionName.AD_LIBSVM.name());
    model.setVersion(VERSION);
    model.setContent(ModelSerDeSer.serialize(libSVMModel));
    return model;
}
Also used : LibSVMModel(org.tribuo.common.libsvm.LibSVMModel) LibSVMAnomalyModel(org.tribuo.anomaly.libsvm.LibSVMAnomalyModel) SVMAnomalyType(org.tribuo.anomaly.libsvm.SVMAnomalyType) SVMParameters(org.tribuo.common.libsvm.SVMParameters) LibSVMModel(org.tribuo.common.libsvm.LibSVMModel) Model(org.opensearch.ml.common.parameter.Model) LibSVMAnomalyModel(org.tribuo.anomaly.libsvm.LibSVMAnomalyModel) Event(org.tribuo.anomaly.Event) KernelType(org.tribuo.common.libsvm.KernelType) LibSVMAnomalyTrainer(org.tribuo.anomaly.libsvm.LibSVMAnomalyTrainer) AnomalyFactory(org.tribuo.anomaly.AnomalyFactory)

Aggregations

Model (org.opensearch.ml.common.parameter.Model)1 AnomalyFactory (org.tribuo.anomaly.AnomalyFactory)1 Event (org.tribuo.anomaly.Event)1 LibSVMAnomalyModel (org.tribuo.anomaly.libsvm.LibSVMAnomalyModel)1 LibSVMAnomalyTrainer (org.tribuo.anomaly.libsvm.LibSVMAnomalyTrainer)1 SVMAnomalyType (org.tribuo.anomaly.libsvm.SVMAnomalyType)1 KernelType (org.tribuo.common.libsvm.KernelType)1 LibSVMModel (org.tribuo.common.libsvm.LibSVMModel)1 SVMParameters (org.tribuo.common.libsvm.SVMParameters)1