Search in sources :

Example 11 with UsageException

use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    TrainTestOptions o = new TrainTestOptions();
    try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
        Trainer<Label> trainer = o.cartOptions.getTrainer();
        Model<Label> model = TrainTestHelper.run(cm, o.generalOptions, trainer);
        if (o.cartOptions.cartPrintTree) {
            logger.info(model.toString());
        }
    } catch (UsageException e) {
        logger.info(e.getMessage());
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 12 with UsageException

use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.

the class LIMETextCLI method main.

/**
 * Runs a LIMETextCLI.
 * @param args The CLI arguments.
 */
public static void main(String[] args) {
    LIMETextCLI.LIMETextCLIOptions options = new LIMETextCLI.LIMETextCLIOptions();
    try {
        ConfigurationManager cm = new ConfigurationManager(args, options, false);
        LIMETextCLI driver = new LIMETextCLI();
        if (options.modelFilename != null) {
            logger.log(Level.INFO, driver.loadModel(driver.shell, new File(options.modelFilename)));
        }
        driver.startShell();
    } catch (UsageException e) {
        System.out.println("Usage: " + e.getUsage());
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) File(java.io.File)

Example 13 with UsageException

use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    TrainTestOptions o = new TrainTestOptions();
    try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
        Trainer<Label> trainer = o.libsvmOptions.getTrainer();
        trainer = o.ensembleOptions.wrapTrainer(trainer);
        Model<Label> model = TrainTestHelper.run(cm, o.general, trainer);
        if (model instanceof LibSVMClassificationModel) {
            logger.info("Used " + ((LibSVMClassificationModel) model).getNumberOfSupportVectors() + " support vectors");
        }
    } catch (UsageException e) {
        System.out.println(e.getUsage());
    } catch (ArgumentException e) {
        System.out.println("Invalid argument: " + e.getMessage());
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) ArgumentException(com.oracle.labs.mlrg.olcut.config.ArgumentException) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 14 with UsageException

use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.

the class RunAll method main.

/**
 * Runs the RunALL CLI.
 * @param args The CLI arguments.
 * @throws IOException If it failed to load the data.
 */
public static void main(String[] args) throws IOException {
    LabsLogFormatter.setAllLogFormatters();
    RunAllOptions o = new RunAllOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null || o.directory == null) {
        logger.info(cm.usage());
        System.exit(1);
    }
    Pair<Dataset<Label>, Dataset<Label>> data = null;
    try {
        data = o.general.load(new LabelFactory());
    } catch (IOException e) {
        logger.log(Level.SEVERE, "Failed to load data", e);
        System.exit(1);
    }
    Dataset<Label> train = data.getA();
    Dataset<Label> test = data.getB();
    logger.info("Creating directory - " + o.directory.toString());
    if (!o.directory.exists() && !o.directory.mkdirs()) {
        logger.warning("Failed to create directory.");
    }
    Map<String, Double> performances = new HashMap<>();
    List<Trainer> trainers = cm.lookupAll(Trainer.class);
    for (Trainer<?> t : trainers) {
        String name = t.getClass().getSimpleName();
        logger.info("Training model using " + t.toString());
        // configuration system cast.
        @SuppressWarnings("unchecked") Model<Label> curModel = ((Trainer<Label>) t).train(train);
        LabelEvaluator evaluator = new LabelEvaluator();
        LabelEvaluation evaluation = evaluator.evaluate(curModel, test);
        Double old = performances.put(name, evaluation.microAveragedF1());
        if (old != null) {
            logger.info("Found two trainers with the name " + name);
        }
        String outputPath = o.directory.toString() + "/" + name;
        try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(outputPath + ".model"))) {
            oos.writeObject(curModel);
        }
        try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(outputPath + ".output"), StandardCharsets.UTF_8))) {
            writer.println("Model = " + name);
            writer.println("Provenance = " + curModel.toString());
            writer.println();
            ConfusionMatrix<Label> matrix = evaluation.getConfusionMatrix();
            writer.println("ConfusionMatrix:\n" + matrix.toString());
            writer.println();
            writer.println("Evaluation:\n" + evaluation.toString());
        }
    }
    for (Map.Entry<String, Double> e : performances.entrySet()) {
        logger.info("Trainer = " + e.getKey() + ", F1 = " + e.getValue());
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) HashMap(java.util.HashMap) Label(org.tribuo.classification.Label) Trainer(org.tribuo.Trainer) ObjectOutputStream(java.io.ObjectOutputStream) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) PrintWriter(java.io.PrintWriter) Dataset(org.tribuo.Dataset) IOException(java.io.IOException) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator) LabelFactory(org.tribuo.classification.LabelFactory) FileOutputStream(java.io.FileOutputStream) OutputStreamWriter(java.io.OutputStreamWriter) HashMap(java.util.HashMap) Map(java.util.Map)

Example 15 with UsageException

use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    LibSVMOptions o = new LibSVMOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null) {
        logger.info(cm.usage());
        return;
    }
    RegressionFactory factory = new RegressionFactory();
    Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
    Dataset<Regressor> train = data.getA();
    Dataset<Regressor> test = data.getB();
    SVMParameters<Regressor> parameters = new SVMParameters<>(new SVMRegressionType(o.svmType), o.kernelType);
    parameters.setGamma(o.gamma);
    parameters.setCoeff(o.coeff);
    parameters.setDegree(o.degree);
    Trainer<Regressor> trainer = new LibSVMRegressionTrainer(parameters, o.standardize);
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    Model<Regressor> model = trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
    logger.info("Support vectors - " + ((LibSVMRegressionModel) model).getNumberOfSupportVectors());
    final long testStart = System.currentTimeMillis();
    RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) RegressionFactory(org.tribuo.regression.RegressionFactory) Dataset(org.tribuo.Dataset) SVMParameters(org.tribuo.common.libsvm.SVMParameters) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Regressor(org.tribuo.regression.Regressor) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Aggregations

ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)32 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)32 Label (org.tribuo.classification.Label)15 Dataset (org.tribuo.Dataset)14 IOException (java.io.IOException)8 ObjectOutputStream (java.io.ObjectOutputStream)7 RegressionFactory (org.tribuo.regression.RegressionFactory)7 Regressor (org.tribuo.regression.Regressor)7 RegressionEvaluation (org.tribuo.regression.evaluation.RegressionEvaluation)7 FileOutputStream (java.io.FileOutputStream)6 MutableDataset (org.tribuo.MutableDataset)4 LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)4 LabelEvaluator (org.tribuo.classification.evaluation.LabelEvaluator)4 BufferedWriter (java.io.BufferedWriter)3 ObjectInputStream (java.io.ObjectInputStream)3 OutputStreamWriter (java.io.OutputStreamWriter)3 Model (org.tribuo.Model)3 LabelFactory (org.tribuo.classification.LabelFactory)3 ArgumentException (com.oracle.labs.mlrg.olcut.config.ArgumentException)2 BufferedInputStream (java.io.BufferedInputStream)2