Search in sources :

Example 1 with MetricID

use of org.tribuo.evaluation.metrics.MetricID in project tribuo by oracle.

the class EvaluationAggregationTests method summarizeF1AcrossDatasets_v2.

@Test
public void summarizeF1AcrossDatasets_v2() {
    Pair<Dataset<Label>, Dataset<Label>> pair = LabelledDataGenerator.denseTrainTest(-0.3);
    Model<Label> model = DummyClassifierTrainer.createMostFrequentTrainer().train(pair.getA());
    List<Dataset<Label>> datasets = Arrays.asList(LabelledDataGenerator.denseTrainTest(-1.0).getB(), LabelledDataGenerator.denseTrainTest(-0.5).getB(), LabelledDataGenerator.denseTrainTest(-0.1).getB());
    Evaluator<Label, LabelEvaluation> evaluator = factory.getEvaluator();
    Map<MetricID<Label>, DescriptiveStats> summaries = EvaluationAggregator.summarize(evaluator, model, datasets);
    MetricID<Label> macroF1 = LabelMetrics.F1.forTarget(MetricTarget.macroAverageTarget()).getID();
    DescriptiveStats summary = summaries.get(macroF1);
    // Can also do this:
    List<LabelEvaluation> evals = datasets.stream().map(dataset -> evaluator.evaluate(model, dataset)).collect(Collectors.toList());
    Map<MetricID<Label>, DescriptiveStats> summaries2 = EvaluationAggregator.summarize(evals);
    assertEquals(summaries, summaries2);
}
Also used : MetricTarget(org.tribuo.evaluation.metrics.MetricTarget) Arrays(java.util.Arrays) Evaluator(org.tribuo.evaluation.Evaluator) Prediction(org.tribuo.Prediction) Model(org.tribuo.Model) EvaluationAggregator(org.tribuo.evaluation.EvaluationAggregator) Pair(com.oracle.labs.mlrg.olcut.util.Pair) Collectors(java.util.stream.Collectors) MetricID(org.tribuo.evaluation.metrics.MetricID) System.out(java.lang.System.out) ArrayList(java.util.ArrayList) Dataset(org.tribuo.Dataset) Test(org.junit.jupiter.api.Test) Trainer(org.tribuo.Trainer) DummyClassifierTrainer(org.tribuo.classification.baseline.DummyClassifierTrainer) List(java.util.List) LabelFactory(org.tribuo.classification.LabelFactory) Map(java.util.Map) DescriptiveStats(org.tribuo.evaluation.DescriptiveStats) LabelledDataGenerator(org.tribuo.classification.example.LabelledDataGenerator) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) Comparator(java.util.Comparator) Label(org.tribuo.classification.Label) CrossValidation(org.tribuo.evaluation.CrossValidation) Dataset(org.tribuo.Dataset) Label(org.tribuo.classification.Label) MetricID(org.tribuo.evaluation.metrics.MetricID) DescriptiveStats(org.tribuo.evaluation.DescriptiveStats) Test(org.junit.jupiter.api.Test)

Example 2 with MetricID

use of org.tribuo.evaluation.metrics.MetricID in project tribuo by oracle.

the class AbstractEvaluator method evaluate.

// "template method"
/**
 * Produces an evaluation for the supplied model and predictions by aggregating the appropriate statistics.
 * <p>
 * Warning, this method cannot validate that the predictions were returned by the model in question.
 * @param model The model to use.
 * @param predictions The predictions to use.
 * @param dataProvenance The provenance of the test data.
 * @return An evaluation of the predictions.
 */
@Override
public final E evaluate(Model<T> model, List<Prediction<T>> predictions, DataProvenance dataProvenance) {
    // 
    // Create the provenance for the model and dataset
    EvaluationProvenance provenance = new EvaluationProvenance(model.getProvenance(), dataProvenance);
    // 
    // Create an evaluation context. The context stores all the information needed by the list of metrics plus might
    // cache intermediate computation relevant to multiple metrics (e.g., a pre-computed confusion matrix might be stored in 'context')
    C context = createContext(model, predictions);
    // 
    // "MODEL": Build the list of metrics to compute.
    Set<? extends EvaluationMetric<T, C>> metrics = createMetrics(model);
    // 
    // "CONTROLLER": For each metric in the list, compute the result.
    Map<MetricID<T>, Double> results = computeResults(context, metrics);
    // "VIEW": Create an evaluation to store the results and provide a "view" of the results to users
    return createEvaluation(context, results, provenance);
}
Also used : MetricID(org.tribuo.evaluation.metrics.MetricID) EvaluationProvenance(org.tribuo.provenance.EvaluationProvenance)

