use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
TrainTestOptions o = new TrainTestOptions();
try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
Trainer<Label> trainer = o.libLinearOptions.getTrainer();
trainer = o.ensembleOptions.wrapTrainer(trainer);
TrainTestHelper.run(cm, o.general, trainer);
} catch (UsageException e) {
System.out.println(e.getUsage());
} catch (ArgumentException e) {
System.out.println("Invalid argument: " + e.getMessage());
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class Test method main.
/**
* Runs the Test CLI.
* @param args the command line arguments
*/
public static void main(String[] args) {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
ConfigurableTestOptions o = new ConfigurableTestOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.modelPath == null || o.testingPath == null) {
logger.info(cm.usage());
System.exit(1);
}
Pair<Model<Label>, Dataset<Label>> loaded = null;
try {
loaded = load(o);
} catch (IOException e) {
logger.log(Level.SEVERE, "Failed to load model/data", e);
System.exit(1);
}
Model<Label> model = loaded.getA();
Dataset<Label> test = loaded.getB();
logger.info("Model is " + model.toString());
logger.info("Labels are " + model.getOutputIDInfo().toReadableString());
LabelEvaluator labelEvaluator = new LabelEvaluator();
final long testStart = System.currentTimeMillis();
List<Prediction<Label>> predictions = model.predict(test);
LabelEvaluation evaluation = labelEvaluator.evaluate(model, predictions, test.getProvenance());
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
System.out.println(evaluation.getConfusionMatrix().toString());
if (model.generatesProbabilities()) {
System.out.println("Average AUC = " + evaluation.averageAUCROC(false));
System.out.println("Average weighted AUC = " + evaluation.averageAUCROC(true));
}
if (o.predictionPath != null) {
try (BufferedWriter wrt = Files.newBufferedWriter(o.predictionPath)) {
List<String> labels = model.getOutputIDInfo().getDomain().stream().map(Label::getLabel).sorted().collect(Collectors.toList());
wrt.write("Label,");
wrt.write(String.join(",", labels));
wrt.newLine();
for (Prediction<Label> pred : predictions) {
Example<Label> ex = pred.getExample();
wrt.write(ex.getOutput().getLabel() + ",");
wrt.write(labels.stream().map(l -> Double.toString(pred.getOutputScores().get(l).getScore())).collect(Collectors.joining(",")));
wrt.newLine();
}
wrt.flush();
} catch (IOException e) {
logger.log(Level.SEVERE, "Error writing predictions", e);
}
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
AllClassificationOptions o = new AllClassificationOptions();
try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
Trainer<Label> trainer = o.trainerOptions.getTrainer();
TrainTestHelper.run(cm, o.general, trainer);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
LibLinearOptions o = new LibLinearOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null || o.general.testingPath == null) {
logger.info(cm.usage());
return;
}
RegressionFactory factory = new RegressionFactory();
Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
Dataset<Regressor> train = data.getA();
Dataset<Regressor> test = data.getB();
Trainer<Regressor> trainer = new LibLinearRegressionTrainer(new LinearRegressionType(o.algorithm), o.cost, o.maxIterations, o.terminationCriterion, o.epsilon);
logger.info("Training using " + trainer.toString());
final long trainStart = System.currentTimeMillis();
Model<Regressor> model = trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
final long testStart = System.currentTimeMillis();
RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
if (o.general.outputPath != null) {
o.general.saveModel(model);
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class TrainTest method main.
/**
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
FMRegressionOptions o = new FMRegressionOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null || o.general.testingPath == null) {
logger.info(cm.usage());
return;
}
logger.info("Configuring gradient optimiser");
RegressionObjective obj = null;
switch(o.loss) {
case ABSOLUTE:
obj = new AbsoluteLoss();
break;
case SQUARED:
obj = new SquaredLoss();
break;
case HUBER:
obj = new Huber();
break;
default:
logger.warning("Unknown objective function " + o.loss);
logger.info(cm.usage());
return;
}
StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
logger.info(String.format("Set logging interval to %d", o.loggingInterval));
RegressionFactory factory = new RegressionFactory();
Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
Dataset<Regressor> train = data.getA();
Dataset<Regressor> test = data.getB();
logger.info("Feature domain - " + train.getFeatureIDMap());
Trainer<Regressor> trainer = new FMRegressionTrainer(obj, grad, o.epochs, o.loggingInterval, o.minibatchSize, o.general.seed, o.factorSize, o.variance, o.standardise);
logger.info("Training using " + trainer.toString());
final long trainStart = System.currentTimeMillis();
Model<Regressor> model = trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
final long testStart = System.currentTimeMillis();
RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
if (o.general.outputPath != null) {
o.general.saveModel(model);
}
}
Aggregations