use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class CompletelyConfigurableTrainTest method main.
/**
* @param args the command line arguments
* @param <T> The {@link Output} subclass.
*/
@SuppressWarnings("unchecked")
public static <T extends Output<T>> void main(String[] args) {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.trainSource == null || o.testSource == null) {
logger.info(cm.usage());
System.exit(1);
} else if (o.trainer == null) {
logger.warning("No trainer supplied");
logger.info(cm.usage());
System.exit(1);
}
Dataset<T> train = new MutableDataset<>((DataSource<T>) o.trainSource);
if (o.minCount > 0) {
logger.info("Removing features which occur fewer than " + o.minCount + " times.");
train = new MinimumCardinalityDataset<>(train, o.minCount);
}
Dataset<T> test = new MutableDataset<>((DataSource<T>) o.testSource);
if (o.transformationMap != null) {
o.trainer = new TransformTrainer<>(o.trainer, o.transformationMap);
}
logger.info("Trainer is " + o.trainer.getProvenance().toString());
logger.info("Outputs are " + train.getOutputInfo().toReadableString());
logger.info("Number of features: " + train.getFeatureMap().size());
final long trainStart = System.currentTimeMillis();
Model<T> model = ((Trainer<T>) o.trainer).train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
Evaluator<T, ? extends Evaluation<T>> evaluator = train.getOutputFactory().getEvaluator();
final long testStart = System.currentTimeMillis();
Evaluation<T> evaluation = evaluator.evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
if (o.outputPath != null) {
try (ObjectOutputStream oout = new ObjectOutputStream(new FileOutputStream(o.outputPath.toFile()))) {
oout.writeObject(model);
logger.info("Serialized model to file: " + o.outputPath);
} catch (IOException e) {
logger.log(Level.SEVERE, "Error writing model", e);
}
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class SeqTest method main.
/**
* @param args the command line arguments
* @throws ClassNotFoundException if it failed to load the model.
* @throws IOException if there is any error reading the examples.
*/
// deserialising a generic dataset.
@SuppressWarnings("unchecked")
public static void main(String[] args) throws ClassNotFoundException, IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
CRFOptions o = new CRFOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
logger.info("Configuring gradient optimiser");
StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
logger.info(String.format("Set logging interval to %d", o.loggingInterval));
SequenceDataset<Label> train;
SequenceDataset<Label> test;
switch(o.datasetName) {
case "Gorilla":
case "gorilla":
logger.info("Generating gorilla dataset");
train = SequenceDataGenerator.generateGorillaDataset(1);
test = SequenceDataGenerator.generateGorillaDataset(1);
break;
default:
if ((o.trainDataset != null) && (o.testDataset != null)) {
logger.info("Loading training data from " + o.trainDataset);
try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.trainDataset.toFile())));
ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.testDataset.toFile())))) {
train = (SequenceDataset<Label>) ois.readObject();
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features");
logger.info("Loading testing data from " + o.testDataset);
SequenceDataset<Label> deserTest = (SequenceDataset<Label>) oits.readObject();
test = ImmutableSequenceDataset.copyDataset(deserTest, train.getFeatureIDMap(), train.getOutputIDInfo());
logger.info(String.format("Loaded %d testing examples", test.size()));
}
} else {
logger.warning("Unknown dataset " + o.datasetName);
logger.info(cm.usage());
return;
}
}
SequenceTrainer<Label> trainer = new CRFTrainer(grad, o.epochs, o.loggingInterval, o.seed);
((CRFTrainer) trainer).setShuffle(o.shuffle);
switch(o.modelHashingAlgorithm) {
case NONE:
break;
case HC:
trainer = new HashingSequenceTrainer<>(trainer, new HashCodeHasher(o.modelHashingSalt));
break;
case SHA1:
trainer = new HashingSequenceTrainer<>(trainer, new MessageDigestHasher("SHA1", o.modelHashingSalt));
break;
case SHA256:
trainer = new HashingSequenceTrainer<>(trainer, new MessageDigestHasher("SHA-256", o.modelHashingSalt));
break;
default:
logger.info("Unknown hasher " + o.modelHashingAlgorithm);
}
logger.info("Training using " + trainer.toString());
final long trainStart = System.currentTimeMillis();
CRFModel model = (CRFModel) trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
if (o.logModel) {
System.out.println("FeatureMap = " + model.getFeatureIDMap().toString());
System.out.println("LabelMap = " + model.getOutputIDInfo().toString());
System.out.println("Features - " + model.generateWeightsString());
}
LabelSequenceEvaluator labelEvaluator = new LabelSequenceEvaluator();
final long testStart = System.currentTimeMillis();
LabelSequenceEvaluation evaluation = labelEvaluator.evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
System.out.println();
System.out.println(evaluation.getConfusionMatrix().toString());
if (o.outputPath != null) {
FileOutputStream fout = new FileOutputStream(o.outputPath.toFile());
ObjectOutputStream oout = new ObjectOutputStream(fout);
oout.writeObject(model);
oout.close();
fout.close();
logger.info("Serialized model to file: " + o.outputPath);
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
TrainTestOptions o = new TrainTestOptions();
try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
Trainer<Label> trainer = o.xgboostOptions.getTrainer();
TrainTestHelper.run(cm, o.general, trainer);
} catch (UsageException e) {
logger.info(e.getMessage());
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
TrainTestOptions o = new TrainTestOptions();
try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
Trainer<Label> trainer = o.mnbOptions.getTrainer();
trainer = o.ensembleOptions.wrapTrainer(trainer);
TrainTestHelper.run(cm, o.general, trainer);
} catch (UsageException e) {
logger.info(e.getMessage());
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
TrainTestOptions o = new TrainTestOptions();
try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
Trainer<Label> trainer = o.trainerOptions.getTrainer();
trainer = o.ensembleOptions.wrapTrainer(trainer);
TrainTestHelper.run(cm, o.general, trainer);
} catch (UsageException e) {
logger.info(e.getMessage());
}
}
Aggregations