Example 3 with MetricID

use of org.tribuo.evaluation.metrics.MetricID in project tribuo by oracle.

the class AbstractSequenceEvaluator method evaluate.

// "template method"
/**
 * Produces an evaluation for the supplied model and predictions by aggregating the appropriate statistics.
 * <p>
 * Warning, this method cannot validate that the predictions were returned by the model in question.
 * @param model The model to use.
 * @param predictions The predictions to use.
 * @param dataProvenance The provenance of the test data.
 * @return An evaluation of the predictions.
 */
@Override
public final E evaluate(SequenceModel<T> model, List<List<Prediction<T>>> predictions, DataProvenance dataProvenance) {
    // 
    // Create the provenance for the model and dataset
    EvaluationProvenance provenance = new EvaluationProvenance(model.getProvenance(), dataProvenance);
    // 
    // Create an evaluation context. The context stores all the information needed by the list of metrics plus might
    // cache intermediate computation relevant to multiple metrics (e.g., a pre-computed confusion matrix might be stored in 'context')
    C context = createContext(model, predictions);
    // 
    // "MODEL": Build the list of metrics to compute.
    Set<? extends EvaluationMetric<T, C>> metrics = createMetrics(model);
    // 
    // "CONTROLLER": For each metric in the list, compute the result.
    Map<MetricID<T>, Double> results = computeResults(context, metrics);
    // "VIEW": Create an evaluation to store the results and provide a "view" of the results to users
    return createEvaluation(context, results, provenance);
}
Also used : MetricID(org.tribuo.evaluation.metrics.MetricID) EvaluationProvenance(org.tribuo.provenance.EvaluationProvenance)

Example 4 with MetricID

use of org.tribuo.evaluation.metrics.MetricID in project tribuo by oracle.

the class ConfigurableTrainTest method main.

/**
 * @param args the command line arguments
 * @param <T> The {@link Output} subclass.
 */
@SuppressWarnings("unchecked")
public static <T extends Output<T>> void main(String[] args) {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null || o.outputFactory == null) {
        logger.info(cm.usage());
        System.exit(1);
    }
    Pair<Dataset<T>, Dataset<T>> data = null;
    try {
        data = o.general.load((OutputFactory<T>) o.outputFactory);
    } catch (IOException e) {
        logger.log(Level.SEVERE, "Failed to load data", e);
        System.exit(1);
    }
    Dataset<T> train = data.getA();
    Dataset<T> test = data.getB();
    if (o.trainer == null) {
        logger.warning("No trainer supplied");
        logger.info(cm.usage());
        System.exit(1);
    }
    if (o.transformationMap != null) {
        o.trainer = new TransformTrainer<>(o.trainer, o.transformationMap);
    }
    logger.info("Trainer is " + o.trainer.getProvenance().toString());
    logger.info("Outputs are " + train.getOutputInfo().toReadableString());
    logger.info("Number of features: " + train.getFeatureMap().size());
    final long trainStart = System.currentTimeMillis();
    Model<T> model = ((Trainer<T>) o.trainer).train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
    Evaluator<T, ? extends Evaluation<T>> evaluator = train.getOutputFactory().getEvaluator();
    final long testStart = System.currentTimeMillis();
    Evaluation<T> evaluation = evaluator.evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    if (o.general.outputPath != null) {
        try {
            o.general.saveModel(model);
        } catch (IOException e) {
            logger.log(Level.SEVERE, "Error writing model", e);
        }
    }
    if (o.crossValidation) {
        if (o.numFolds > 1) {
            logger.info("Running " + o.numFolds + " fold cross-validation");
            CrossValidation<T, ? extends Evaluation<T>> cv = new CrossValidation<>((Trainer<T>) o.trainer, train, evaluator, o.numFolds, o.general.seed);
            List<? extends Pair<? extends Evaluation<T>, Model<T>>> evaluations = cv.evaluate();
            List<Evaluation<T>> evals = evaluations.stream().map(Pair::getA).collect(Collectors.toList());
            // Summarize across everything
            Map<MetricID<T>, DescriptiveStats> summary = EvaluationAggregator.summarize(evals);
            List<MetricID<T>> keys = new ArrayList<>(summary.keySet()).stream().sorted(Comparator.comparing(Pair::getB)).collect(Collectors.toList());
            System.out.println("Summary across the folds:");
            for (MetricID<T> key : keys) {
                DescriptiveStats stats = summary.get(key);
                System.out.printf("%-10s  %.5f (%.5f)%n", key, stats.getMean(), stats.getStandardDeviation());
            }
        } else {
            logger.warning("The number of cross-validation folds must be greater than 1, found " + o.numFolds);
        }
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) TransformTrainer(org.tribuo.transform.TransformTrainer) Trainer(org.tribuo.Trainer) MetricID(org.tribuo.evaluation.metrics.MetricID) DescriptiveStats(org.tribuo.evaluation.DescriptiveStats) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) Pair(com.oracle.labs.mlrg.olcut.util.Pair) Evaluation(org.tribuo.evaluation.Evaluation) Dataset(org.tribuo.Dataset) IOException(java.io.IOException) Model(org.tribuo.Model) CrossValidation(org.tribuo.evaluation.CrossValidation) OutputFactory(org.tribuo.OutputFactory)

