Search in sources :

Example 36 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    TrainTestOptions o = new TrainTestOptions();
    try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
        Trainer<Label> trainer = o.libLinearOptions.getTrainer();
        trainer = o.ensembleOptions.wrapTrainer(trainer);
        TrainTestHelper.run(cm, o.general, trainer);
    } catch (UsageException e) {
        System.out.println(e.getUsage());
    } catch (ArgumentException e) {
        System.out.println("Invalid argument: " + e.getMessage());
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) ArgumentException(com.oracle.labs.mlrg.olcut.config.ArgumentException) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 37 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class Test method main.

/**
 * Runs the Test CLI.
 * @param args the command line arguments
 */
public static void main(String[] args) {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    ConfigurableTestOptions o = new ConfigurableTestOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.modelPath == null || o.testingPath == null) {
        logger.info(cm.usage());
        System.exit(1);
    }
    Pair<Model<Label>, Dataset<Label>> loaded = null;
    try {
        loaded = load(o);
    } catch (IOException e) {
        logger.log(Level.SEVERE, "Failed to load model/data", e);
        System.exit(1);
    }
    Model<Label> model = loaded.getA();
    Dataset<Label> test = loaded.getB();
    logger.info("Model is " + model.toString());
    logger.info("Labels are " + model.getOutputIDInfo().toReadableString());
    LabelEvaluator labelEvaluator = new LabelEvaluator();
    final long testStart = System.currentTimeMillis();
    List<Prediction<Label>> predictions = model.predict(test);
    LabelEvaluation evaluation = labelEvaluator.evaluate(model, predictions, test.getProvenance());
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    System.out.println(evaluation.getConfusionMatrix().toString());
    if (model.generatesProbabilities()) {
        System.out.println("Average AUC = " + evaluation.averageAUCROC(false));
        System.out.println("Average weighted AUC = " + evaluation.averageAUCROC(true));
    }
    if (o.predictionPath != null) {
        try (BufferedWriter wrt = Files.newBufferedWriter(o.predictionPath)) {
            List<String> labels = model.getOutputIDInfo().getDomain().stream().map(Label::getLabel).sorted().collect(Collectors.toList());
            wrt.write("Label,");
            wrt.write(String.join(",", labels));
            wrt.newLine();
            for (Prediction<Label> pred : predictions) {
                Example<Label> ex = pred.getExample();
                wrt.write(ex.getOutput().getLabel() + ",");
                wrt.write(labels.stream().map(l -> Double.toString(pred.getOutputScores().get(l).getScore())).collect(Collectors.joining(",")));
                wrt.newLine();
            }
            wrt.flush();
        } catch (IOException e) {
            logger.log(Level.SEVERE, "Error writing predictions", e);
        }
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) ImmutableDataset(org.tribuo.ImmutableDataset) Dataset(org.tribuo.Dataset) Prediction(org.tribuo.Prediction) Label(org.tribuo.classification.Label) IOException(java.io.IOException) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator) BufferedWriter(java.io.BufferedWriter) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) Model(org.tribuo.Model) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 38 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    AllClassificationOptions o = new AllClassificationOptions();
    try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
        Trainer<Label> trainer = o.trainerOptions.getTrainer();
        TrainTestHelper.run(cm, o.general, trainer);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 39 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    LibLinearOptions o = new LibLinearOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null) {
        logger.info(cm.usage());
        return;
    }
    RegressionFactory factory = new RegressionFactory();
    Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
    Dataset<Regressor> train = data.getA();
    Dataset<Regressor> test = data.getB();
    Trainer<Regressor> trainer = new LibLinearRegressionTrainer(new LinearRegressionType(o.algorithm), o.cost, o.maxIterations, o.terminationCriterion, o.epsilon);
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    Model<Regressor> model = trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
    final long testStart = System.currentTimeMillis();
    RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) RegressionFactory(org.tribuo.regression.RegressionFactory) Dataset(org.tribuo.Dataset) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Regressor(org.tribuo.regression.Regressor) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 40 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class TrainTest method main.

/**
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    FMRegressionOptions o = new FMRegressionOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null) {
        logger.info(cm.usage());
        return;
    }
    logger.info("Configuring gradient optimiser");
    RegressionObjective obj = null;
    switch(o.loss) {
        case ABSOLUTE:
            obj = new AbsoluteLoss();
            break;
        case SQUARED:
            obj = new SquaredLoss();
            break;
        case HUBER:
            obj = new Huber();
            break;
        default:
            logger.warning("Unknown objective function " + o.loss);
            logger.info(cm.usage());
            return;
    }
    StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
    logger.info(String.format("Set logging interval to %d", o.loggingInterval));
    RegressionFactory factory = new RegressionFactory();
    Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
    Dataset<Regressor> train = data.getA();
    Dataset<Regressor> test = data.getB();
    logger.info("Feature domain - " + train.getFeatureIDMap());
    Trainer<Regressor> trainer = new FMRegressionTrainer(obj, grad, o.epochs, o.loggingInterval, o.minibatchSize, o.general.seed, o.factorSize, o.variance, o.standardise);
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    Model<Regressor> model = trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
    final long testStart = System.currentTimeMillis();
    RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) RegressionObjective(org.tribuo.regression.sgd.RegressionObjective) SquaredLoss(org.tribuo.regression.sgd.objectives.SquaredLoss) RegressionFactory(org.tribuo.regression.RegressionFactory) Dataset(org.tribuo.Dataset) Huber(org.tribuo.regression.sgd.objectives.Huber) AbsoluteLoss(org.tribuo.regression.sgd.objectives.AbsoluteLoss) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) StochasticGradientOptimiser(org.tribuo.math.StochasticGradientOptimiser) Regressor(org.tribuo.regression.Regressor) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Aggregations

ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)42 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)32 Label (org.tribuo.classification.Label)16 Dataset (org.tribuo.Dataset)15 IOException (java.io.IOException)8 FileOutputStream (java.io.FileOutputStream)7 ObjectOutputStream (java.io.ObjectOutputStream)7 RegressionFactory (org.tribuo.regression.RegressionFactory)7 Regressor (org.tribuo.regression.Regressor)7 RegressionEvaluation (org.tribuo.regression.evaluation.RegressionEvaluation)7 BufferedWriter (java.io.BufferedWriter)4 File (java.io.File)4 OutputStreamWriter (java.io.OutputStreamWriter)4 MutableDataset (org.tribuo.MutableDataset)4 LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)4 LabelEvaluator (org.tribuo.classification.evaluation.LabelEvaluator)4 FileInputStream (java.io.FileInputStream)3 ObjectInputStream (java.io.ObjectInputStream)3 PrintWriter (java.io.PrintWriter)3 ArrayList (java.util.ArrayList)3