Search in sources :

Example 1 with ClusterID

use of org.tribuo.clustering.ClusterID in project ml-commons by opensearch-project.

the class TribuoUtil method generateDataset.

/**
 * Generate tribuo dataset from data frame.
 * @param dataFrame features data
 * @param outputFactory the tribuo output factory
 * @param desc description for tribuo provenance
 * @param outputType the tribuo output type
 * @return tribuo dataset
 */
public static <T extends Output<T>> MutableDataset<T> generateDataset(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType) {
    List<Example<T>> dataset = new ArrayList<>();
    Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
    ArrayExample<T> example;
    for (int i = 0; i < dataFrame.size(); ++i) {
        switch(outputType) {
            case CLUSTERID:
                example = new ArrayExample<>((T) new ClusterID(ClusterID.UNASSIGNED), featureNamesValues.v1(), featureNamesValues.v2()[i]);
                break;
            case REGRESSOR:
                // Create single dimension tribuo regressor with name DIM-0 and value double NaN.
                example = new ArrayExample<>((T) new Regressor("DIM-0", Double.NaN), featureNamesValues.v1(), featureNamesValues.v2()[i]);
                break;
            case ANOMALY_DETECTION_LIBSVM:
                // Why we set default event type as EXPECTED(non-anomalous)
                // 1. For training data, Tribuo LibSVMAnomalyTrainer only supports EXPECTED events at training time.
                // 2. For prediction data, we treat the data as non-anomalous by default as Tribuo lib don't accept UNKNOWN type.
                Event.EventType defaultEventType = Event.EventType.EXPECTED;
                // TODO: support anomaly labels to evaluate prediction result
                example = new ArrayExample<>((T) new Event(defaultEventType), featureNamesValues.v1(), featureNamesValues.v2()[i]);
                break;
            default:
                throw new IllegalArgumentException("unknown type:" + outputType);
        }
        dataset.add(example);
    }
    SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
    return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) ArrayList(java.util.ArrayList) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Event(org.tribuo.anomaly.Event) Regressor(org.tribuo.regression.Regressor) MutableDataset(org.tribuo.MutableDataset)

Example 2 with ClusterID

use of org.tribuo.clustering.ClusterID in project ml-commons by opensearch-project.

the class KMeans method trainAndPredict.

@Override
public MLOutput trainAndPredict(DataFrame dataFrame) {
    MutableDataset<ClusterID> trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training and predicting data from opensearch", TribuoOutputType.CLUSTERID);
    Integer centroids = Optional.ofNullable(parameters.getCentroids()).orElse(DEFAULT_CENTROIDS);
    Integer iterations = Optional.ofNullable(parameters.getIterations()).orElse(DEFAULT_ITERATIONS);
    KMeansTrainer trainer = new KMeansTrainer(centroids, iterations, distance, numThreads, seed);
    // won't store model in index
    KMeansModel kMeansModel = trainer.train(trainDataset);
    List<Prediction<ClusterID>> predictions = kMeansModel.predict(trainDataset);
    List<Map<String, Object>> listClusterID = new ArrayList<>();
    predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e.getOutput().getID())));
    return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
}
Also used : KMeansModel(org.tribuo.clustering.kmeans.KMeansModel) ClusterID(org.tribuo.clustering.ClusterID) Prediction(org.tribuo.Prediction) ArrayList(java.util.ArrayList) KMeansTrainer(org.tribuo.clustering.kmeans.KMeansTrainer) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) Map(java.util.Map)

Example 3 with ClusterID

use of org.tribuo.clustering.ClusterID in project ml-commons by opensearch-project.

the class KMeans method predict.