Example 5 with MetricID

use of org.tribuo.evaluation.metrics.MetricID in project tribuo by oracle.

the class EvaluationAggregationTests method xval.

public static void xval() {
    Trainer<Label> trainer = DummyClassifierTrainer.createUniformTrainer(1L);
    Pair<Dataset<Label>, Dataset<Label>> datasets = LabelledDataGenerator.denseTrainTest();
    Dataset<Label> trainData = datasets.getA();
    Evaluator<Label, LabelEvaluation> evaluator = factory.getEvaluator();
    CrossValidation<Label, LabelEvaluation> xval = new CrossValidation<>(trainer, trainData, evaluator, 5);
    List<Pair<LabelEvaluation, Model<Label>>> results = xval.evaluate();
    List<LabelEvaluation> evals = results.stream().map(Pair::getA).collect(Collectors.toList());
    // Summarize across everything
    Map<MetricID<Label>, DescriptiveStats> summary = EvaluationAggregator.summarize(evals);
    List<MetricID<Label>> keys = new ArrayList<>(summary.keySet()).stream().sorted(Comparator.comparing(Pair::getB)).collect(Collectors.toList());
    for (MetricID<Label> key : keys) {
        DescriptiveStats stats = summary.get(key);
        out.printf("%-10s  %.5f (%.5f)%n", key, stats.getMean(), stats.getStandardDeviation());
    }
    // Summarize across macro F1s only
    DescriptiveStats macroF1Summary = EvaluationAggregator.summarize(evals, LabelEvaluation::macroAveragedF1);
    out.println(macroF1Summary);
    Pair<Integer, Double> argmax = EvaluationAggregator.argmax(evals, LabelEvaluation::macroAveragedF1);
    Model<Label> bestF1 = results.get(argmax.getA()).getB();
    LabelEvaluation testEval = evaluator.evaluate(bestF1, datasets.getB());
    System.out.println(testEval);
}
Also used : Dataset(org.tribuo.Dataset) Label(org.tribuo.classification.Label) MetricID(org.tribuo.evaluation.metrics.MetricID) DescriptiveStats(org.tribuo.evaluation.DescriptiveStats) CrossValidation(org.tribuo.evaluation.CrossValidation) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Aggregations

MetricID (org.tribuo.evaluation.metrics.MetricID)6 Dataset (org.tribuo.Dataset)4 DescriptiveStats (org.tribuo.evaluation.DescriptiveStats)4 Pair (com.oracle.labs.mlrg.olcut.util.Pair)3 Model (org.tribuo.Model)3 Label (org.tribuo.classification.Label)3 CrossValidation (org.tribuo.evaluation.CrossValidation)3 Trainer (org.tribuo.Trainer)2 EvaluationProvenance (org.tribuo.provenance.EvaluationProvenance)2 ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)1 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)1 IOException (java.io.IOException)1 System.out (java.lang.System.out)1 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 Comparator (java.util.Comparator)1 List (java.util.List)1 Map (java.util.Map)1 Collectors (java.util.stream.Collectors)1 Assertions.assertEquals (org.junit.jupiter.api.Assertions.assertEquals)1