Search in sources :

Example 1 with LabelSequenceEvaluation

use of org.tribuo.classification.sequence.LabelSequenceEvaluation in project tribuo by oracle.

the class SeqTest method main.

/**
 * @param args the command line arguments
 * @throws ClassNotFoundException if it failed to load the model.
 * @throws IOException            if there is any error reading the examples.
 */
// deserialising a generic dataset.
@SuppressWarnings("unchecked")
public static void main(String[] args) throws ClassNotFoundException, IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    CRFOptions o = new CRFOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    logger.info("Configuring gradient optimiser");
    StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
    logger.info(String.format("Set logging interval to %d", o.loggingInterval));
    SequenceDataset<Label> train;
    SequenceDataset<Label> test;
    switch(o.datasetName) {
        case "Gorilla":
        case "gorilla":
            logger.info("Generating gorilla dataset");
            train = SequenceDataGenerator.generateGorillaDataset(1);
            test = SequenceDataGenerator.generateGorillaDataset(1);
            break;
        default:
            if ((o.trainDataset != null) && (o.testDataset != null)) {
                logger.info("Loading training data from " + o.trainDataset);
                try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.trainDataset.toFile())));
                    ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.testDataset.toFile())))) {
                    train = (SequenceDataset<Label>) ois.readObject();
                    logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
                    logger.info("Found " + train.getFeatureIDMap().size() + " features");
                    logger.info("Loading testing data from " + o.testDataset);
                    SequenceDataset<Label> deserTest = (SequenceDataset<Label>) oits.readObject();
                    test = ImmutableSequenceDataset.copyDataset(deserTest, train.getFeatureIDMap(), train.getOutputIDInfo());
                    logger.info(String.format("Loaded %d testing examples", test.size()));
                }
            } else {
                logger.warning("Unknown dataset " + o.datasetName);
                logger.info(cm.usage());
                return;
            }
    }
    SequenceTrainer<Label> trainer = new CRFTrainer(grad, o.epochs, o.loggingInterval, o.seed);
    ((CRFTrainer) trainer).setShuffle(o.shuffle);
    switch(o.modelHashingAlgorithm) {
        case NONE:
            break;
        case HC:
            trainer = new HashingSequenceTrainer<>(trainer, new HashCodeHasher(o.modelHashingSalt));
            break;
        case SHA1:
            trainer = new HashingSequenceTrainer<>(trainer, new MessageDigestHasher("SHA1", o.modelHashingSalt));
            break;
        case SHA256:
            trainer = new HashingSequenceTrainer<>(trainer, new MessageDigestHasher("SHA-256", o.modelHashingSalt));
            break;
        default:
            logger.info("Unknown hasher " + o.modelHashingAlgorithm);
    }
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    CRFModel model = (CRFModel) trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
    if (o.logModel) {
        System.out.println("FeatureMap = " + model.getFeatureIDMap().toString());
        System.out.println("LabelMap = " + model.getOutputIDInfo().toString());
        System.out.println("Features - " + model.generateWeightsString());
    }
    LabelSequenceEvaluator labelEvaluator = new LabelSequenceEvaluator();
    final long testStart = System.currentTimeMillis();
    LabelSequenceEvaluation evaluation = labelEvaluator.evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    System.out.println();
    System.out.println(evaluation.getConfusionMatrix().toString());
    if (o.outputPath != null) {
        FileOutputStream fout = new FileOutputStream(o.outputPath.toFile());
        ObjectOutputStream oout = new ObjectOutputStream(fout);
        oout.writeObject(model);
        oout.close();
        fout.close();
        logger.info("Serialized model to file: " + o.outputPath);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) HashCodeHasher(org.tribuo.hash.HashCodeHasher) MessageDigestHasher(org.tribuo.hash.MessageDigestHasher) Label(org.tribuo.classification.Label) ImmutableSequenceDataset(org.tribuo.sequence.ImmutableSequenceDataset) SequenceDataset(org.tribuo.sequence.SequenceDataset) ObjectOutputStream(java.io.ObjectOutputStream) FileInputStream(java.io.FileInputStream) LabelSequenceEvaluator(org.tribuo.classification.sequence.LabelSequenceEvaluator) LabelSequenceEvaluation(org.tribuo.classification.sequence.LabelSequenceEvaluation) BufferedInputStream(java.io.BufferedInputStream) StochasticGradientOptimiser(org.tribuo.math.StochasticGradientOptimiser) FileOutputStream(java.io.FileOutputStream) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ObjectInputStream(java.io.ObjectInputStream)

Aggregations

ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)1 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)1 BufferedInputStream (java.io.BufferedInputStream)1 FileInputStream (java.io.FileInputStream)1 FileOutputStream (java.io.FileOutputStream)1 ObjectInputStream (java.io.ObjectInputStream)1 ObjectOutputStream (java.io.ObjectOutputStream)1 Label (org.tribuo.classification.Label)1 LabelSequenceEvaluation (org.tribuo.classification.sequence.LabelSequenceEvaluation)1 LabelSequenceEvaluator (org.tribuo.classification.sequence.LabelSequenceEvaluator)1 HashCodeHasher (org.tribuo.hash.HashCodeHasher)1 MessageDigestHasher (org.tribuo.hash.MessageDigestHasher)1 StochasticGradientOptimiser (org.tribuo.math.StochasticGradientOptimiser)1 ImmutableSequenceDataset (org.tribuo.sequence.ImmutableSequenceDataset)1 SequenceDataset (org.tribuo.sequence.SequenceDataset)1