Search in sources :

Example 1 with Event

use of org.tribuo.anomaly.Event in project ml-commons by opensearch-project.

the class AnomalyDetectionLibSVM method train.

@Override
public Model train(DataFrame dataFrame) {
    KernelType kernelType = parseKernelType();
    SVMParameters params = new SVMParameters<>(new SVMAnomalyType(SVMAnomalyType.SVMMode.ONE_CLASS), kernelType);
    Double gamma = Optional.ofNullable(parameters.getGamma()).orElse(DEFAULT_GAMMA);
    Double nu = Optional.ofNullable(parameters.getNu()).orElse(DEFAULT_NU);
    params.setGamma(gamma);
    params.setNu(nu);
    if (parameters.getCost() != null) {
        params.setCost(parameters.getCost());
    }
    if (parameters.getCoeff() != null) {
        params.setCoeff(parameters.getCoeff());
    }
    if (parameters.getEpsilon() != null) {
        params.setEpsilon(parameters.getEpsilon());
    }
    if (parameters.getDegree() != null) {
        params.setDegree(parameters.getDegree());
    }
    MutableDataset<Event> data = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM training data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
    LibSVMAnomalyTrainer trainer = new LibSVMAnomalyTrainer(params);
    LibSVMModel libSVMModel = trainer.train(data);
    ((LibSVMAnomalyModel) libSVMModel).getNumberOfSupportVectors();
    Model model = new Model();
    model.setName(FunctionName.AD_LIBSVM.name());
    model.setVersion(VERSION);
    model.setContent(ModelSerDeSer.serialize(libSVMModel));
    return model;
}
Also used : LibSVMModel(org.tribuo.common.libsvm.LibSVMModel) LibSVMAnomalyModel(org.tribuo.anomaly.libsvm.LibSVMAnomalyModel) SVMAnomalyType(org.tribuo.anomaly.libsvm.SVMAnomalyType) SVMParameters(org.tribuo.common.libsvm.SVMParameters) LibSVMModel(org.tribuo.common.libsvm.LibSVMModel) Model(org.opensearch.ml.common.parameter.Model) LibSVMAnomalyModel(org.tribuo.anomaly.libsvm.LibSVMAnomalyModel) Event(org.tribuo.anomaly.Event) KernelType(org.tribuo.common.libsvm.KernelType) LibSVMAnomalyTrainer(org.tribuo.anomaly.libsvm.LibSVMAnomalyTrainer) AnomalyFactory(org.tribuo.anomaly.AnomalyFactory)

Example 2 with Event

use of org.tribuo.anomaly.Event in project ml-commons by opensearch-project.

the class TribuoUtil method generateDataset.

/**
 * Generate tribuo dataset from data frame.
 * @param dataFrame features data
 * @param outputFactory the tribuo output factory
 * @param desc description for tribuo provenance
 * @param outputType the tribuo output type
 * @return tribuo dataset
 */
public static <T extends Output<T>> MutableDataset<T> generateDataset(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType) {
    List<Example<T>> dataset = new ArrayList<>();
    Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
    ArrayExample<T> example;
    for (int i = 0; i < dataFrame.size(); ++i) {
        switch(outputType) {
            case CLUSTERID:
                example = new ArrayExample<>((T) new ClusterID(ClusterID.UNASSIGNED), featureNamesValues.v1(), featureNamesValues.v2()[i]);
                break;
            case REGRESSOR:
                // Create single dimension tribuo regressor with name DIM-0 and value double NaN.
                example = new ArrayExample<>((T) new Regressor("DIM-0", Double.NaN), featureNamesValues.v1(), featureNamesValues.v2()[i]);
                break;
            case ANOMALY_DETECTION_LIBSVM:
                // Why we set default event type as EXPECTED(non-anomalous)
                // 1. For training data, Tribuo LibSVMAnomalyTrainer only supports EXPECTED events at training time.
                // 2. For prediction data, we treat the data as non-anomalous by default as Tribuo lib don't accept UNKNOWN type.
                Event.EventType defaultEventType = Event.EventType.EXPECTED;
                // TODO: support anomaly labels to evaluate prediction result
                example = new ArrayExample<>((T) new Event(defaultEventType), featureNamesValues.v1(), featureNamesValues.v2()[i]);
                break;
            default:
                throw new IllegalArgumentException("unknown type:" + outputType);
        }
        dataset.add(example);
    }
    SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
    return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) ArrayList(java.util.ArrayList) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Event(org.tribuo.anomaly.Event) Regressor(org.tribuo.regression.Regressor) MutableDataset(org.tribuo.MutableDataset)

