use of org.tribuo.classification.baseline.DummyClassifierTrainer in project tribuo by oracle.
the class ClassifierChainTest method testInvalidChain.
@Test
public void testInvalidChain() {
// Chains can be invalid in several ways
// incorrect number of labels, duplicate labels, or labels not in the training data
// Generate data
Dataset<MultiLabel> train = MultiLabelDataGenerator.generateTrainData();
DummyClassifierTrainer trainer = DummyClassifierTrainer.createConstantTrainer("MONKEY");
List<String> labelOrder;
// Too many labels
labelOrder = Arrays.asList("MONKEY", "PUZZLE", "TREE", "PINE");
ClassifierChainTrainer tooMany = new ClassifierChainTrainer(trainer, labelOrder);
assertThrows(IllegalArgumentException.class, () -> tooMany.train(train));
// Too few labels
labelOrder = Arrays.asList("MONKEY", "PUZZLE");
ClassifierChainTrainer tooFew = new ClassifierChainTrainer(trainer, labelOrder);
assertThrows(IllegalArgumentException.class, () -> tooFew.train(train));
// Duplicate valid labels
labelOrder = Arrays.asList("MONKEY", "PUZZLE", "PUZZLE");
ClassifierChainTrainer duplicate = new ClassifierChainTrainer(trainer, labelOrder);
assertThrows(IllegalArgumentException.class, () -> duplicate.train(train));
// Labels not in the training data
labelOrder = Arrays.asList("MONKEY", "PUZZLE", "PINE");
ClassifierChainTrainer invalidLabels = new ClassifierChainTrainer(trainer, labelOrder);
assertThrows(IllegalArgumentException.class, () -> invalidLabels.train(train));
}
Aggregations