use of org.tribuo.clustering.ClusteringFactory in project ml-commons by opensearch-project.
the class KMeans method trainAndPredict.
@Override
public MLOutput trainAndPredict(DataFrame dataFrame) {
MutableDataset<ClusterID> trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training and predicting data from opensearch", TribuoOutputType.CLUSTERID);
Integer centroids = Optional.ofNullable(parameters.getCentroids()).orElse(DEFAULT_CENTROIDS);
Integer iterations = Optional.ofNullable(parameters.getIterations()).orElse(DEFAULT_ITERATIONS);
KMeansTrainer trainer = new KMeansTrainer(centroids, iterations, distance, numThreads, seed);
// won't store model in index
KMeansModel kMeansModel = trainer.train(trainDataset);
List<Prediction<ClusterID>> predictions = kMeansModel.predict(trainDataset);
List<Map<String, Object>> listClusterID = new ArrayList<>();
predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e.getOutput().getID())));
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
}
use of org.tribuo.clustering.ClusteringFactory in project ml-commons by opensearch-project.
the class KMeans method predict.
@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
if (model == null) {
throw new IllegalArgumentException("No model found for KMeans prediction.");
}
List<Prediction<ClusterID>> predictions;
MutableDataset<ClusterID> predictionDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans prediction data from opensearch", TribuoOutputType.CLUSTERID);
KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent());
predictions = kMeansModel.predict(predictionDataset);
List<Map<String, Object>> listClusterID = new ArrayList<>();
predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e.getOutput().getID())));
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
}
use of org.tribuo.clustering.ClusteringFactory in project ml-commons by opensearch-project.
the class KMeans method train.
@Override
public Model train(DataFrame dataFrame) {
MutableDataset<ClusterID> trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training data from opensearch", TribuoOutputType.CLUSTERID);
Integer centroids = Optional.ofNullable(parameters.getCentroids()).orElse(DEFAULT_CENTROIDS);
Integer iterations = Optional.ofNullable(parameters.getIterations()).orElse(DEFAULT_ITERATIONS);
KMeansTrainer trainer = new KMeansTrainer(centroids, iterations, distance, numThreads, seed);
KMeansModel kMeansModel = trainer.train(trainDataset);
Model model = new Model();
model.setName(FunctionName.KMEANS.name());
model.setVersion(1);
model.setContent(ModelSerDeSer.serialize(kMeansModel));
return model;
}
use of org.tribuo.clustering.ClusteringFactory in project ml-commons by opensearch-project.
the class TribuoUtilTest method generateDataset.
@SuppressWarnings("unchecked")
@Test
public void generateDataset() {
MutableDataset<ClusterID> dataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "test", TribuoOutputType.CLUSTERID);
List<Example<ClusterID>> examples = dataset.getData();
Assert.assertEquals(rawData.length, examples.size());
for (int i = 0; i < rawData.length; ++i) {
ArrayExample arrayExample = (ArrayExample) examples.get(i);
Iterator<Feature> iterator = arrayExample.iterator();
int idx = 1;
while (iterator.hasNext()) {
Feature feature = iterator.next();
Assert.assertEquals("f" + idx, feature.getName());
Assert.assertEquals(i + idx / 10.0, feature.getValue(), 0.01);
++idx;
}
}
}
Aggregations