use of com.oracle.labs.mlrg.olcut.config.ArgumentException in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
TrainTestOptions o = new TrainTestOptions();
try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
Trainer<Label> trainer = o.libsvmOptions.getTrainer();
trainer = o.ensembleOptions.wrapTrainer(trainer);
Model<Label> model = TrainTestHelper.run(cm, o.general, trainer);
if (model instanceof LibSVMClassificationModel) {
logger.info("Used " + ((LibSVMClassificationModel) model).getNumberOfSupportVectors() + " support vectors");
}
} catch (UsageException e) {
System.out.println(e.getUsage());
} catch (ArgumentException e) {
System.out.println("Invalid argument: " + e.getMessage());
}
}
use of com.oracle.labs.mlrg.olcut.config.ArgumentException in project tribuo by oracle.
the class TrainTestHelper method run.
/**
* This method trains a model on the specified training data, and evaluates it
* on the specified test data. It writes out the timing to it's logger, and the
* statistical performance to standard out. If set, the model is written out
* to the specified path on disk.
* @param cm The configuration manager which knows the arguments.
* @param dataOptions The data options which specify the training and test data.
* @param trainer The trainer to use.
* @return The trained model.
* @throws IOException If the data failed to load.
*/
public static Model<Label> run(ConfigurationManager cm, DataOptions dataOptions, Trainer<Label> trainer) throws IOException {
LabsLogFormatter.setAllLogFormatters();
if (dataOptions.trainingPath == null || dataOptions.testingPath == null) {
logger.info(cm.usage());
logger.info("Training Path = " + dataOptions.trainingPath + ", Testing Path = " + dataOptions.testingPath);
throw new ArgumentException("training-file", "test-file", "Must supply both training and testing data.");
}
Pair<Dataset<Label>, Dataset<Label>> data = dataOptions.load(factory);
Dataset<Label> train = data.getA();
logger.info("Training data has " + train.getFeatureIDMap().size() + " features.");
Dataset<Label> test = data.getB();
logger.info("Training using " + trainer.toString());
final long trainStart = System.currentTimeMillis();
Model<Label> model = trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
final long testStart = System.currentTimeMillis();
LabelEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
if (model.generatesProbabilities()) {
logger.info("Average AUC = " + evaluation.averageAUCROC(false));
logger.info("Average weighted AUC = " + evaluation.averageAUCROC(true));
}
System.out.println(evaluation.toString());
System.out.println(evaluation.getConfusionMatrix().toString());
if (dataOptions.outputPath != null) {
dataOptions.saveModel(model);
}
return model;
}
use of com.oracle.labs.mlrg.olcut.config.ArgumentException in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
TrainTestOptions o = new TrainTestOptions();
try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
Trainer<Label> trainer = o.libLinearOptions.getTrainer();
trainer = o.ensembleOptions.wrapTrainer(trainer);
TrainTestHelper.run(cm, o.general, trainer);
} catch (UsageException e) {
System.out.println(e.getUsage());
} catch (ArgumentException e) {
System.out.println("Invalid argument: " + e.getMessage());
}
}
use of com.oracle.labs.mlrg.olcut.config.ArgumentException in project tribuo by oracle.
the class KernelSVMOptions method getTrainer.
@Override
public KernelSVMTrainer getTrainer() {
logger.info("Configuring Kernel SVM Trainer");
Kernel kernelObj = null;
switch(kernelKernel) {
case LINEAR:
logger.info("Using a linear kernel");
kernelObj = new Linear();
break;
case POLYNOMIAL:
logger.info("Using a Polynomial kernel with gamma " + kernelGamma + ", intercept " + kernelIntercept + ", and degree " + kernelDegree);
kernelObj = new Polynomial(kernelGamma, kernelIntercept, kernelDegree);
break;
case RBF:
logger.info("Using an RBF kernel with gamma " + kernelGamma);
kernelObj = new RBF(kernelGamma);
break;
case SIGMOID:
logger.info("Using a tanh kernel with gamma " + kernelGamma + ", and intercept " + kernelIntercept);
kernelObj = new Sigmoid(kernelGamma, kernelIntercept);
break;
default:
logger.warning("Unknown kernel function " + kernelKernel);
throw new ArgumentException("kernel-kernel", "Unknown kernel function " + kernelKernel);
}
logger.info(String.format("Set logging interval to %d", kernelLoggingInterval));
return new KernelSVMTrainer(kernelObj, kernelLambda, kernelEpochs, kernelLoggingInterval, kernelSeed);
}
Aggregations