use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class SeqTrainTest method main.
/**
* @param args the command line arguments
* @throws ClassNotFoundException if it failed to load the model.
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws ClassNotFoundException, IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
SeqTrainTestOptions o = new SeqTrainTestOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
SequenceDataset<Label> train;
SequenceDataset<Label> test;
switch(o.datasetName) {
case "Gorilla":
case "gorilla":
logger.info("Generating gorilla dataset");
train = SequenceDataGenerator.generateGorillaDataset(1);
test = SequenceDataGenerator.generateGorillaDataset(1);
break;
default:
if ((o.trainDataset != null) && (o.testDataset != null)) {
logger.info("Loading training data from " + o.trainDataset);
try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.trainDataset.toFile())));
ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.testDataset.toFile())))) {
// deserialising a generic dataset.
@SuppressWarnings("unchecked") SequenceDataset<Label> tmpTrain = (SequenceDataset<Label>) ois.readObject();
train = tmpTrain;
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features");
logger.info("Loading testing data from " + o.testDataset);
// deserialising a generic dataset.
@SuppressWarnings("unchecked") SequenceDataset<Label> tmpTest = (SequenceDataset<Label>) oits.readObject();
test = tmpTest;
logger.info(String.format("Loaded %d testing examples", test.size()));
}
} else {
logger.warning("Unknown dataset " + o.datasetName);
logger.info(cm.usage());
return;
}
}
logger.info("Training using " + o.trainer.toString());
final long trainStart = System.currentTimeMillis();
SequenceModel<Label> model = o.trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
LabelSequenceEvaluator labelEvaluator = new LabelSequenceEvaluator();
final long testStart = System.currentTimeMillis();
LabelSequenceEvaluation evaluation = labelEvaluator.evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
System.out.println();
System.out.println(evaluation.getConfusionMatrix().toString());
if (o.outputPath != null) {
try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputPath.toFile()))) {
oos.writeObject(model);
logger.info("Serialized model to file: " + o.outputPath);
}
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class ConfigurableTrainTest method main.
/**
* @param args the command line arguments
*/
public static void main(String[] args) {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null || o.general.testingPath == null) {
logger.info(cm.usage());
System.exit(1);
}
Pair<Dataset<Label>, Dataset<Label>> data = null;
try {
data = o.general.load(new LabelFactory());
} catch (IOException e) {
logger.log(Level.SEVERE, "Failed to load data", e);
System.exit(1);
}
Dataset<Label> train = data.getA();
Dataset<Label> test = data.getB();
if (o.trainer == null) {
logger.warning("No trainer supplied");
logger.info(cm.usage());
System.exit(1);
}
logger.info("Trainer is " + o.trainer.toString());
if (o.weights != null) {
Map<Label, Float> weightsMap = processWeights(o.weights);
if (o.trainer instanceof WeightedLabels) {
((WeightedLabels) o.trainer).setLabelWeights(weightsMap);
logger.info("Setting label weights using " + weightsMap.toString());
} else if (o.trainer instanceof WeightedExamples) {
((MutableDataset<Label>) train).setWeights(weightsMap);
logger.info("Setting example weights using " + weightsMap.toString());
} else {
logger.warning("The selected trainer does not support weighted training. The chosen trainer is " + o.trainer.toString());
logger.info(cm.usage());
System.exit(1);
}
}
logger.info("Labels are " + train.getOutputInfo().toReadableString());
final long trainStart = System.currentTimeMillis();
Model<Label> model = o.trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
LabelEvaluator labelEvaluator = new LabelEvaluator();
final long testStart = System.currentTimeMillis();
List<Prediction<Label>> predictions = model.predict(test);
LabelEvaluation labelEvaluation = labelEvaluator.evaluate(model, predictions, test.getProvenance());
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(labelEvaluation.toString());
ConfusionMatrix<Label> matrix = labelEvaluation.getConfusionMatrix();
System.out.println(matrix.toString());
if (model.generatesProbabilities()) {
System.out.println("Average AUC = " + labelEvaluation.averageAUCROC(false));
System.out.println("Average weighted AUC = " + labelEvaluation.averageAUCROC(true));
}
if (o.predictionPath != null) {
try (BufferedWriter wrt = Files.newBufferedWriter(o.predictionPath)) {
List<String> labels = model.getOutputIDInfo().getDomain().stream().map(Label::getLabel).sorted().collect(Collectors.toList());
wrt.write("Label,");
wrt.write(String.join(",", labels));
wrt.newLine();
for (Prediction<Label> pred : predictions) {
Example<Label> ex = pred.getExample();
wrt.write(ex.getOutput().getLabel() + ",");
wrt.write(labels.stream().map(l -> Double.toString(pred.getOutputScores().get(l).getScore())).collect(Collectors.joining(",")));
wrt.newLine();
}
wrt.flush();
} catch (IOException e) {
logger.log(Level.SEVERE, "Error writing predictions", e);
}
}
if (o.general.outputPath != null) {
try {
o.general.saveModel(model);
} catch (IOException e) {
logger.log(Level.SEVERE, "Error writing model", e);
}
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class TrainTest method main.
/**
* CLI entry point.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
TensorflowOptions o = new TensorflowOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.trainingPath == null || o.testingPath == null) {
logger.info(cm.usage());
return;
}
Pair<Dataset<Label>, Dataset<Label>> data = load(o.trainingPath, o.testingPath, new LabelFactory());
Dataset<Label> train = data.getA();
Dataset<Label> test = data.getB();
if ((o.inputName == null || o.inputName.isEmpty()) || (o.outputName == null || o.outputName.isEmpty())) {
throw new IllegalArgumentException("Must specify both 'input-name' and 'output-name'");
}
FeatureConverter inputConverter;
switch(o.inputType) {
case IMAGE:
String[] splitFormat = o.imageFormat.split(",");
if (splitFormat.length != 3) {
logger.info(cm.usage());
logger.info("Invalid image format specified. Found " + o.imageFormat);
return;
}
int width = Integer.parseInt(splitFormat[0]);
int height = Integer.parseInt(splitFormat[1]);
int channels = Integer.parseInt(splitFormat[2]);
inputConverter = new ImageConverter(o.inputName, width, height, channels);
break;
case DENSE:
inputConverter = new DenseFeatureConverter(o.inputName);
break;
default:
logger.info(cm.usage());
logger.info("Unknown input type. Found " + o.inputType);
return;
}
OutputConverter<Label> labelConverter = new LabelConverter();
Trainer<Label> trainer;
if (o.checkpointPath == null) {
logger.info("Using TensorflowTrainer");
trainer = new TensorFlowTrainer<>(o.protobufPath, o.outputName, o.optimiser, o.getGradientParams(), inputConverter, labelConverter, o.batchSize, o.epochs, o.testBatchSize, o.loggingInterval);
} else {
logger.info("Using TensorflowCheckpointTrainer, writing to path " + o.checkpointPath);
trainer = new TensorFlowTrainer<>(o.protobufPath, o.outputName, o.optimiser, o.getGradientParams(), inputConverter, labelConverter, o.batchSize, o.epochs, o.testBatchSize, o.loggingInterval, o.checkpointPath);
}
logger.info("Training using " + trainer.toString());
final long trainStart = System.currentTimeMillis();
Model<Label> model = trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
final long testStart = System.currentTimeMillis();
LabelEvaluator evaluator = new LabelEvaluator();
LabelEvaluation evaluation = evaluator.evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
if (model.generatesProbabilities()) {
logger.info("Average AUC = " + evaluation.averageAUCROC(false));
logger.info("Average weighted AUC = " + evaluation.averageAUCROC(true));
}
System.out.println(evaluation.toString());
System.out.println(evaluation.getConfusionMatrix().toString());
if (o.outputPath != null) {
try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputPath.toFile()))) {
oos.writeObject(model);
}
logger.info("Serialized model to file: " + o.outputPath);
}
if (o.checkpointPath == null) {
((TensorFlowNativeModel<?>) model).close();
} else {
((TensorFlowCheckpointModel<?>) model).close();
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class StripProvenance method main.
/**
* Runs StripProvenance.
* @param args the command line arguments
* @param <T> The {@link Output} subclass.
*/
@SuppressWarnings("unchecked")
public static <T extends Output<T>> void main(String[] args) {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
StripProvenanceOptions o = new StripProvenanceOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.inputModel == null || o.outputModel == null) {
logger.info(cm.usage());
System.exit(1);
}
try (ObjectInputStream ois = IOUtil.getObjectInputStream(o.inputModel);
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputModel))) {
logger.info("Loading model from " + o.inputModel);
Model<T> input = (Model<T>) ois.readObject();
ModelProvenance oldProvenance = input.getProvenance();
logger.info("Marshalling provenance and creating JSON.");
JsonProvenanceSerialization jsonProvenanceSerialization = new JsonProvenanceSerialization(true);
String jsonResult = jsonProvenanceSerialization.marshalAndSerialize(oldProvenance);
logger.info("Hashing JSON file");
MessageDigest digest = o.hashType.getDigest();
byte[] digestBytes = digest.digest(jsonResult.getBytes(StandardCharsets.UTF_8));
String provenanceHash = ProvenanceUtil.bytesToHexString(digestBytes);
logger.info("Provenance hash = " + provenanceHash);
if (o.provenanceFile != null) {
logger.info("Writing JSON provenance to " + o.provenanceFile.toString());
try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(o.provenanceFile), StandardCharsets.UTF_8))) {
writer.println(jsonResult);
}
}
ModelTuple<T> tuple = convertModel(input, provenanceHash, o);
logger.info("Writing model to " + o.outputModel);
oos.writeObject(tuple.model);
ModelProvenance newProvenance = tuple.provenance;
logger.info("Marshalling provenance and creating JSON.");
String newJsonResult = jsonProvenanceSerialization.marshalAndSerialize(newProvenance);
logger.info("Old provenance = \n" + jsonResult);
logger.info("New provenance = \n" + newJsonResult);
} catch (NoSuchMethodException e) {
logger.log(Level.SEVERE, "Model.copy method missing on a class which extends Model.", e);
} catch (IllegalAccessException e) {
logger.log(Level.SEVERE, "Failed to modify protection on inner copy method on Model.", e);
} catch (InvocationTargetException e) {
logger.log(Level.SEVERE, "Failed to invoke inner copy method on Model.", e);
} catch (UnsupportedEncodingException e) {
logger.log(Level.SEVERE, "Unsupported encoding exception.", e);
} catch (FileNotFoundException e) {
logger.log(Level.SEVERE, "Failed to find the input file.", e);
} catch (IOException e) {
logger.log(Level.SEVERE, "IO error when reading or writing a file.", e);
} catch (ClassNotFoundException e) {
logger.log(Level.SEVERE, "The model and/or provenance classes are not on the classpath.", e);
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
HdbscanCLIOptions o = new HdbscanCLIOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null) {
logger.info(cm.usage());
return;
}
ClusteringFactory factory = new ClusteringFactory();
Pair<Dataset<ClusterID>, Dataset<ClusterID>> data = o.general.load(factory);
Dataset<ClusterID> train = data.getA();
HdbscanTrainer trainer = o.hdbscanOptions.getTrainer();
Model<ClusterID> model = trainer.train(train);
logger.info("Finished training model");
ClusteringEvaluation evaluation = factory.getEvaluator().evaluate(model, train);
logger.info("Finished evaluating model");
System.out.println("Normalized MI = " + evaluation.normalizedMI());
System.out.println("Adjusted MI = " + evaluation.adjustedMI());
if (o.general.outputPath != null) {
o.general.saveModel(model);
}
}
Aggregations