Search in sources :

Example 1 with DummyClassifierTrainer

use of org.tribuo.classification.baseline.DummyClassifierTrainer in project tribuo by oracle.

the class ClassifierChainTest method testInvalidChain.

@Test
public void testInvalidChain() {
    // Chains can be invalid in several ways
    // incorrect number of labels, duplicate labels, or labels not in the training data
    // Generate data
    Dataset<MultiLabel> train = MultiLabelDataGenerator.generateTrainData();
    DummyClassifierTrainer trainer = DummyClassifierTrainer.createConstantTrainer("MONKEY");
    List<String> labelOrder;
    // Too many labels
    labelOrder = Arrays.asList("MONKEY", "PUZZLE", "TREE", "PINE");
    ClassifierChainTrainer tooMany = new ClassifierChainTrainer(trainer, labelOrder);
    assertThrows(IllegalArgumentException.class, () -> tooMany.train(train));
    // Too few labels
    labelOrder = Arrays.asList("MONKEY", "PUZZLE");
    ClassifierChainTrainer tooFew = new ClassifierChainTrainer(trainer, labelOrder);
    assertThrows(IllegalArgumentException.class, () -> tooFew.train(train));
    // Duplicate valid labels
    labelOrder = Arrays.asList("MONKEY", "PUZZLE", "PUZZLE");
    ClassifierChainTrainer duplicate = new ClassifierChainTrainer(trainer, labelOrder);
    assertThrows(IllegalArgumentException.class, () -> duplicate.train(train));
    // Labels not in the training data
    labelOrder = Arrays.asList("MONKEY", "PUZZLE", "PINE");
    ClassifierChainTrainer invalidLabels = new ClassifierChainTrainer(trainer, labelOrder);
    assertThrows(IllegalArgumentException.class, () -> invalidLabels.train(train));
}
Also used : MultiLabel(org.tribuo.multilabel.MultiLabel) DummyClassifierTrainer(org.tribuo.classification.baseline.DummyClassifierTrainer) Test(org.junit.jupiter.api.Test)

Aggregations

Test (org.junit.jupiter.api.Test)1 DummyClassifierTrainer (org.tribuo.classification.baseline.DummyClassifierTrainer)1 MultiLabel (org.tribuo.multilabel.MultiLabel)1