Search in sources :

Example 1 with LabelEvaluator

use of org.tribuo.classification.evaluation.LabelEvaluator in project tribuo by oracle.

the class TestXGBoost method testXGBoost.

public static Model<Label> testXGBoost(XGBoostClassificationTrainer trainer, Pair<Dataset<Label>, Dataset<Label>> p) {
    Model<Label> m = trainer.train(p.getA());
    LabelEvaluator e = new LabelEvaluator();
    LabelEvaluation evaluation = e.evaluate(m, p.getB());
    Map<String, List<Pair<String, Double>>> features = m.getTopFeatures(3);
    Assertions.assertNotNull(features);
    Assertions.assertFalse(features.isEmpty());
    features = m.getTopFeatures(-1);
    Assertions.assertNotNull(features);
    Assertions.assertFalse(features.isEmpty());
    return m;
}
Also used : LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) Label(org.tribuo.classification.Label) List(java.util.List) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator)

Example 2 with LabelEvaluator

use of org.tribuo.classification.evaluation.LabelEvaluator in project tribuo by oracle.

the class TestFMClassification method testFMClassification.

public static Model<Label> testFMClassification(Pair<Dataset<Label>, Dataset<Label>> p) {
    Model<Label> m = t.train(p.getA());
    LabelEvaluator e = new LabelEvaluator();
    LabelEvaluation evaluation = e.evaluate(m, p.getB());
    Map<String, List<Pair<String, Double>>> features = m.getTopFeatures(3);
    assertNotNull(features);
    Assertions.assertFalse(features.isEmpty());
    features = m.getTopFeatures(-1);
    assertNotNull(features);
    Assertions.assertFalse(features.isEmpty());
    return m;
}
Also used : LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) Label(org.tribuo.classification.Label) List(java.util.List) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator)

Example 3 with LabelEvaluator

use of org.tribuo.classification.evaluation.LabelEvaluator in project tribuo by oracle.

the class TestDummyClassifier method testDummyClassifier.

public static void testDummyClassifier(Pair<Dataset<Label>, Dataset<Label>> p, boolean testModelSave) {
    for (Trainer<Label> t : trainers) {
        Model<Label> m = t.train(p.getA());
        Evaluator<Label, LabelEvaluation> evaluator = new LabelEvaluator();
        LabelEvaluation evaluation = evaluator.evaluate(m, p.getB());
        if (testModelSave) {
            Helpers.testModelSerialization(m, Label.class);
        }
    }
}
Also used : LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) Label(org.tribuo.classification.Label) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator)

Example 4 with LabelEvaluator

use of org.tribuo.classification.evaluation.LabelEvaluator in project tribuo by oracle.

the class ClassificationTest method classificationCNNTest.