@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
    if (model == null) {
        throw new IllegalArgumentException("No model found for KMeans prediction.");
    }
    List<Prediction<ClusterID>> predictions;
    MutableDataset<ClusterID> predictionDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans prediction data from opensearch", TribuoOutputType.CLUSTERID);
    KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent());
    predictions = kMeansModel.predict(predictionDataset);
    List<Map<String, Object>> listClusterID = new ArrayList<>();
    predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e.getOutput().getID())));
    return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
}
Also used : KMeansModel(org.tribuo.clustering.kmeans.KMeansModel) ClusterID(org.tribuo.clustering.ClusterID) Prediction(org.tribuo.Prediction) ArrayList(java.util.ArrayList) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) Map(java.util.Map)

Example 4 with ClusterID

use of org.tribuo.clustering.ClusterID in project ml-commons by opensearch-project.

the class KMeans method train.

@Override
public Model train(DataFrame dataFrame) {
    MutableDataset<ClusterID> trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training data from opensearch", TribuoOutputType.CLUSTERID);
    Integer centroids = Optional.ofNullable(parameters.getCentroids()).orElse(DEFAULT_CENTROIDS);
    Integer iterations = Optional.ofNullable(parameters.getIterations()).orElse(DEFAULT_ITERATIONS);
    KMeansTrainer trainer = new KMeansTrainer(centroids, iterations, distance, numThreads, seed);
    KMeansModel kMeansModel = trainer.train(trainDataset);
    Model model = new Model();
    model.setName(FunctionName.KMEANS.name());
    model.setVersion(1);
    model.setContent(ModelSerDeSer.serialize(kMeansModel));
    return model;
}
Also used : KMeansModel(org.tribuo.clustering.kmeans.KMeansModel) ClusterID(org.tribuo.clustering.ClusterID) Model(org.opensearch.ml.common.parameter.Model) KMeansModel(org.tribuo.clustering.kmeans.KMeansModel) KMeansTrainer(org.tribuo.clustering.kmeans.KMeansTrainer) ClusteringFactory(org.tribuo.clustering.ClusteringFactory)

Example 5 with ClusterID

use of org.tribuo.clustering.ClusterID in project ml-commons by opensearch-project.

the class TribuoUtilTest method generateDataset.

@SuppressWarnings("unchecked")
@Test
public void generateDataset() {
    MutableDataset<ClusterID> dataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "test", TribuoOutputType.CLUSTERID);
    List<Example<ClusterID>> examples = dataset.getData();
    Assert.assertEquals(rawData.length, examples.size());
    for (int i = 0; i < rawData.length; ++i) {
        ArrayExample arrayExample = (ArrayExample) examples.get(i);
        Iterator<Feature> iterator = arrayExample.iterator();
        int idx = 1;
        while (iterator.hasNext()) {
            Feature feature = iterator.next();
            Assert.assertEquals("f" + idx, feature.getName());
            Assert.assertEquals(i + idx / 10.0, feature.getValue(), 0.01);
            ++idx;
        }
    }
}
Also used : ArrayExample(org.tribuo.impl.ArrayExample) ClusterID(org.tribuo.clustering.ClusterID) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) Feature(org.tribuo.Feature) Test(org.junit.Test)

Aggregations

ClusterID (org.tribuo.clustering.ClusterID)5 ClusteringFactory (org.tribuo.clustering.ClusteringFactory)4 ArrayList (java.util.ArrayList)3 KMeansModel (org.tribuo.clustering.kmeans.KMeansModel)3 Map (java.util.Map)2 Example (org.tribuo.Example)2 Prediction (org.tribuo.Prediction)2 KMeansTrainer (org.tribuo.clustering.kmeans.KMeansTrainer)2 ArrayExample (org.tribuo.impl.ArrayExample)2 Test (org.junit.Test)1 Model (org.opensearch.ml.common.parameter.Model)1 Feature (org.tribuo.Feature)1 MutableDataset (org.tribuo.MutableDataset)1 Event (org.tribuo.anomaly.Event)1 SimpleDataSourceProvenance (org.tribuo.provenance.SimpleDataSourceProvenance)1 Regressor (org.tribuo.regression.Regressor)1