Search in sources :

Example 1 with LabelEvaluation

use of org.tribuo.classification.evaluation.LabelEvaluation in project tribuo by oracle.

the class TestXGBoost method testSingleClassTraining.

@Test
public void testSingleClassTraining() {
    Pair<Dataset<Label>, Dataset<Label>> data = LabelledDataGenerator.denseTrainTest();
    DatasetView<Label> trainingData = DatasetView.createView(data.getA(), (Example<Label> e) -> e.getOutput().getLabel().equals("Foo"), "Foo selector");
    Model<Label> model = t.train(trainingData);
    LabelEvaluation evaluation = (LabelEvaluation) trainingData.getOutputFactory().getEvaluator().evaluate(model, data.getB());
    assertEquals(0.0, evaluation.accuracy(new Label("Bar")));
    assertEquals(0.0, evaluation.accuracy(new Label("Baz")));
    assertEquals(0.0, evaluation.accuracy(new Label("Quux")));
    assertEquals(1.0, evaluation.recall(new Label("Foo")));
}
Also used : LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) ImmutableDataset(org.tribuo.ImmutableDataset) Dataset(org.tribuo.Dataset) Example(org.tribuo.Example) Label(org.tribuo.classification.Label) Test(org.junit.jupiter.api.Test)

Example 2 with LabelEvaluation

use of org.tribuo.classification.evaluation.LabelEvaluation in project tribuo by oracle.

the class TestXGBoost method testXGBoost.

public static Model<Label> testXGBoost(XGBoostClassificationTrainer trainer, Pair<Dataset<Label>, Dataset<Label>> p) {
    Model<Label> m = trainer.train(p.getA());
    LabelEvaluator e = new LabelEvaluator();
    LabelEvaluation evaluation = e.evaluate(m, p.getB());
    Map<String, List<Pair<String, Double>>> features = m.getTopFeatures(3);
    Assertions.assertNotNull(features);
    Assertions.assertFalse(features.isEmpty());
    features = m.getTopFeatures(-1);
    Assertions.assertNotNull(features);
    Assertions.assertFalse(features.isEmpty());
    return m;
}
Also used : LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) Label(org.tribuo.classification.Label) List(java.util.List) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator)

Example 3 with LabelEvaluation

use of org.tribuo.classification.evaluation.LabelEvaluation in project tribuo by oracle.

the class TestXGBoostExternalModel method testMNIST.

@Test
public void testMNIST() throws IOException, URISyntaxException {
    LabelFactory labelFactory = new LabelFactory();
    // Loads regular MNIST
    URL data = TestXGBoostExternalModel.class.getResource("/org/tribuo/classification/xgboost/mnist_test_head.libsvm");
    DataSource<Label> transposedMNIST = new LibSVMDataSource<>(data, labelFactory, false, 784);
    Dataset<Label> dataset = new MutableDataset<>(transposedMNIST);
    Map<String, Integer> featureMapping = new HashMap<>();
    for (int i = 0; i < 784; i++) {
        // This MNIST model has the feature indices transposed to test a non-trivial mapping.
        int id = (783 - i);
        featureMapping.put(String.format("%03d", i), id);
    }
    Map<Label, Integer> outputMapping = new HashMap<>();
    for (Label l : dataset.getOutputInfo().getDomain()) {
        outputMapping.put(l, Integer.parseInt(l.getLabel()));
    }
    XGBoostClassificationConverter labelConverter = new XGBoostClassificationConverter();
    Path testResource = Paths.get(TestXGBoostExternalModel.class.getResource("/org/tribuo/classification/xgboost/xgb_mnist.xgb").toURI());
    XGBoostExternalModel<Label> transposedMNISTXGB = XGBoostExternalModel.createXGBoostModel(labelFactory, featureMapping, outputMapping, labelConverter, testResource);
    LabelEvaluation evaluation = labelFactory.getEvaluator().evaluate(transposedMNISTXGB, transposedMNIST);
    assertEquals(1.0, evaluation.accuracy(), 1e-6);
    assertEquals(0.0, evaluation.balancedErrorRate(), 1e-6);
    Helpers.testModelSerialization(transposedMNISTXGB, Label.class);
}
Also used : Path(java.nio.file.Path) HashMap(java.util.HashMap) Label(org.tribuo.classification.Label) URL(java.net.URL) LabelFactory(org.tribuo.classification.LabelFactory) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) LibSVMDataSource(org.tribuo.datasource.LibSVMDataSource) MutableDataset(org.tribuo.MutableDataset) Test(org.junit.jupiter.api.Test)

Example 4 with LabelEvaluation

use of org.tribuo.classification.evaluation.LabelEvaluation in project tribuo by oracle.

the class TestMNB method testSingleClassTraining.

@Test
public void testSingleClassTraining() {
    Pair<Dataset<Label>, Dataset<Label>> data = LabelledDataGenerator.denseTrainTest(1.0);
    DatasetView<Label> trainingData = DatasetView.createView(data.getA(), (Example<Label> e) -> e.getOutput().getLabel().equals("Foo"), "Foo selector");
    Model<Label> model = t.train(trainingData);
    LabelEvaluation evaluation = (LabelEvaluation) trainingData.getOutputFactory().getEvaluator().evaluate(model, data.getB());
    assertEquals(0.0, evaluation.accuracy(new Label("Bar")));
    assertEquals(0.0, evaluation.accuracy(new Label("Baz")));
    assertEquals(0.0, evaluation.accuracy(new Label("Quux")));
    assertEquals(1.0, evaluation.recall(new Label("Foo")));
}
Also used : LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) Dataset(org.tribuo.Dataset) Example(org.tribuo.Example) Label(org.tribuo.classification.Label) Test(org.junit.jupiter.api.Test)

Example 5 with LabelEvaluation

use of org.tribuo.classification.evaluation.LabelEvaluation in project tribuo by oracle.

the class TestFMClassification method testFMClassification.

public static Model<Label> testFMClassification(Pair<Dataset<Label>, Dataset<Label>> p) {
    Model<Label> m = t.train(p.getA());
    LabelEvaluator e = new LabelEvaluator();
    LabelEvaluation evaluation = e.evaluate(m, p.getB());
    Map<String, List<Pair<String, Double>>> features = m.getTopFeatures(3);
    assertNotNull(features);
    Assertions.assertFalse(features.isEmpty());
    features = m.getTopFeatures(-1);
    assertNotNull(features);
    Assertions.assertFalse(features.isEmpty());
    return m;
}
Also used : LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) Label(org.tribuo.classification.Label) List(java.util.List) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator)

Aggregations

LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)38 Label (org.tribuo.classification.Label)35 LabelEvaluator (org.tribuo.classification.evaluation.LabelEvaluator)24 Dataset (org.tribuo.Dataset)14 List (java.util.List)13 Test (org.junit.jupiter.api.Test)12 MutableDataset (org.tribuo.MutableDataset)12 LabelFactory (org.tribuo.classification.LabelFactory)11 HashMap (java.util.HashMap)9 Path (java.nio.file.Path)7 Example (org.tribuo.Example)7 URL (java.net.URL)6 ImmutableDataset (org.tribuo.ImmutableDataset)6 LibSVMDataSource (org.tribuo.datasource.LibSVMDataSource)6 ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)4 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)4 OrtEnvironment (ai.onnxruntime.OrtEnvironment)3 OrtSession (ai.onnxruntime.OrtSession)3 File (java.io.File)3 FileOutputStream (java.io.FileOutputStream)3