Example 3 with Event

use of org.tribuo.anomaly.Event in project ml-commons by opensearch-project.

the class AnomalyDetectionLibSVMTest method constructDataFrame.

private DataFrame constructDataFrame(Dataset<Event> data, boolean training, List<Event.EventType> labels) {
    Iterator<Example<Event>> iterator = data.iterator();
    List<ColumnMeta> columns = null;
    DataFrame dataFrame = null;
    while (iterator.hasNext()) {
        Example<Event> example = iterator.next();
        if (columns == null) {
            columns = new ArrayList<>();
            List<ColumnValue> columnValues = new ArrayList<>();
            for (Feature feature : example) {
                columns.add(new ColumnMeta(feature.getName(), ColumnType.DOUBLE));
                columnValues.add(new DoubleValue(feature.getValue()));
            }
            ColumnMeta[] columnMetas = columns.toArray(new ColumnMeta[columns.size()]);
            dataFrame = new DefaultDataFrame(columnMetas);
            addRow(columnValues, training, example, dataFrame, labels);
        } else {
            List<ColumnValue> columnValues = new ArrayList<>();
            for (Feature feature : example) {
                columnValues.add(new DoubleValue(feature.getValue()));
            }
            addRow(columnValues, training, example, dataFrame, labels);
        }
    }
    return dataFrame;
}
Also used : ArrayList(java.util.ArrayList) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) DefaultDataFrame(org.opensearch.ml.common.dataframe.DefaultDataFrame) Feature(org.tribuo.Feature) DefaultDataFrame(org.opensearch.ml.common.dataframe.DefaultDataFrame) ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) DoubleValue(org.opensearch.ml.common.dataframe.DoubleValue) Example(org.tribuo.Example) Event(org.tribuo.anomaly.Event) ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue)

Example 4 with Event

use of org.tribuo.anomaly.Event in project ml-commons by opensearch-project.

the class AnomalyDetectionLibSVM method predict.

@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
    if (model == null) {
        throw new IllegalArgumentException("No model found for KMeans prediction.");
    }
    List<Prediction<Event>> predictions;
    MutableDataset<Event> predictionDataset = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM prediction data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
    LibSVMModel libSVMAnomalyModel = (LibSVMModel) ModelSerDeSer.deserialize(model.getContent());
    predictions = libSVMAnomalyModel.predict(predictionDataset);
    List<Map<String, Object>> adResults = new ArrayList<>();
    predictions.forEach(e -> {
        Map<String, Object> result = new HashMap<>();
        result.put("score", e.getOutput().getScore());
        result.put("anomaly_type", e.getOutput().getType().name());
        adResults.add(result);
    });
    return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(adResults)).build();
}
Also used : LibSVMModel(org.tribuo.common.libsvm.LibSVMModel) HashMap(java.util.HashMap) Prediction(org.tribuo.Prediction) ArrayList(java.util.ArrayList) AnomalyFactory(org.tribuo.anomaly.AnomalyFactory) Event(org.tribuo.anomaly.Event) HashMap(java.util.HashMap) Map(java.util.Map)

Aggregations

Event (org.tribuo.anomaly.Event)4 ArrayList (java.util.ArrayList)3 Example (org.tribuo.Example)2 AnomalyFactory (org.tribuo.anomaly.AnomalyFactory)2 LibSVMModel (org.tribuo.common.libsvm.LibSVMModel)2 HashMap (java.util.HashMap)1 Map (java.util.Map)1 ColumnMeta (org.opensearch.ml.common.dataframe.ColumnMeta)1 ColumnValue (org.opensearch.ml.common.dataframe.ColumnValue)1 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)1 DefaultDataFrame (org.opensearch.ml.common.dataframe.DefaultDataFrame)1 DoubleValue (org.opensearch.ml.common.dataframe.DoubleValue)1 Model (org.opensearch.ml.common.parameter.Model)1 Feature (org.tribuo.Feature)1 MutableDataset (org.tribuo.MutableDataset)1 Prediction (org.tribuo.Prediction)1 LibSVMAnomalyModel (org.tribuo.anomaly.libsvm.LibSVMAnomalyModel)1 LibSVMAnomalyTrainer (org.tribuo.anomaly.libsvm.LibSVMAnomalyTrainer)1 SVMAnomalyType (org.tribuo.anomaly.libsvm.SVMAnomalyType)1 ClusterID (org.tribuo.clustering.ClusterID)1