use of org.tribuo.anomaly.libsvm.LibSVMAnomalyTrainer in project ml-commons by opensearch-project.
the class AnomalyDetectionLibSVM method train.
@Override
public Model train(DataFrame dataFrame) {
KernelType kernelType = parseKernelType();
SVMParameters params = new SVMParameters<>(new SVMAnomalyType(SVMAnomalyType.SVMMode.ONE_CLASS), kernelType);
Double gamma = Optional.ofNullable(parameters.getGamma()).orElse(DEFAULT_GAMMA);
Double nu = Optional.ofNullable(parameters.getNu()).orElse(DEFAULT_NU);
params.setGamma(gamma);
params.setNu(nu);
if (parameters.getCost() != null) {
params.setCost(parameters.getCost());
}
if (parameters.getCoeff() != null) {
params.setCoeff(parameters.getCoeff());
}
if (parameters.getEpsilon() != null) {
params.setEpsilon(parameters.getEpsilon());
}
if (parameters.getDegree() != null) {
params.setDegree(parameters.getDegree());
}
MutableDataset<Event> data = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM training data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
LibSVMAnomalyTrainer trainer = new LibSVMAnomalyTrainer(params);
LibSVMModel libSVMModel = trainer.train(data);
((LibSVMAnomalyModel) libSVMModel).getNumberOfSupportVectors();
Model model = new Model();
model.setName(FunctionName.AD_LIBSVM.name());
model.setVersion(VERSION);
model.setContent(ModelSerDeSer.serialize(libSVMModel));
return model;
}
Aggregations