use of edu.stanford.nlp.ling.RVFDatum in project CoreNLP by stanfordnlp.
the class KBPStatisticalExtractor method main.
public static void main(String[] args) throws IOException, ClassNotFoundException {
// Disable SLF4J crap.
RedwoodConfiguration.standard().apply();
// Fill command-line options
ArgumentParser.fillOptions(KBPStatisticalExtractor.class, args);
// Load the test (or dev) data
forceTrack("Test data");
List<Pair<KBPInput, String>> testExamples = KBPRelationExtractor.readDataset(TEST_FILE);
log.info("Read " + testExamples.size() + " examples");
endTrack("Test data");
// If we can't find an existing model, train one
if (!IOUtils.existsInClasspathOrFileSystem(MODEL_FILE)) {
forceTrack("Training data");
List<Pair<KBPInput, String>> trainExamples = KBPRelationExtractor.readDataset(TRAIN_FILE);
log.info("Read " + trainExamples.size() + " examples");
log.info("" + trainExamples.stream().map(Pair::second).filter(NO_RELATION::equals).count() + " are " + NO_RELATION);
endTrack("Training data");
// Featurize + create the dataset
forceTrack("Creating dataset");
RVFDataset<String, String> dataset = new RVFDataset<>();
final AtomicInteger i = new AtomicInteger(0);
long beginTime = System.currentTimeMillis();
trainExamples.stream().parallel().forEach(example -> {
if (i.incrementAndGet() % 1000 == 0) {
log.info("[" + Redwood.formatTimeDifference(System.currentTimeMillis() - beginTime) + "] Featurized " + i.get() + " / " + trainExamples.size() + " examples");
}
Counter<String> features = features(example.first);
synchronized (dataset) {
dataset.add(new RVFDatum<>(features, example.second));
}
});
// Free up some memory
trainExamples.clear();
endTrack("Creating dataset");
// Train the classifier
log.info("Training classifier:");
Classifier<String, String> classifier = trainMultinomialClassifier(dataset, FEATURE_THRESHOLD, SIGMA);
// Free up some memory
dataset.clear();
// Save the classifier
IOUtils.writeObjectToFile(new KBPStatisticalExtractor(classifier), MODEL_FILE);
}
// Read either a newly-trained or pre-trained model
Object model = IOUtils.readObjectFromURLOrClasspathOrFileSystem(MODEL_FILE);
KBPStatisticalExtractor classifier;
if (model instanceof Classifier) {
//noinspection unchecked
classifier = new KBPStatisticalExtractor((Classifier<String, String>) model);
} else {
classifier = ((KBPStatisticalExtractor) model);
}
// Evaluate the model
classifier.computeAccuracy(testExamples.stream(), PREDICTIONS.map(x -> {
try {
return "stdout".equalsIgnoreCase(x) ? System.out : new PrintStream(new FileOutputStream(x));
} catch (IOException e) {
throw new RuntimeIOException(e);
}
}));
}
Aggregations