Search in sources :

Example 1 with ClusteringEvaluation

use of org.tribuo.clustering.evaluation.ClusteringEvaluation in project tribuo by oracle.

the class TestHdbscan method runEvaluation.

public static void runEvaluation(HdbscanTrainer trainer) {
    DataSource<ClusterID> gaussianSource = new GaussianClusterDataSource(1000, 1L);
    TrainTestSplitter<ClusterID> splitter = new TrainTestSplitter<>(gaussianSource, 0.7f, 2L);
    Dataset<ClusterID> trainData = new MutableDataset<>(splitter.getTrain());
    Dataset<ClusterID> testData = new MutableDataset<>(splitter.getTest());
    ClusteringEvaluator eval = new ClusteringEvaluator();
    HdbscanModel model = trainer.train(trainData);
    // Test serialization
    Helpers.testModelSerialization(model, ClusterID.class);
    ClusteringEvaluation trainEvaluation = eval.evaluate(model, trainData);
    assertFalse(Double.isNaN(trainEvaluation.adjustedMI()));
    assertFalse(Double.isNaN(trainEvaluation.normalizedMI()));
    ClusteringEvaluation testEvaluation = eval.evaluate(model, testData);
    assertFalse(Double.isNaN(testEvaluation.adjustedMI()));
    assertFalse(Double.isNaN(testEvaluation.normalizedMI()));
}
Also used : TrainTestSplitter(org.tribuo.evaluation.TrainTestSplitter) ClusterID(org.tribuo.clustering.ClusterID) ClusteringEvaluator(org.tribuo.clustering.evaluation.ClusteringEvaluator) GaussianClusterDataSource(org.tribuo.clustering.example.GaussianClusterDataSource) MutableDataset(org.tribuo.MutableDataset) ClusteringEvaluation(org.tribuo.clustering.evaluation.ClusteringEvaluation)

Example 2 with ClusteringEvaluation

use of org.tribuo.clustering.evaluation.ClusteringEvaluation in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    KMeansOptions o = new KMeansOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null) {
        logger.info(cm.usage());
        return;
    }
    ClusteringFactory factory = new ClusteringFactory();
    Pair<Dataset<ClusterID>, Dataset<ClusterID>> data = o.general.load(factory);
    Dataset<ClusterID> train = data.getA();
    // public KMeansTrainer(int centroids, int iterations, DistanceType distType, int numThreads, int seed)
    KMeansTrainer trainer = new KMeansTrainer(o.centroids, o.iterations, o.distType, o.initialisation, o.numThreads, o.general.seed);
    Model<ClusterID> model = trainer.train(train);
    logger.info("Finished training model");
    ClusteringEvaluation evaluation = factory.getEvaluator().evaluate(model, train);
    logger.info("Finished evaluating model");
    System.out.println("Normalized MI = " + evaluation.normalizedMI());
    System.out.println("Adjusted MI = " + evaluation.adjustedMI());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) ClusterID(org.tribuo.clustering.ClusterID) Dataset(org.tribuo.Dataset) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ClusteringEvaluation(org.tribuo.clustering.evaluation.ClusteringEvaluation)

Example 3 with ClusteringEvaluation

use of org.tribuo.clustering.evaluation.ClusteringEvaluation in project tribuo by oracle.

the class TestKMeans method runEvaluation.

public static void runEvaluation(KMeansTrainer trainer) {
    Dataset<ClusterID> data = new MutableDataset<>(new GaussianClusterDataSource(500, 1L));
    Dataset<ClusterID> test = ClusteringDataGenerator.gaussianClusters(500, 2L);
    ClusteringEvaluator eval = new ClusteringEvaluator();
    KMeansModel model = trainer.train(data);
    ClusteringEvaluation trainEvaluation = eval.evaluate(model, data);
    assertFalse(Double.isNaN(trainEvaluation.adjustedMI()));
    assertFalse(Double.isNaN(trainEvaluation.normalizedMI()));
    ClusteringEvaluation testEvaluation = eval.evaluate(model, test);
    assertFalse(Double.isNaN(testEvaluation.adjustedMI()));
    assertFalse(Double.isNaN(testEvaluation.normalizedMI()));
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) ClusteringEvaluator(org.tribuo.clustering.evaluation.ClusteringEvaluator) GaussianClusterDataSource(org.tribuo.clustering.example.GaussianClusterDataSource) MutableDataset(org.tribuo.MutableDataset) ClusteringEvaluation(org.tribuo.clustering.evaluation.ClusteringEvaluation)

Example 4 with ClusteringEvaluation

use of org.tribuo.clustering.evaluation.ClusteringEvaluation in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    HdbscanCLIOptions o = new HdbscanCLIOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null) {
        logger.info(cm.usage());
        return;
    }
    ClusteringFactory factory = new ClusteringFactory();
    Pair<Dataset<ClusterID>, Dataset<ClusterID>> data = o.general.load(factory);
    Dataset<ClusterID> train = data.getA();
    HdbscanTrainer trainer = o.hdbscanOptions.getTrainer();
    Model<ClusterID> model = trainer.train(train);
    logger.info("Finished training model");
    ClusteringEvaluation evaluation = factory.getEvaluator().evaluate(model, train);
    logger.info("Finished evaluating model");
    System.out.println("Normalized MI = " + evaluation.normalizedMI());
    System.out.println("Adjusted MI = " + evaluation.adjustedMI());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) ClusterID(org.tribuo.clustering.ClusterID) Dataset(org.tribuo.Dataset) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ClusteringEvaluation(org.tribuo.clustering.evaluation.ClusteringEvaluation)

Aggregations

ClusterID (org.tribuo.clustering.ClusterID)4 ClusteringEvaluation (org.tribuo.clustering.evaluation.ClusteringEvaluation)4 ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)2 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)2 Dataset (org.tribuo.Dataset)2 MutableDataset (org.tribuo.MutableDataset)2 ClusteringFactory (org.tribuo.clustering.ClusteringFactory)2 ClusteringEvaluator (org.tribuo.clustering.evaluation.ClusteringEvaluator)2 GaussianClusterDataSource (org.tribuo.clustering.example.GaussianClusterDataSource)2 TrainTestSplitter (org.tribuo.evaluation.TrainTestSplitter)1