Search in sources :

Example 1 with Prediction

use of org.tribuo.Prediction in project ml-commons by opensearch-project.

the class LinearRegression method predict.

@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
    if (model == null) {
        throw new IllegalArgumentException("No model found for linear regression prediction.");
    }
    org.tribuo.Model<Regressor> regressionModel = (org.tribuo.Model<Regressor>) ModelSerDeSer.deserialize(model.getContent());
    MutableDataset<Regressor> predictionDataset = TribuoUtil.generateDataset(dataFrame, new RegressionFactory(), "Linear regression prediction data from opensearch", TribuoOutputType.REGRESSOR);
    List<Prediction<Regressor>> predictions = regressionModel.predict(predictionDataset);
    List<Map<String, Object>> listPrediction = new ArrayList<>();
    predictions.forEach(e -> listPrediction.add(Collections.singletonMap(e.getOutput().getNames()[0], e.getOutput().getValues()[0])));
    return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listPrediction)).build();
}
Also used : RegressionFactory(org.tribuo.regression.RegressionFactory) Prediction(org.tribuo.Prediction) ArrayList(java.util.ArrayList) Model(org.opensearch.ml.common.parameter.Model) Regressor(org.tribuo.regression.Regressor) Map(java.util.Map)

Example 2 with Prediction

use of org.tribuo.Prediction in project ml-commons by opensearch-project.

the class KMeans method trainAndPredict.

@Override
public MLOutput trainAndPredict(DataFrame dataFrame) {
    MutableDataset<ClusterID> trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training and predicting data from opensearch", TribuoOutputType.CLUSTERID);
    Integer centroids = Optional.ofNullable(parameters.getCentroids()).orElse(DEFAULT_CENTROIDS);
    Integer iterations = Optional.ofNullable(parameters.getIterations()).orElse(DEFAULT_ITERATIONS);
    KMeansTrainer trainer = new KMeansTrainer(centroids, iterations, distance, numThreads, seed);
    // won't store model in index
    KMeansModel kMeansModel = trainer.train(trainDataset);
    List<Prediction<ClusterID>> predictions = kMeansModel.predict(trainDataset);
    List<Map<String, Object>> listClusterID = new ArrayList<>();
    predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e.getOutput().getID())));
    return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
}
Also used : KMeansModel(org.tribuo.clustering.kmeans.KMeansModel) ClusterID(org.tribuo.clustering.ClusterID) Prediction(org.tribuo.Prediction) ArrayList(java.util.ArrayList) KMeansTrainer(org.tribuo.clustering.kmeans.KMeansTrainer) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) Map(java.util.Map)

Example 3 with Prediction

use of org.tribuo.Prediction in project ml-commons by opensearch-project.

the class AnomalyDetectionLibSVM method predict.

@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
    if (model == null) {
        throw new IllegalArgumentException("No model found for KMeans prediction.");
    }
    List<Prediction<Event>> predictions;
    MutableDataset<Event> predictionDataset = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM prediction data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
    LibSVMModel libSVMAnomalyModel = (LibSVMModel) ModelSerDeSer.deserialize(model.getContent());
    predictions = libSVMAnomalyModel.predict(predictionDataset);
    List<Map<String, Object>> adResults = new ArrayList<>();
    predictions.forEach(e -> {
        Map<String, Object> result = new HashMap<>();
        result.put("score", e.getOutput().getScore());
        result.put("anomaly_type", e.getOutput().getType().name());
        adResults.add(result);
    });
    return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(adResults)).build();
}
Also used : LibSVMModel(org.tribuo.common.libsvm.LibSVMModel) HashMap(java.util.HashMap) Prediction(org.tribuo.Prediction) ArrayList(java.util.ArrayList) AnomalyFactory(org.tribuo.anomaly.AnomalyFactory) Event(org.tribuo.anomaly.Event) HashMap(java.util.HashMap) Map(java.util.Map)

Example 4 with Prediction

use of org.tribuo.Prediction in project ml-commons by opensearch-project.

the class KMeans method predict.

@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
    if (model == null) {
        throw new IllegalArgumentException("No model found for KMeans prediction.");
    }
    List<Prediction<ClusterID>> predictions;
    MutableDataset<ClusterID> predictionDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans prediction data from opensearch", TribuoOutputType.CLUSTERID);
    KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent());
    predictions = kMeansModel.predict(predictionDataset);
    List<Map<String, Object>> listClusterID = new ArrayList<>();
    predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e.getOutput().getID())));
    return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
}
Also used : KMeansModel(org.tribuo.clustering.kmeans.KMeansModel) ClusterID(org.tribuo.clustering.ClusterID) Prediction(org.tribuo.Prediction) ArrayList(java.util.ArrayList) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) Map(java.util.Map)

Aggregations

ArrayList (java.util.ArrayList)4 Map (java.util.Map)4 Prediction (org.tribuo.Prediction)4 ClusterID (org.tribuo.clustering.ClusterID)2 ClusteringFactory (org.tribuo.clustering.ClusteringFactory)2 KMeansModel (org.tribuo.clustering.kmeans.KMeansModel)2 HashMap (java.util.HashMap)1 Model (org.opensearch.ml.common.parameter.Model)1 AnomalyFactory (org.tribuo.anomaly.AnomalyFactory)1 Event (org.tribuo.anomaly.Event)1 KMeansTrainer (org.tribuo.clustering.kmeans.KMeansTrainer)1 LibSVMModel (org.tribuo.common.libsvm.LibSVMModel)1 RegressionFactory (org.tribuo.regression.RegressionFactory)1 Regressor (org.tribuo.regression.Regressor)1