use of org.tribuo.classification.sequence.LabelSequenceEvaluation in project tribuo by oracle.
the class SeqTest method main.
/**
* @param args the command line arguments
* @throws ClassNotFoundException if it failed to load the model.
* @throws IOException if there is any error reading the examples.
*/
// deserialising a generic dataset.
@SuppressWarnings("unchecked")
public static void main(String[] args) throws ClassNotFoundException, IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
CRFOptions o = new CRFOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
logger.info("Configuring gradient optimiser");
StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
logger.info(String.format("Set logging interval to %d", o.loggingInterval));
SequenceDataset<Label> train;
SequenceDataset<Label> test;
switch(o.datasetName) {
case "Gorilla":
case "gorilla":
logger.info("Generating gorilla dataset");
train = SequenceDataGenerator.generateGorillaDataset(1);
test = SequenceDataGenerator.generateGorillaDataset(1);
break;
default:
if ((o.trainDataset != null) && (o.testDataset != null)) {
logger.info("Loading training data from " + o.trainDataset);
try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.trainDataset.toFile())));
ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.testDataset.toFile())))) {
train = (SequenceDataset<Label>) ois.readObject();
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features");
logger.info("Loading testing data from " + o.testDataset);
SequenceDataset<Label> deserTest = (SequenceDataset<Label>) oits.readObject();
test = ImmutableSequenceDataset.copyDataset(deserTest, train.getFeatureIDMap(), train.getOutputIDInfo());
logger.info(String.format("Loaded %d testing examples", test.size()));
}
} else {
logger.warning("Unknown dataset " + o.datasetName);
logger.info(cm.usage());
return;
}
}
SequenceTrainer<Label> trainer = new CRFTrainer(grad, o.epochs, o.loggingInterval, o.seed);
((CRFTrainer) trainer).setShuffle(o.shuffle);
switch(o.modelHashingAlgorithm) {
case NONE:
break;
case HC:
trainer = new HashingSequenceTrainer<>(trainer, new HashCodeHasher(o.modelHashingSalt));
break;
case SHA1:
trainer = new HashingSequenceTrainer<>(trainer, new MessageDigestHasher("SHA1", o.modelHashingSalt));
break;
case SHA256:
trainer = new HashingSequenceTrainer<>(trainer, new MessageDigestHasher("SHA-256", o.modelHashingSalt));
break;
default:
logger.info("Unknown hasher " + o.modelHashingAlgorithm);
}
logger.info("Training using " + trainer.toString());
final long trainStart = System.currentTimeMillis();
CRFModel model = (CRFModel) trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
if (o.logModel) {
System.out.println("FeatureMap = " + model.getFeatureIDMap().toString());
System.out.println("LabelMap = " + model.getOutputIDInfo().toString());
System.out.println("Features - " + model.generateWeightsString());
}
LabelSequenceEvaluator labelEvaluator = new LabelSequenceEvaluator();
final long testStart = System.currentTimeMillis();
LabelSequenceEvaluation evaluation = labelEvaluator.evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
System.out.println();
System.out.println(evaluation.getConfusionMatrix().toString());
if (o.outputPath != null) {
FileOutputStream fout = new FileOutputStream(o.outputPath.toFile());
ObjectOutputStream oout = new ObjectOutputStream(fout);
oout.writeObject(model);
oout.close();
fout.close();
logger.info("Serialized model to file: " + o.outputPath);
}
}
Aggregations