Search in sources :

Example 1 with ClusteringFactory

use of org.tribuo.clustering.ClusteringFactory in project ml-commons by opensearch-project.

the class KMeans method trainAndPredict.

@Override
public MLOutput trainAndPredict(DataFrame dataFrame) {
    MutableDataset<ClusterID> trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training and predicting data from opensearch", TribuoOutputType.CLUSTERID);
    Integer centroids = Optional.ofNullable(parameters.getCentroids()).orElse(DEFAULT_CENTROIDS);
    Integer iterations = Optional.ofNullable(parameters.getIterations()).orElse(DEFAULT_ITERATIONS);
    KMeansTrainer trainer = new KMeansTrainer(centroids, iterations, distance, numThreads, seed);
    // won't store model in index
    KMeansModel kMeansModel = trainer.train(trainDataset);
    List<Prediction<ClusterID>> predictions = kMeansModel.predict(trainDataset);
    List<Map<String, Object>> listClusterID = new ArrayList<>();
    predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e.getOutput().getID())));
    return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
}
Also used : KMeansModel(org.tribuo.clustering.kmeans.KMeansModel) ClusterID(org.tribuo.clustering.ClusterID) Prediction(org.tribuo.Prediction) ArrayList(java.util.ArrayList) KMeansTrainer(org.tribuo.clustering.kmeans.KMeansTrainer) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) Map(java.util.Map)

Example 2 with ClusteringFactory

use of org.tribuo.clustering.ClusteringFactory in project ml-commons by opensearch-project.

the class KMeans method predict.

@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
    if (model == null) {
        throw new IllegalArgumentException("No model found for KMeans prediction.");
    }
    List<Prediction<ClusterID>> predictions;
    MutableDataset<ClusterID> predictionDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans prediction data from opensearch", TribuoOutputType.CLUSTERID);
    KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent());
    predictions = kMeansModel.predict(predictionDataset);
    List<Map<String, Object>> listClusterID = new ArrayList<>();
    predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e.getOutput().getID())));
    return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
}
Also used : KMeansModel(org.tribuo.clustering.kmeans.KMeansModel) ClusterID(org.tribuo.clustering.ClusterID) Prediction(org.tribuo.Prediction) ArrayList(java.util.ArrayList) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) Map(java.util.Map)

Example 3 with ClusteringFactory

use of org.tribuo.clustering.ClusteringFactory in project ml-commons by opensearch-project.

the class KMeans method train.

@Override
public Model train(DataFrame dataFrame) {
    MutableDataset<ClusterID> trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training data from opensearch", TribuoOutputType.CLUSTERID);
    Integer centroids = Optional.ofNullable(parameters.getCentroids()).orElse(DEFAULT_CENTROIDS);
    Integer iterations = Optional.ofNullable(parameters.getIterations()).orElse(DEFAULT_ITERATIONS);
    KMeansTrainer trainer = new KMeansTrainer(centroids, iterations, distance, numThreads, seed);
    KMeansModel kMeansModel = trainer.train(trainDataset);
    Model model = new Model();
    model.setName(FunctionName.KMEANS.name());
    model.setVersion(1);
    model.setContent(ModelSerDeSer.serialize(kMeansModel));
    return model;
}
Also used : KMeansModel(org.tribuo.clustering.kmeans.KMeansModel) ClusterID(org.tribuo.clustering.ClusterID) Model(org.opensearch.ml.common.parameter.Model) KMeansModel(org.tribuo.clustering.kmeans.KMeansModel) KMeansTrainer(org.tribuo.clustering.kmeans.KMeansTrainer) ClusteringFactory(org.tribuo.clustering.ClusteringFactory)

Example 4 with ClusteringFactory

use of org.tribuo.clustering.ClusteringFactory in project ml-commons by opensearch-project.

the class TribuoUtilTest method generateDataset.

@SuppressWarnings("unchecked")
@Test
public void generateDataset() {
    MutableDataset<ClusterID> dataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "test", TribuoOutputType.CLUSTERID);
    List<Example<ClusterID>> examples = dataset.getData();
    Assert.assertEquals(rawData.length, examples.size());
    for (int i = 0; i < rawData.length; ++i) {
        ArrayExample arrayExample = (ArrayExample) examples.get(i);
        Iterator<Feature> iterator = arrayExample.iterator();
        int idx = 1;
        while (iterator.hasNext()) {
            Feature feature = iterator.next();
            Assert.assertEquals("f" + idx, feature.getName());
            Assert.assertEquals(i + idx / 10.0, feature.getValue(), 0.01);
            ++idx;
        }
    }
}
Also used : ArrayExample(org.tribuo.impl.ArrayExample) ClusterID(org.tribuo.clustering.ClusterID) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) Feature(org.tribuo.Feature) Test(org.junit.Test)

Aggregations

ClusterID (org.tribuo.clustering.ClusterID)4 ClusteringFactory (org.tribuo.clustering.ClusteringFactory)4 KMeansModel (org.tribuo.clustering.kmeans.KMeansModel)3 ArrayList (java.util.ArrayList)2 Map (java.util.Map)2 Prediction (org.tribuo.Prediction)2 KMeansTrainer (org.tribuo.clustering.kmeans.KMeansTrainer)2 Test (org.junit.Test)1 Model (org.opensearch.ml.common.parameter.Model)1 Example (org.tribuo.Example)1 Feature (org.tribuo.Feature)1 ArrayExample (org.tribuo.impl.ArrayExample)1