use of org.tribuo.evaluation.metrics.MetricID in project tribuo by oracle.
the class EvaluationAggregationTests method summarizeF1AcrossDatasets_v2.
@Test
public void summarizeF1AcrossDatasets_v2() {
Pair<Dataset<Label>, Dataset<Label>> pair = LabelledDataGenerator.denseTrainTest(-0.3);
Model<Label> model = DummyClassifierTrainer.createMostFrequentTrainer().train(pair.getA());
List<Dataset<Label>> datasets = Arrays.asList(LabelledDataGenerator.denseTrainTest(-1.0).getB(), LabelledDataGenerator.denseTrainTest(-0.5).getB(), LabelledDataGenerator.denseTrainTest(-0.1).getB());
Evaluator<Label, LabelEvaluation> evaluator = factory.getEvaluator();
Map<MetricID<Label>, DescriptiveStats> summaries = EvaluationAggregator.summarize(evaluator, model, datasets);
MetricID<Label> macroF1 = LabelMetrics.F1.forTarget(MetricTarget.macroAverageTarget()).getID();
DescriptiveStats summary = summaries.get(macroF1);
// Can also do this:
List<LabelEvaluation> evals = datasets.stream().map(dataset -> evaluator.evaluate(model, dataset)).collect(Collectors.toList());
Map<MetricID<Label>, DescriptiveStats> summaries2 = EvaluationAggregator.summarize(evals);
assertEquals(summaries, summaries2);
}
use of org.tribuo.evaluation.metrics.MetricID in project tribuo by oracle.
the class AbstractEvaluator method evaluate.
// "template method"
/**
* Produces an evaluation for the supplied model and predictions by aggregating the appropriate statistics.
* <p>
* Warning, this method cannot validate that the predictions were returned by the model in question.
* @param model The model to use.
* @param predictions The predictions to use.
* @param dataProvenance The provenance of the test data.
* @return An evaluation of the predictions.
*/
@Override
public final E evaluate(Model<T> model, List<Prediction<T>> predictions, DataProvenance dataProvenance) {
//
// Create the provenance for the model and dataset
EvaluationProvenance provenance = new EvaluationProvenance(model.getProvenance(), dataProvenance);
//
// Create an evaluation context. The context stores all the information needed by the list of metrics plus might
// cache intermediate computation relevant to multiple metrics (e.g., a pre-computed confusion matrix might be stored in 'context')
C context = createContext(model, predictions);
//
// "MODEL": Build the list of metrics to compute.
Set<? extends EvaluationMetric<T, C>> metrics = createMetrics(model);
//
// "CONTROLLER": For each metric in the list, compute the result.
Map<MetricID<T>, Double> results = computeResults(context, metrics);
// "VIEW": Create an evaluation to store the results and provide a "view" of the results to users
return createEvaluation(context, results, provenance);
}
use of org.tribuo.evaluation.metrics.MetricID in project tribuo by oracle.
the class AbstractSequenceEvaluator method evaluate.
// "template method"
/**
* Produces an evaluation for the supplied model and predictions by aggregating the appropriate statistics.
* <p>
* Warning, this method cannot validate that the predictions were returned by the model in question.
* @param model The model to use.
* @param predictions The predictions to use.
* @param dataProvenance The provenance of the test data.
* @return An evaluation of the predictions.
*/
@Override
public final E evaluate(SequenceModel<T> model, List<List<Prediction<T>>> predictions, DataProvenance dataProvenance) {
//
// Create the provenance for the model and dataset
EvaluationProvenance provenance = new EvaluationProvenance(model.getProvenance(), dataProvenance);
//
// Create an evaluation context. The context stores all the information needed by the list of metrics plus might
// cache intermediate computation relevant to multiple metrics (e.g., a pre-computed confusion matrix might be stored in 'context')
C context = createContext(model, predictions);
//
// "MODEL": Build the list of metrics to compute.
Set<? extends EvaluationMetric<T, C>> metrics = createMetrics(model);
//
// "CONTROLLER": For each metric in the list, compute the result.
Map<MetricID<T>, Double> results = computeResults(context, metrics);
// "VIEW": Create an evaluation to store the results and provide a "view" of the results to users
return createEvaluation(context, results, provenance);
}
use of org.tribuo.evaluation.metrics.MetricID in project tribuo by oracle.
the class ConfigurableTrainTest method main.
/**
* @param args the command line arguments
* @param <T> The {@link Output} subclass.
*/
@SuppressWarnings("unchecked")
public static <T extends Output<T>> void main(String[] args) {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null || o.general.testingPath == null || o.outputFactory == null) {
logger.info(cm.usage());
System.exit(1);
}
Pair<Dataset<T>, Dataset<T>> data = null;
try {
data = o.general.load((OutputFactory<T>) o.outputFactory);
} catch (IOException e) {
logger.log(Level.SEVERE, "Failed to load data", e);
System.exit(1);
}
Dataset<T> train = data.getA();
Dataset<T> test = data.getB();
if (o.trainer == null) {
logger.warning("No trainer supplied");
logger.info(cm.usage());
System.exit(1);
}
if (o.transformationMap != null) {
o.trainer = new TransformTrainer<>(o.trainer, o.transformationMap);
}
logger.info("Trainer is " + o.trainer.getProvenance().toString());
logger.info("Outputs are " + train.getOutputInfo().toReadableString());
logger.info("Number of features: " + train.getFeatureMap().size());
final long trainStart = System.currentTimeMillis();
Model<T> model = ((Trainer<T>) o.trainer).train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
Evaluator<T, ? extends Evaluation<T>> evaluator = train.getOutputFactory().getEvaluator();
final long testStart = System.currentTimeMillis();
Evaluation<T> evaluation = evaluator.evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
if (o.general.outputPath != null) {
try {
o.general.saveModel(model);
} catch (IOException e) {
logger.log(Level.SEVERE, "Error writing model", e);
}
}
if (o.crossValidation) {
if (o.numFolds > 1) {
logger.info("Running " + o.numFolds + " fold cross-validation");
CrossValidation<T, ? extends Evaluation<T>> cv = new CrossValidation<>((Trainer<T>) o.trainer, train, evaluator, o.numFolds, o.general.seed);
List<? extends Pair<? extends Evaluation<T>, Model<T>>> evaluations = cv.evaluate();
List<Evaluation<T>> evals = evaluations.stream().map(Pair::getA).collect(Collectors.toList());
// Summarize across everything
Map<MetricID<T>, DescriptiveStats> summary = EvaluationAggregator.summarize(evals);
List<MetricID<T>> keys = new ArrayList<>(summary.keySet()).stream().sorted(Comparator.comparing(Pair::getB)).collect(Collectors.toList());
System.out.println("Summary across the folds:");
for (MetricID<T> key : keys) {
DescriptiveStats stats = summary.get(key);
System.out.printf("%-10s %.5f (%.5f)%n", key, stats.getMean(), stats.getStandardDeviation());
}
} else {
logger.warning("The number of cross-validation folds must be greater than 1, found " + o.numFolds);
}
}
}
use of org.tribuo.evaluation.metrics.MetricID in project tribuo by oracle.
the class EvaluationAggregationTests method xval.
public static void xval() {
Trainer<Label> trainer = DummyClassifierTrainer.createUniformTrainer(1L);
Pair<Dataset<Label>, Dataset<Label>> datasets = LabelledDataGenerator.denseTrainTest();
Dataset<Label> trainData = datasets.getA();
Evaluator<Label, LabelEvaluation> evaluator = factory.getEvaluator();
CrossValidation<Label, LabelEvaluation> xval = new CrossValidation<>(trainer, trainData, evaluator, 5);
List<Pair<LabelEvaluation, Model<Label>>> results = xval.evaluate();
List<LabelEvaluation> evals = results.stream().map(Pair::getA).collect(Collectors.toList());
// Summarize across everything
Map<MetricID<Label>, DescriptiveStats> summary = EvaluationAggregator.summarize(evals);
List<MetricID<Label>> keys = new ArrayList<>(summary.keySet()).stream().sorted(Comparator.comparing(Pair::getB)).collect(Collectors.toList());
for (MetricID<Label> key : keys) {
DescriptiveStats stats = summary.get(key);
out.printf("%-10s %.5f (%.5f)%n", key, stats.getMean(), stats.getStandardDeviation());
}
// Summarize across macro F1s only
DescriptiveStats macroF1Summary = EvaluationAggregator.summarize(evals, LabelEvaluation::macroAveragedF1);
out.println(macroF1Summary);
Pair<Integer, Double> argmax = EvaluationAggregator.argmax(evals, LabelEvaluation::macroAveragedF1);
Model<Label> bestF1 = results.get(argmax.getA()).getB();
LabelEvaluation testEval = evaluator.evaluate(bestF1, datasets.getB());
System.out.println(testEval);
}
Aggregations