use of org.tribuo.classification.evaluation.LabelEvaluator in project tribuo by oracle.
the class TestXGBoost method testXGBoost.
public static Model<Label> testXGBoost(XGBoostClassificationTrainer trainer, Pair<Dataset<Label>, Dataset<Label>> p) {
Model<Label> m = trainer.train(p.getA());
LabelEvaluator e = new LabelEvaluator();
LabelEvaluation evaluation = e.evaluate(m, p.getB());
Map<String, List<Pair<String, Double>>> features = m.getTopFeatures(3);
Assertions.assertNotNull(features);
Assertions.assertFalse(features.isEmpty());
features = m.getTopFeatures(-1);
Assertions.assertNotNull(features);
Assertions.assertFalse(features.isEmpty());
return m;
}
use of org.tribuo.classification.evaluation.LabelEvaluator in project tribuo by oracle.
the class TestFMClassification method testFMClassification.
public static Model<Label> testFMClassification(Pair<Dataset<Label>, Dataset<Label>> p) {
Model<Label> m = t.train(p.getA());
LabelEvaluator e = new LabelEvaluator();
LabelEvaluation evaluation = e.evaluate(m, p.getB());
Map<String, List<Pair<String, Double>>> features = m.getTopFeatures(3);
assertNotNull(features);
Assertions.assertFalse(features.isEmpty());
features = m.getTopFeatures(-1);
assertNotNull(features);
Assertions.assertFalse(features.isEmpty());
return m;
}
use of org.tribuo.classification.evaluation.LabelEvaluator in project tribuo by oracle.
the class TestDummyClassifier method testDummyClassifier.
public static void testDummyClassifier(Pair<Dataset<Label>, Dataset<Label>> p, boolean testModelSave) {
for (Trainer<Label> t : trainers) {
Model<Label> m = t.train(p.getA());
Evaluator<Label, LabelEvaluation> evaluator = new LabelEvaluator();
LabelEvaluation evaluation = evaluator.evaluate(m, p.getB());
if (testModelSave) {
Helpers.testModelSerialization(m, Label.class);
}
}
}
use of org.tribuo.classification.evaluation.LabelEvaluator in project tribuo by oracle.
the class ClassificationTest method classificationCNNTest.
@Test
public void classificationCNNTest() throws IOException {
// Create the train and test data
Pair<Dataset<Label>, Dataset<Label>> data = generateImageData(512, 10, 128, 5, 42);
Dataset<Label> trainData = data.getA();
Dataset<Label> testData = data.getB();
// Build the CNN graph
GraphDefTuple graphDefTuple = CNNExamples.buildLeNetGraph(INPUT_NAME, 10, 255, trainData.getOutputs().size());
// Configure the trainer
Map<String, Float> gradientParams = new HashMap<>();
gradientParams.put("learningRate", 0.01f);
gradientParams.put("initialAccumulatorValue", 0.1f);
FeatureConverter imageConverter = new ImageConverter(INPUT_NAME, 10, 10, 1);
OutputConverter<Label> outputConverter = new LabelConverter();
TensorFlowTrainer<Label> trainer = new TensorFlowTrainer<>(graphDefTuple.graphDef, graphDefTuple.outputName, GradientOptimiser.ADAGRAD, gradientParams, imageConverter, outputConverter, 16, 5, 16, -1);
// Train the model
TensorFlowModel<Label> model = trainer.train(trainData);
// Make some predictions
List<Prediction<Label>> predictions = model.predict(testData);
// Run smoke test evaluation
LabelEvaluation eval = new LabelEvaluator().evaluate(model, predictions, testData.getProvenance());
Assertions.assertTrue(eval.accuracy() > 0.0);
// Check Tribuo serialization
Helpers.testModelSerialization(model, Label.class);
// Check saved model bundle export
Path outputPath = Files.createTempDirectory("tf-classification-cnn-test");
model.exportModel(outputPath.toString());
try (Stream<Path> f = Files.list(outputPath)) {
List<Path> files = f.collect(Collectors.toList());
Assertions.assertNotEquals(0, files.size());
}
// Create external model from bundle
Map<Label, Integer> outputMapping = new HashMap<>();
for (Pair<Integer, Label> p : model.getOutputIDInfo()) {
outputMapping.put(p.getB(), p.getA());
}
Map<String, Integer> featureMapping = new HashMap<>();
ImmutableFeatureMap featureIDMap = model.getFeatureIDMap();
for (VariableInfo info : featureIDMap) {
featureMapping.put(info.getName(), featureIDMap.getID(info.getName()));
}
TensorFlowSavedModelExternalModel<Label> externalModel = TensorFlowSavedModelExternalModel.createTensorflowModel(trainData.getOutputFactory(), featureMapping, outputMapping, model.getOutputName(), imageConverter, outputConverter, outputPath.toString());
// Check predictions are equal
List<Prediction<Label>> externalPredictions = externalModel.predict(testData);
checkPredictionEquality(predictions, externalPredictions);
// Cleanup saved model bundle
externalModel.close();
Files.walk(outputPath).sorted(Comparator.reverseOrder()).map(Path::toFile).forEach(File::delete);
Assertions.assertFalse(Files.exists(outputPath));
// Cleanup created model
model.close();
}
use of org.tribuo.classification.evaluation.LabelEvaluator in project tribuo by oracle.
the class ClassificationTest method testTrainer.
private static TensorFlowModel<Label> testTrainer(TensorFlowTrainer<Label> trainer, Dataset<Label> trainData, Dataset<Label> testData) throws IOException {
// Train the model
TensorFlowModel<Label> model = trainer.train(trainData);
// Run smoke test evaluation
LabelEvaluation eval = new LabelEvaluator().evaluate(model, testData);
Assertions.assertTrue(eval.averageAUCROC(false) > 0.0);
// Check Tribuo serialization
Helpers.testModelSerialization(model, Label.class);
// Check saved model bundle export
Path outputPath = Files.createTempDirectory("tf-classification-test");
model.exportModel(outputPath.toString());
List<Path> files = Files.list(outputPath).collect(Collectors.toList());
Assertions.assertNotEquals(0, files.size());
// Cleanup saved model bundle
Files.walk(outputPath).sorted(Comparator.reverseOrder()).map(Path::toFile).forEach(File::delete);
Assertions.assertFalse(Files.exists(outputPath));
return model;
}
Aggregations