Search in sources :

Example 1 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class CompletelyConfigurableTrainTest method main.

/**
 * @param args the command line arguments
 * @param <T> The {@link Output} subclass.
 */
@SuppressWarnings("unchecked")
public static <T extends Output<T>> void main(String[] args) {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.trainSource == null || o.testSource == null) {
        logger.info(cm.usage());
        System.exit(1);
    } else if (o.trainer == null) {
        logger.warning("No trainer supplied");
        logger.info(cm.usage());
        System.exit(1);
    }
    Dataset<T> train = new MutableDataset<>((DataSource<T>) o.trainSource);
    if (o.minCount > 0) {
        logger.info("Removing features which occur fewer than " + o.minCount + " times.");
        train = new MinimumCardinalityDataset<>(train, o.minCount);
    }
    Dataset<T> test = new MutableDataset<>((DataSource<T>) o.testSource);
    if (o.transformationMap != null) {
        o.trainer = new TransformTrainer<>(o.trainer, o.transformationMap);
    }
    logger.info("Trainer is " + o.trainer.getProvenance().toString());
    logger.info("Outputs are " + train.getOutputInfo().toReadableString());
    logger.info("Number of features: " + train.getFeatureMap().size());
    final long trainStart = System.currentTimeMillis();
    Model<T> model = ((Trainer<T>) o.trainer).train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
    Evaluator<T, ? extends Evaluation<T>> evaluator = train.getOutputFactory().getEvaluator();
    final long testStart = System.currentTimeMillis();
    Evaluation<T> evaluation = evaluator.evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    if (o.outputPath != null) {
        try (ObjectOutputStream oout = new ObjectOutputStream(new FileOutputStream(o.outputPath.toFile()))) {
            oout.writeObject(model);
            logger.info("Serialized model to file: " + o.outputPath);
        } catch (IOException e) {
            logger.log(Level.SEVERE, "Error writing model", e);
        }
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) TransformTrainer(org.tribuo.transform.TransformTrainer) Trainer(org.tribuo.Trainer) IOException(java.io.IOException) ObjectOutputStream(java.io.ObjectOutputStream) FileOutputStream(java.io.FileOutputStream) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) MutableDataset(org.tribuo.MutableDataset)

Example 2 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class DatasetExplorer method main.

/**
 * Runs a dataset explorer.
 * @param args CLI arguments.
 */
public static void main(String[] args) {
    DatasetExplorerOptions options = new DatasetExplorerOptions();
    ConfigurationManager cm = new ConfigurationManager(args, options, false);
    DatasetExplorer driver = new DatasetExplorer();
    if (options.modelFilename != null) {
        logger.log(Level.INFO, driver.loadDataset(driver.shell, new File(options.modelFilename)));
    }
    driver.startShell();
}
Also used : ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) File(java.io.File)

Example 3 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class CoreTokenizerOptionsTest method createTokenizer.

private Tokenizer createTokenizer(String[] args) {
    CoreTokenizerOptions options = new CoreTokenizerOptions();
    ConfigurationManager cm = new ConfigurationManager(args, options);
    Tokenizer tokenizer = options.getTokenizer();
    cm.close();
    return tokenizer;
}
Also used : ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) Tokenizer(org.tribuo.util.tokens.Tokenizer) NonTokenizer(org.tribuo.util.tokens.impl.NonTokenizer) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) SplitPatternTokenizer(org.tribuo.util.tokens.impl.SplitPatternTokenizer) UniversalTokenizer(org.tribuo.util.tokens.universal.UniversalTokenizer) ShapeTokenizer(org.tribuo.util.tokens.impl.ShapeTokenizer) SplitCharactersTokenizer(org.tribuo.util.tokens.impl.SplitCharactersTokenizer)

Example 4 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class SeqTest method main.

/**
 * @param args the command line arguments
 * @throws ClassNotFoundException if it failed to load the model.
 * @throws IOException            if there is any error reading the examples.
 */
