use of org.tribuo.Prediction in project ml-commons by opensearch-project.
the class LinearRegression method predict.
@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
if (model == null) {
throw new IllegalArgumentException("No model found for linear regression prediction.");
}
org.tribuo.Model<Regressor> regressionModel = (org.tribuo.Model<Regressor>) ModelSerDeSer.deserialize(model.getContent());
MutableDataset<Regressor> predictionDataset = TribuoUtil.generateDataset(dataFrame, new RegressionFactory(), "Linear regression prediction data from opensearch", TribuoOutputType.REGRESSOR);
List<Prediction<Regressor>> predictions = regressionModel.predict(predictionDataset);
List<Map<String, Object>> listPrediction = new ArrayList<>();
predictions.forEach(e -> listPrediction.add(Collections.singletonMap(e.getOutput().getNames()[0], e.getOutput().getValues()[0])));
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listPrediction)).build();
}
use of org.tribuo.Prediction in project ml-commons by opensearch-project.
the class KMeans method trainAndPredict.
@Override
public MLOutput trainAndPredict(DataFrame dataFrame) {
MutableDataset<ClusterID> trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training and predicting data from opensearch", TribuoOutputType.CLUSTERID);
Integer centroids = Optional.ofNullable(parameters.getCentroids()).orElse(DEFAULT_CENTROIDS);
Integer iterations = Optional.ofNullable(parameters.getIterations()).orElse(DEFAULT_ITERATIONS);
KMeansTrainer trainer = new KMeansTrainer(centroids, iterations, distance, numThreads, seed);
// won't store model in index
KMeansModel kMeansModel = trainer.train(trainDataset);
List<Prediction<ClusterID>> predictions = kMeansModel.predict(trainDataset);
List<Map<String, Object>> listClusterID = new ArrayList<>();
predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e.getOutput().getID())));
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
}
use of org.tribuo.Prediction in project ml-commons by opensearch-project.
the class AnomalyDetectionLibSVM method predict.
@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
if (model == null) {
throw new IllegalArgumentException("No model found for KMeans prediction.");
}
List<Prediction<Event>> predictions;
MutableDataset<Event> predictionDataset = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM prediction data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
LibSVMModel libSVMAnomalyModel = (LibSVMModel) ModelSerDeSer.deserialize(model.getContent());
predictions = libSVMAnomalyModel.predict(predictionDataset);
List<Map<String, Object>> adResults = new ArrayList<>();
predictions.forEach(e -> {
Map<String, Object> result = new HashMap<>();
result.put("score", e.getOutput().getScore());
result.put("anomaly_type", e.getOutput().getType().name());
adResults.add(result);
});
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(adResults)).build();
}
use of org.tribuo.Prediction in project ml-commons by opensearch-project.
the class KMeans method predict.
@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
if (model == null) {
throw new IllegalArgumentException("No model found for KMeans prediction.");
}
List<Prediction<ClusterID>> predictions;
MutableDataset<ClusterID> predictionDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans prediction data from opensearch", TribuoOutputType.CLUSTERID);
KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent());
predictions = kMeansModel.predict(predictionDataset);
List<Map<String, Object>> listClusterID = new ArrayList<>();
predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e.getOutput().getID())));
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
}
Aggregations