use of org.tribuo.anomaly.AnomalyFactory in project ml-commons by opensearch-project.
the class AnomalyDetectionLibSVM method train.
@Override
public Model train(DataFrame dataFrame) {
KernelType kernelType = parseKernelType();
SVMParameters params = new SVMParameters<>(new SVMAnomalyType(SVMAnomalyType.SVMMode.ONE_CLASS), kernelType);
Double gamma = Optional.ofNullable(parameters.getGamma()).orElse(DEFAULT_GAMMA);
Double nu = Optional.ofNullable(parameters.getNu()).orElse(DEFAULT_NU);
params.setGamma(gamma);
params.setNu(nu);
if (parameters.getCost() != null) {
params.setCost(parameters.getCost());
}
if (parameters.getCoeff() != null) {
params.setCoeff(parameters.getCoeff());
}
if (parameters.getEpsilon() != null) {
params.setEpsilon(parameters.getEpsilon());
}
if (parameters.getDegree() != null) {
params.setDegree(parameters.getDegree());
}
MutableDataset<Event> data = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM training data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
LibSVMAnomalyTrainer trainer = new LibSVMAnomalyTrainer(params);
LibSVMModel libSVMModel = trainer.train(data);
((LibSVMAnomalyModel) libSVMModel).getNumberOfSupportVectors();
Model model = new Model();
model.setName(FunctionName.AD_LIBSVM.name());
model.setVersion(VERSION);
model.setContent(ModelSerDeSer.serialize(libSVMModel));
return model;
}
use of org.tribuo.anomaly.AnomalyFactory in project ml-commons by opensearch-project.
the class AnomalyDetectionLibSVM method predict.
@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
if (model == null) {
throw new IllegalArgumentException("No model found for KMeans prediction.");
}
List<Prediction<Event>> predictions;
MutableDataset<Event> predictionDataset = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM prediction data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
LibSVMModel libSVMAnomalyModel = (LibSVMModel) ModelSerDeSer.deserialize(model.getContent());
predictions = libSVMAnomalyModel.predict(predictionDataset);
List<Map<String, Object>> adResults = new ArrayList<>();
predictions.forEach(e -> {
Map<String, Object> result = new HashMap<>();
result.put("score", e.getOutput().getScore());
result.put("anomaly_type", e.getOutput().getType().name());
adResults.add(result);
});
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(adResults)).build();
}
Aggregations