// deserialising a generic dataset.
@SuppressWarnings("unchecked")
public static void main(String[] args) throws ClassNotFoundException, IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    CRFOptions o = new CRFOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    logger.info("Configuring gradient optimiser");
    StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
    logger.info(String.format("Set logging interval to %d", o.loggingInterval));
    SequenceDataset<Label> train;
    SequenceDataset<Label> test;
    switch(o.datasetName) {
        case "Gorilla":
        case "gorilla":
            logger.info("Generating gorilla dataset");
            train = SequenceDataGenerator.generateGorillaDataset(1);
            test = SequenceDataGenerator.generateGorillaDataset(1);
            break;
        default:
            if ((o.trainDataset != null) && (o.testDataset != null)) {
                logger.info("Loading training data from " + o.trainDataset);
                try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.trainDataset.toFile())));
                    ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.testDataset.toFile())))) {
                    train = (SequenceDataset<Label>) ois.readObject();
                    logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
                    logger.info("Found " + train.getFeatureIDMap().size() + " features");
                    logger.info("Loading testing data from " + o.testDataset);
                    SequenceDataset<Label> deserTest = (SequenceDataset<Label>) oits.readObject();
                    test = ImmutableSequenceDataset.copyDataset(deserTest, train.getFeatureIDMap(), train.getOutputIDInfo());
                    logger.info(String.format("Loaded %d testing examples", test.size()));
                }
            } else {
                logger.warning("Unknown dataset " + o.datasetName);
                logger.info(cm.usage());
                return;
            }
    }
    SequenceTrainer<Label> trainer = new CRFTrainer(grad, o.epochs, o.loggingInterval, o.seed);
    ((CRFTrainer) trainer).setShuffle(o.shuffle);
    switch(o.modelHashingAlgorithm) {
        case NONE:
            break;
        case HC:
            trainer = new HashingSequenceTrainer<>(trainer, new HashCodeHasher(o.modelHashingSalt));
            break;
        case SHA1:
            trainer = new HashingSequenceTrainer<>(trainer, new MessageDigestHasher("SHA1", o.modelHashingSalt));
            break;
        case SHA256:
            trainer = new HashingSequenceTrainer<>(trainer, new MessageDigestHasher("SHA-256", o.modelHashingSalt));
            break;
        default:
            logger.info("Unknown hasher " + o.modelHashingAlgorithm);
    }
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    CRFModel model = (CRFModel) trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
    if (o.logModel) {
        System.out.println("FeatureMap = " + model.getFeatureIDMap().toString());
        System.out.println("LabelMap = " + model.getOutputIDInfo().toString());
        System.out.println("Features - " + model.generateWeightsString());
    }
    LabelSequenceEvaluator labelEvaluator = new LabelSequenceEvaluator();
    final long testStart = System.currentTimeMillis();
    LabelSequenceEvaluation evaluation = labelEvaluator.evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    System.out.println();
    System.out.println(evaluation.getConfusionMatrix().toString());
    if (o.outputPath != null) {
        FileOutputStream fout = new FileOutputStream(o.outputPath.toFile());
        ObjectOutputStream oout = new ObjectOutputStream(fout);
        oout.writeObject(model);
        oout.close();
        fout.close();
        logger.info("Serialized model to file: " + o.outputPath);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) HashCodeHasher(org.tribuo.hash.HashCodeHasher) MessageDigestHasher(org.tribuo.hash.MessageDigestHasher) Label(org.tribuo.classification.Label) ImmutableSequenceDataset(org.tribuo.sequence.ImmutableSequenceDataset) SequenceDataset(org.tribuo.sequence.SequenceDataset) ObjectOutputStream(java.io.ObjectOutputStream) FileInputStream(java.io.FileInputStream) LabelSequenceEvaluator(org.tribuo.classification.sequence.LabelSequenceEvaluator) LabelSequenceEvaluation(org.tribuo.classification.sequence.LabelSequenceEvaluation) BufferedInputStream(java.io.BufferedInputStream) StochasticGradientOptimiser(org.tribuo.math.StochasticGradientOptimiser) FileOutputStream(java.io.FileOutputStream) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ObjectInputStream(java.io.ObjectInputStream)

Example 5 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    TrainTestOptions o = new TrainTestOptions();
    try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
        Trainer<Label> trainer = o.xgboostOptions.getTrainer();
        TrainTestHelper.run(cm, o.general, trainer);
    } catch (UsageException e) {
        logger.info(e.getMessage());
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Aggregations

ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)42 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)32 Label (org.tribuo.classification.Label)16 Dataset (org.tribuo.Dataset)15 IOException (java.io.IOException)8 FileOutputStream (java.io.FileOutputStream)7 ObjectOutputStream (java.io.ObjectOutputStream)7 RegressionFactory (org.tribuo.regression.RegressionFactory)7 Regressor (org.tribuo.regression.Regressor)7 RegressionEvaluation (org.tribuo.regression.evaluation.RegressionEvaluation)7 BufferedWriter (java.io.BufferedWriter)4 File (java.io.File)4 OutputStreamWriter (java.io.OutputStreamWriter)4 MutableDataset (org.tribuo.MutableDataset)4 LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)4 LabelEvaluator (org.tribuo.classification.evaluation.LabelEvaluator)4 FileInputStream (java.io.FileInputStream)3 ObjectInputStream (java.io.ObjectInputStream)3 PrintWriter (java.io.PrintWriter)3 ArrayList (java.util.ArrayList)3