@Test
public void classificationCNNTest() throws IOException {
    // Create the train and test data
    Pair<Dataset<Label>, Dataset<Label>> data = generateImageData(512, 10, 128, 5, 42);
    Dataset<Label> trainData = data.getA();
    Dataset<Label> testData = data.getB();
    // Build the CNN graph
    GraphDefTuple graphDefTuple = CNNExamples.buildLeNetGraph(INPUT_NAME, 10, 255, trainData.getOutputs().size());
    // Configure the trainer
    Map<String, Float> gradientParams = new HashMap<>();
    gradientParams.put("learningRate", 0.01f);
    gradientParams.put("initialAccumulatorValue", 0.1f);
    FeatureConverter imageConverter = new ImageConverter(INPUT_NAME, 10, 10, 1);
    OutputConverter<Label> outputConverter = new LabelConverter();
    TensorFlowTrainer<Label> trainer = new TensorFlowTrainer<>(graphDefTuple.graphDef, graphDefTuple.outputName, GradientOptimiser.ADAGRAD, gradientParams, imageConverter, outputConverter, 16, 5, 16, -1);
    // Train the model
    TensorFlowModel<Label> model = trainer.train(trainData);
    // Make some predictions
    List<Prediction<Label>> predictions = model.predict(testData);
    // Run smoke test evaluation
    LabelEvaluation eval = new LabelEvaluator().evaluate(model, predictions, testData.getProvenance());
    Assertions.assertTrue(eval.accuracy() > 0.0);
    // Check Tribuo serialization
    Helpers.testModelSerialization(model, Label.class);
    // Check saved model bundle export
    Path outputPath = Files.createTempDirectory("tf-classification-cnn-test");
    model.exportModel(outputPath.toString());
    try (Stream<Path> f = Files.list(outputPath)) {
        List<Path> files = f.collect(Collectors.toList());
        Assertions.assertNotEquals(0, files.size());
    }
    // Create external model from bundle
    Map<Label, Integer> outputMapping = new HashMap<>();
    for (Pair<Integer, Label> p : model.getOutputIDInfo()) {
        outputMapping.put(p.getB(), p.getA());
    }
    Map<String, Integer> featureMapping = new HashMap<>();
    ImmutableFeatureMap featureIDMap = model.getFeatureIDMap();
    for (VariableInfo info : featureIDMap) {
        featureMapping.put(info.getName(), featureIDMap.getID(info.getName()));
    }
    TensorFlowSavedModelExternalModel<Label> externalModel = TensorFlowSavedModelExternalModel.createTensorflowModel(trainData.getOutputFactory(), featureMapping, outputMapping, model.getOutputName(), imageConverter, outputConverter, outputPath.toString());
    // Check predictions are equal
    List<Prediction<Label>> externalPredictions = externalModel.predict(testData);
    checkPredictionEquality(predictions, externalPredictions);
    // Cleanup saved model bundle
    externalModel.close();
    Files.walk(outputPath).sorted(Comparator.reverseOrder()).map(Path::toFile).forEach(File::delete);
    Assertions.assertFalse(Files.exists(outputPath));
    // Cleanup created model
    model.close();
}
Also used : GraphDefTuple(org.tribuo.interop.tensorflow.example.GraphDefTuple) HashMap(java.util.HashMap) Label(org.tribuo.classification.Label) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) Path(java.nio.file.Path) Dataset(org.tribuo.Dataset) MutableDataset(org.tribuo.MutableDataset) Prediction(org.tribuo.Prediction) VariableInfo(org.tribuo.VariableInfo) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator) File(java.io.File) Test(org.junit.jupiter.api.Test)

Example 5 with LabelEvaluator

use of org.tribuo.classification.evaluation.LabelEvaluator in project tribuo by oracle.

the class ClassificationTest method testTrainer.

private static TensorFlowModel<Label> testTrainer(TensorFlowTrainer<Label> trainer, Dataset<Label> trainData, Dataset<Label> testData) throws IOException {
    // Train the model
    TensorFlowModel<Label> model = trainer.train(trainData);
    // Run smoke test evaluation
    LabelEvaluation eval = new LabelEvaluator().evaluate(model, testData);
    Assertions.assertTrue(eval.averageAUCROC(false) > 0.0);
    // Check Tribuo serialization
    Helpers.testModelSerialization(model, Label.class);
    // Check saved model bundle export
    Path outputPath = Files.createTempDirectory("tf-classification-test");
    model.exportModel(outputPath.toString());
    List<Path> files = Files.list(outputPath).collect(Collectors.toList());
    Assertions.assertNotEquals(0, files.size());
    // Cleanup saved model bundle
    Files.walk(outputPath).sorted(Comparator.reverseOrder()).map(Path::toFile).forEach(File::delete);
    Assertions.assertFalse(Files.exists(outputPath));
    return model;
}
Also used : Path(java.nio.file.Path) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) Label(org.tribuo.classification.Label) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator) File(java.io.File)

Aggregations

LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)24 LabelEvaluator (org.tribuo.classification.evaluation.LabelEvaluator)24 Label (org.tribuo.classification.Label)22 List (java.util.List)13 Dataset (org.tribuo.Dataset)6 MutableDataset (org.tribuo.MutableDataset)6 LabelFactory (org.tribuo.classification.LabelFactory)6 ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)4 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)4 HashMap (java.util.HashMap)4 File (java.io.File)3 FileOutputStream (java.io.FileOutputStream)3 IOException (java.io.IOException)3 ImmutableDataset (org.tribuo.ImmutableDataset)3 Prediction (org.tribuo.Prediction)3 BufferedWriter (java.io.BufferedWriter)2 ObjectInputStream (java.io.ObjectInputStream)2 ObjectOutputStream (java.io.ObjectOutputStream)2 Path (java.nio.file.Path)2 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)2