Search in sources :

Example 1 with ArgumentException

use of com.oracle.labs.mlrg.olcut.config.ArgumentException in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    TrainTestOptions o = new TrainTestOptions();
    try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
        Trainer<Label> trainer = o.libsvmOptions.getTrainer();
        trainer = o.ensembleOptions.wrapTrainer(trainer);
        Model<Label> model = TrainTestHelper.run(cm, o.general, trainer);
        if (model instanceof LibSVMClassificationModel) {
            logger.info("Used " + ((LibSVMClassificationModel) model).getNumberOfSupportVectors() + " support vectors");
        }
    } catch (UsageException e) {
        System.out.println(e.getUsage());
    } catch (ArgumentException e) {
        System.out.println("Invalid argument: " + e.getMessage());
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) ArgumentException(com.oracle.labs.mlrg.olcut.config.ArgumentException) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 2 with ArgumentException

use of com.oracle.labs.mlrg.olcut.config.ArgumentException in project tribuo by oracle.

the class TrainTestHelper method run.

/**
 * This method trains a model on the specified training data, and evaluates it
 * on the specified test data. It writes out the timing to it's logger, and the
 * statistical performance to standard out. If set, the model is written out
 * to the specified path on disk.
 * @param cm The configuration manager which knows the arguments.
 * @param dataOptions The data options which specify the training and test data.
 * @param trainer The trainer to use.
 * @return The trained model.
 * @throws IOException If the data failed to load.
 */
public static Model<Label> run(ConfigurationManager cm, DataOptions dataOptions, Trainer<Label> trainer) throws IOException {
    LabsLogFormatter.setAllLogFormatters();
    if (dataOptions.trainingPath == null || dataOptions.testingPath == null) {
        logger.info(cm.usage());
        logger.info("Training Path = " + dataOptions.trainingPath + ", Testing Path = " + dataOptions.testingPath);
        throw new ArgumentException("training-file", "test-file", "Must supply both training and testing data.");
    }
    Pair<Dataset<Label>, Dataset<Label>> data = dataOptions.load(factory);
    Dataset<Label> train = data.getA();
    logger.info("Training data has " + train.getFeatureIDMap().size() + " features.");
    Dataset<Label> test = data.getB();
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    Model<Label> model = trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
    final long testStart = System.currentTimeMillis();
    LabelEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    if (model.generatesProbabilities()) {
        logger.info("Average AUC = " + evaluation.averageAUCROC(false));
        logger.info("Average weighted AUC = " + evaluation.averageAUCROC(true));
    }
    System.out.println(evaluation.toString());
    System.out.println(evaluation.getConfusionMatrix().toString());
    if (dataOptions.outputPath != null) {
        dataOptions.saveModel(model);
    }
    return model;
}
Also used : LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) Dataset(org.tribuo.Dataset) ArgumentException(com.oracle.labs.mlrg.olcut.config.ArgumentException)

Example 3 with ArgumentException

use of com.oracle.labs.mlrg.olcut.config.ArgumentException in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    TrainTestOptions o = new TrainTestOptions();
    try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
        Trainer<Label> trainer = o.libLinearOptions.getTrainer();
        trainer = o.ensembleOptions.wrapTrainer(trainer);
        TrainTestHelper.run(cm, o.general, trainer);
    } catch (UsageException e) {
        System.out.println(e.getUsage());
    } catch (ArgumentException e) {
        System.out.println("Invalid argument: " + e.getMessage());
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) ArgumentException(com.oracle.labs.mlrg.olcut.config.ArgumentException) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 4 with ArgumentException

use of com.oracle.labs.mlrg.olcut.config.ArgumentException in project tribuo by oracle.

the class KernelSVMOptions method getTrainer.

@Override
public KernelSVMTrainer getTrainer() {
    logger.info("Configuring Kernel SVM Trainer");
    Kernel kernelObj = null;
    switch(kernelKernel) {
        case LINEAR:
            logger.info("Using a linear kernel");
            kernelObj = new Linear();
            break;
        case POLYNOMIAL:
            logger.info("Using a Polynomial kernel with gamma " + kernelGamma + ", intercept " + kernelIntercept + ", and degree " + kernelDegree);
            kernelObj = new Polynomial(kernelGamma, kernelIntercept, kernelDegree);
            break;
        case RBF:
            logger.info("Using an RBF kernel with gamma " + kernelGamma);
            kernelObj = new RBF(kernelGamma);
            break;
        case SIGMOID:
            logger.info("Using a tanh kernel with gamma " + kernelGamma + ", and intercept " + kernelIntercept);
            kernelObj = new Sigmoid(kernelGamma, kernelIntercept);
            break;
        default:
            logger.warning("Unknown kernel function " + kernelKernel);
            throw new ArgumentException("kernel-kernel", "Unknown kernel function " + kernelKernel);
    }
    logger.info(String.format("Set logging interval to %d", kernelLoggingInterval));
    return new KernelSVMTrainer(kernelObj, kernelLambda, kernelEpochs, kernelLoggingInterval, kernelSeed);
}
Also used : Polynomial(org.tribuo.math.kernel.Polynomial) RBF(org.tribuo.math.kernel.RBF) ArgumentException(com.oracle.labs.mlrg.olcut.config.ArgumentException) Kernel(org.tribuo.math.kernel.Kernel) Linear(org.tribuo.math.kernel.Linear) Sigmoid(org.tribuo.math.kernel.Sigmoid)

Aggregations

ArgumentException (com.oracle.labs.mlrg.olcut.config.ArgumentException)4 ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)2 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)2 Label (org.tribuo.classification.Label)2 Dataset (org.tribuo.Dataset)1 LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)1 Kernel (org.tribuo.math.kernel.Kernel)1 Linear (org.tribuo.math.kernel.Linear)1 Polynomial (org.tribuo.math.kernel.Polynomial)1 RBF (org.tribuo.math.kernel.RBF)1 Sigmoid (org.tribuo.math.kernel.Sigmoid)1