use of edu.stanford.nlp.semgraph.SemanticGraph in project CoreNLP by stanfordnlp.
the class ClauseSplitterSearchProblem method search.
/**
* The core implementation of the search.
*
* @param root The root word to search from. Traditionally, this is the root of the sentence.
* @param candidateFragments The callback for the resulting sentence fragments.
* This is a predicate of a triple of values.
* The return value of the predicate determines whether we should continue searching.
* The triple is a triple of
* <ol>
* <li>The log probability of the sentence fragment, according to the featurizer and the weights</li>
* <li>The features along the path to this fragment. The last element of this is the features from the most recent step.</li>
* <li>The sentence fragment. Because it is relatively expensive to compute the resulting tree, this is returned as a lazy {@link Supplier}.</li>
* </ol>
* @param classifier The classifier for whether an arc should be on the path to a clause split, a clause split itself, or neither.
* @param featurizer The featurizer to use. Make sure this matches the weights!
* @param actionSpace The action space we are allowed to take. Each action defines a means of splitting a clause on a dependency boundary.
*/
protected void search(// The root to search from
IndexedWord root, // The output specs
final Predicate<Triple<Double, List<Counter<String>>, Supplier<SentenceFragment>>> candidateFragments, // The learning specs
final Classifier<ClauseSplitter.ClauseClassifierLabel, String> classifier, Map<String, ? extends List<String>> hardCodedSplits, final Function<Triple<State, Action, State>, Counter<String>> featurizer, final Collection<Action> actionSpace, final int maxTicks) {
// (the fringe)
PriorityQueue<Pair<State, List<Counter<String>>>> fringe = new FixedPrioritiesPriorityQueue<>();
// (avoid duplicate work)
Set<IndexedWord> seenWords = new HashSet<>();
State firstState = new State(null, null, -9000, null, x -> {
}, // First state is implicitly "done"
true);
fringe.add(Pair.makePair(firstState, new ArrayList<>(0)), -0.0);
int ticks = 0;
while (!fringe.isEmpty()) {
if (++ticks > maxTicks) {
// log.info("WARNING! Timed out on search with " + ticks + " ticks");
return;
}
// Useful variables
double logProbSoFar = fringe.getPriority();
assert logProbSoFar <= 0.0;
Pair<State, List<Counter<String>>> lastStatePair = fringe.removeFirst();
State lastState = lastStatePair.first;
List<Counter<String>> featuresSoFar = lastStatePair.second;
IndexedWord rootWord = lastState.edge == null ? root : lastState.edge.getDependent();
// Register thunk
if (lastState.isDone) {
if (!candidateFragments.test(Triple.makeTriple(logProbSoFar, featuresSoFar, () -> {
SemanticGraph copy = new SemanticGraph(tree);
lastState.thunk.andThen(x -> {
for (IndexedWord newTreeRoot : x.getRoots()) {
if (newTreeRoot != null) {
for (SemanticGraphEdge extraEdge : extraEdgesByGovernor.get(newTreeRoot)) {
assert Util.isTree(x);
addSubtree(x, newTreeRoot, extraEdge.getRelation().toString(), tree, extraEdge.getDependent(), tree.getIncomingEdgesSorted(newTreeRoot));
assert Util.isTree(x);
}
}
}
}).accept(copy);
return new SentenceFragment(copy, assumedTruth, false);
}))) {
break;
}
}
// Find relevant auxilliary terms
SemanticGraphEdge subjOrNull = null;
SemanticGraphEdge objOrNull = null;
for (SemanticGraphEdge auxEdge : tree.outgoingEdgeIterable(rootWord)) {
String relString = auxEdge.getRelation().toString();
if (relString.contains("obj")) {
objOrNull = auxEdge;
} else if (relString.contains("subj")) {
subjOrNull = auxEdge;
}
}
// For each outgoing edge...
for (SemanticGraphEdge outgoingEdge : tree.outgoingEdgeIterable(rootWord)) {
// This fires if the governor is an indirect speech verb, and the outgoing edge is a ccomp
if (outgoingEdge.getRelation().toString().equals("ccomp") && ((outgoingEdge.getGovernor().lemma() != null && INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().lemma())) || INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().word()))) {
continue;
}
// Get some variables
String outgoingEdgeRelation = outgoingEdge.getRelation().toString();
List<String> forcedArcOrder = hardCodedSplits.get(outgoingEdgeRelation);
if (forcedArcOrder == null && outgoingEdgeRelation.contains(":")) {
forcedArcOrder = hardCodedSplits.get(outgoingEdgeRelation.substring(0, outgoingEdgeRelation.indexOf(":")) + ":*");
}
boolean doneForcedArc = false;
// For each action...
for (Action action : (forcedArcOrder == null ? actionSpace : orderActions(actionSpace, forcedArcOrder))) {
// Check the prerequisite
if (!action.prerequisitesMet(tree, outgoingEdge)) {
continue;
}
if (forcedArcOrder != null && doneForcedArc) {
break;
}
// 1. Compute the child state
Optional<State> candidate = action.applyTo(tree, lastState, outgoingEdge, subjOrNull, objOrNull);
if (candidate.isPresent()) {
double logProbability;
ClauseClassifierLabel bestLabel;
Counter<String> features = featurizer.apply(Triple.makeTriple(lastState, action, candidate.get()));
if (forcedArcOrder != null && !doneForcedArc) {
logProbability = 0.0;
bestLabel = ClauseClassifierLabel.CLAUSE_SPLIT;
doneForcedArc = true;
} else if (features.containsKey("__undocumented_junit_no_classifier")) {
logProbability = Double.NEGATIVE_INFINITY;
bestLabel = ClauseClassifierLabel.CLAUSE_INTERM;
} else {
Counter<ClauseClassifierLabel> scores = classifier.scoresOf(new RVFDatum<>(features));
if (scores.size() > 0) {
Counters.logNormalizeInPlace(scores);
}
String rel = outgoingEdge.getRelation().toString();
if ("nsubj".equals(rel) || "dobj".equals(rel)) {
// Always at least yield on nsubj and dobj
scores.remove(ClauseClassifierLabel.NOT_A_CLAUSE);
}
logProbability = Counters.max(scores, Double.NEGATIVE_INFINITY);
bestLabel = Counters.argmax(scores, (x, y) -> 0, ClauseClassifierLabel.CLAUSE_SPLIT);
}
if (bestLabel != ClauseClassifierLabel.NOT_A_CLAUSE) {
Pair<State, List<Counter<String>>> childState = Pair.makePair(candidate.get().withIsDone(bestLabel), new ArrayList<Counter<String>>(featuresSoFar) {
{
add(features);
}
});
// 2. Register the child state
if (!seenWords.contains(childState.first.edge.getDependent())) {
// log.info(" pushing " + action.signature() + " with " + argmax.first.edge);
fringe.add(childState, logProbability);
}
}
}
}
}
seenWords.add(rootWord);
}
// log.info("Search finished in " + ticks + " ticks and " + classifierEvals + " classifier evaluations.");
}
use of edu.stanford.nlp.semgraph.SemanticGraph in project CoreNLP by stanfordnlp.
the class CreateClauseDataset method process.
@Override
public void process(long id, Annotation doc) {
CoreMap sentence = doc.get(CoreAnnotations.SentencesAnnotation.class).get(0);
SemanticGraph depparse = sentence.get(SemanticGraphCoreAnnotations.BasicDependenciesAnnotation.class);
log.info("| " + sentence.get(CoreAnnotations.TextAnnotation.class));
// Get all valid subject spans
BitSet consumedAsSubjects = new BitSet();
@SuppressWarnings("MismatchedQueryAndUpdateOfCollection") List<Span> subjectSpans = new ArrayList<>();
NEXTNODE: for (IndexedWord head : depparse.topologicalSort()) {
// Check if the node is a noun/pronoun
if (head.tag().startsWith("N") || head.tag().equals("PRP")) {
// Try to get the NP chunk
Optional<List<IndexedWord>> subjectChunk = segmenter.getValidChunk(depparse, head, segmenter.VALID_SUBJECT_ARCS, Optional.empty(), true);
if (subjectChunk.isPresent()) {
// Make sure it's not already a member of a larger NP
for (IndexedWord tok : subjectChunk.get()) {
if (consumedAsSubjects.get(tok.index())) {
// Already considered. Continue to the next node.
continue NEXTNODE;
}
}
// Register it as an NP
for (IndexedWord tok : subjectChunk.get()) {
consumedAsSubjects.set(tok.index());
}
// Add it as a subject
subjectSpans.add(toSpan(subjectChunk.get()));
}
}
}
}
use of edu.stanford.nlp.semgraph.SemanticGraph in project CoreNLP by stanfordnlp.
the class ForwardEntailerSearchProblem method searchImplementation.
/**
* The search algorithm, starting with a full sentence and iteratively shortening it to its entailed sentences.
*
* @return A list of search results, corresponding to shortenings of the sentence.
*/
@SuppressWarnings("unchecked")
private List<SearchResult> searchImplementation() {
// Pre-process the tree
SemanticGraph parseTree = new SemanticGraph(this.parseTree);
assert Util.isTree(parseTree);
// (remove common determiners)
List<String> determinerRemovals = new ArrayList<>();
parseTree.getLeafVertices().stream().filter(vertex -> "the".equalsIgnoreCase(vertex.word()) || "a".equalsIgnoreCase(vertex.word()) || "an".equalsIgnoreCase(vertex.word()) || "this".equalsIgnoreCase(vertex.word()) || "that".equalsIgnoreCase(vertex.word()) || "those".equalsIgnoreCase(vertex.word()) || "these".equalsIgnoreCase(vertex.word())).forEach(vertex -> {
parseTree.removeVertex(vertex);
assert Util.isTree(parseTree);
determinerRemovals.add("det");
});
// (cut conj_and nodes)
Set<SemanticGraphEdge> andsToAdd = new HashSet<>();
for (IndexedWord vertex : parseTree.vertexSet()) {
if (parseTree.inDegree(vertex) > 1) {
SemanticGraphEdge conjAnd = null;
for (SemanticGraphEdge edge : parseTree.incomingEdgeIterable(vertex)) {
if ("conj:and".equals(edge.getRelation().toString())) {
conjAnd = edge;
}
}
if (conjAnd != null) {
parseTree.removeEdge(conjAnd);
assert Util.isTree(parseTree);
andsToAdd.add(conjAnd);
}
}
}
// Clean the tree
Util.cleanTree(parseTree);
assert Util.isTree(parseTree);
// Find the subject / object split
// This takes max O(n^2) time, expected O(n*log(n)) time.
// Optimal is O(n), but I'm too lazy to implement it.
BitSet isSubject = new BitSet(256);
for (IndexedWord vertex : parseTree.vertexSet()) {
// Search up the tree for a subj node; if found, mark that vertex as a subject.
Iterator<SemanticGraphEdge> incomingEdges = parseTree.incomingEdgeIterator(vertex);
SemanticGraphEdge edge = null;
if (incomingEdges.hasNext()) {
edge = incomingEdges.next();
}
int numIters = 0;
while (edge != null) {
if (edge.getRelation().toString().endsWith("subj")) {
assert vertex.index() > 0;
isSubject.set(vertex.index() - 1);
break;
}
incomingEdges = parseTree.incomingEdgeIterator(edge.getGovernor());
if (incomingEdges.hasNext()) {
edge = incomingEdges.next();
} else {
edge = null;
}
numIters += 1;
if (numIters > 100) {
// log.error("tree has apparent depth > 100");
return Collections.EMPTY_LIST;
}
}
}
// Outputs
List<SearchResult> results = new ArrayList<>();
if (!determinerRemovals.isEmpty()) {
if (andsToAdd.isEmpty()) {
double score = Math.pow(weights.deletionProbability("det"), (double) determinerRemovals.size());
assert !Double.isNaN(score);
assert !Double.isInfinite(score);
results.add(new SearchResult(parseTree, determinerRemovals, score));
} else {
SemanticGraph treeWithAnds = new SemanticGraph(parseTree);
assert Util.isTree(treeWithAnds);
for (SemanticGraphEdge and : andsToAdd) {
treeWithAnds.addEdge(and.getGovernor(), and.getDependent(), and.getRelation(), Double.NEGATIVE_INFINITY, false);
}
assert Util.isTree(treeWithAnds);
results.add(new SearchResult(treeWithAnds, determinerRemovals, Math.pow(weights.deletionProbability("det"), (double) determinerRemovals.size())));
}
}
// Initialize the search
assert Util.isTree(parseTree);
List<IndexedWord> topologicalVertices;
try {
topologicalVertices = parseTree.topologicalSort();
} catch (IllegalStateException e) {
// log.info("Could not topologically sort the vertices! Using left-to-right traversal.");
topologicalVertices = parseTree.vertexListSorted();
}
if (topologicalVertices.isEmpty()) {
return results;
}
Stack<SearchState> fringe = new Stack<>();
fringe.push(new SearchState(new BitSet(256), 0, parseTree, null, null, 1.0));
// Start the search
int numTicks = 0;
while (!fringe.isEmpty()) {
// Overhead with popping a node.
if (numTicks >= maxTicks) {
return results;
}
numTicks += 1;
if (results.size() >= maxResults) {
return results;
}
SearchState state = fringe.pop();
assert state.score > 0.0;
IndexedWord currentWord = topologicalVertices.get(state.currentIndex);
// Push the case where we don't delete
int nextIndex = state.currentIndex + 1;
int numIters = 0;
while (nextIndex < topologicalVertices.size()) {
IndexedWord nextWord = topologicalVertices.get(nextIndex);
assert nextWord.index() > 0;
if (!state.deletionMask.get(nextWord.index() - 1)) {
fringe.push(new SearchState(state.deletionMask, nextIndex, state.tree, null, state, state.score));
break;
} else {
nextIndex += 1;
}
numIters += 1;
if (numIters > 10000) {
// log.error("logic error (apparent infinite loop); returning");
return results;
}
}
// Check if we can delete this subtree
boolean canDelete = !state.tree.getFirstRoot().equals(currentWord);
for (SemanticGraphEdge edge : state.tree.incomingEdgeIterable(currentWord)) {
if ("CD".equals(edge.getGovernor().tag())) {
canDelete = false;
} else {
// Get token information
CoreLabel token = edge.getDependent().backingLabel();
OperatorSpec operator;
NaturalLogicRelation lexicalRelation;
Polarity tokenPolarity = token.get(NaturalLogicAnnotations.PolarityAnnotation.class);
if (tokenPolarity == null) {
tokenPolarity = Polarity.DEFAULT;
}
// Get the relation for this deletion
if ((operator = token.get(NaturalLogicAnnotations.OperatorAnnotation.class)) != null) {
lexicalRelation = operator.instance.deleteRelation;
} else {
assert edge.getDependent().index() > 0;
lexicalRelation = NaturalLogicRelation.forDependencyDeletion(edge.getRelation().toString(), isSubject.get(edge.getDependent().index() - 1));
}
NaturalLogicRelation projectedRelation = tokenPolarity.projectLexicalRelation(lexicalRelation);
// Make sure this is a valid entailment
if (!projectedRelation.applyToTruthValue(truthOfPremise).isTrue()) {
canDelete = false;
}
}
}
if (canDelete) {
// Register the deletion
Lazy<Pair<SemanticGraph, BitSet>> treeWithDeletionsAndNewMask = Lazy.of(() -> {
SemanticGraph impl = new SemanticGraph(state.tree);
BitSet newMask = state.deletionMask;
for (IndexedWord vertex : state.tree.descendants(currentWord)) {
impl.removeVertex(vertex);
assert vertex.index() > 0;
newMask.set(vertex.index() - 1);
assert newMask.get(vertex.index() - 1);
}
return Pair.makePair(impl, newMask);
});
// Compute the score of the sentence
double newScore = state.score;
for (SemanticGraphEdge edge : state.tree.incomingEdgeIterable(currentWord)) {
double multiplier = weights.deletionProbability(edge, state.tree.outgoingEdgeIterable(edge.getGovernor()));
assert !Double.isNaN(multiplier);
assert !Double.isInfinite(multiplier);
newScore *= multiplier;
}
// Register the result
if (newScore > 0.0) {
SemanticGraph resultTree = new SemanticGraph(treeWithDeletionsAndNewMask.get().first);
andsToAdd.stream().filter(edge -> resultTree.containsVertex(edge.getGovernor()) && resultTree.containsVertex(edge.getDependent())).forEach(edge -> resultTree.addEdge(edge.getGovernor(), edge.getDependent(), edge.getRelation(), Double.NEGATIVE_INFINITY, false));
results.add(new SearchResult(resultTree, aggregateDeletedEdges(state, state.tree.incomingEdgeIterable(currentWord), determinerRemovals), newScore));
// Push the state with this subtree deleted
nextIndex = state.currentIndex + 1;
numIters = 0;
while (nextIndex < topologicalVertices.size()) {
IndexedWord nextWord = topologicalVertices.get(nextIndex);
BitSet newMask = treeWithDeletionsAndNewMask.get().second;
SemanticGraph treeWithDeletions = treeWithDeletionsAndNewMask.get().first;
if (!newMask.get(nextWord.index() - 1)) {
assert treeWithDeletions.containsVertex(topologicalVertices.get(nextIndex));
fringe.push(new SearchState(newMask, nextIndex, treeWithDeletions, null, state, newScore));
break;
} else {
nextIndex += 1;
}
numIters += 1;
if (numIters > 10000) {
// log.error("logic error (apparent infinite loop); returning");
return results;
}
}
}
}
}
// Return
return results;
}
use of edu.stanford.nlp.semgraph.SemanticGraph in project CoreNLP by stanfordnlp.
the class NaturalLogicAnnotator method getGeneralizedSubtreeSpan.
/** A helper method for
* {@link NaturalLogicAnnotator#getModifierSubtreeSpan(edu.stanford.nlp.semgraph.SemanticGraph, edu.stanford.nlp.ling.IndexedWord)} and
* {@link NaturalLogicAnnotator#getSubtreeSpan(edu.stanford.nlp.semgraph.SemanticGraph, edu.stanford.nlp.ling.IndexedWord)}.
*/
private static Pair<Integer, Integer> getGeneralizedSubtreeSpan(SemanticGraph tree, IndexedWord root, Set<String> validArcs) {
int min = root.index();
int max = root.index();
Queue<IndexedWord> fringe = new LinkedList<>();
for (SemanticGraphEdge edge : tree.outgoingEdgeIterable(root)) {
String edgeLabel = edge.getRelation().getShortName();
if ((validArcs == null || validArcs.contains(edgeLabel)) && !"punct".equals(edgeLabel)) {
fringe.add(edge.getDependent());
}
}
while (!fringe.isEmpty()) {
IndexedWord node = fringe.poll();
min = Math.min(node.index(), min);
max = Math.max(node.index(), max);
// ignore punctuation
fringe.addAll(tree.getOutEdgesSorted(node).stream().filter(edge -> edge.getGovernor().equals(node) && !(edge.getGovernor().equals(edge.getDependent())) && !"punct".equals(edge.getRelation().getShortName())).map(SemanticGraphEdge::getDependent).collect(Collectors.toList()));
}
return Pair.makePair(min, max + 1);
}
use of edu.stanford.nlp.semgraph.SemanticGraph in project CoreNLP by stanfordnlp.
the class ClauseSplitter method train.
/**
* Train a clause searcher factory. That is, train a classifier for which arcs should be
* new clauses.
*
* @param trainingData The training data. This is a stream of triples of:
* <ol>
* <li>The sentence containing a known extraction.</li>
* <li>The span of the subject in the sentence, as a token span.</li>
* <li>The span of the object in the sentence, as a token span.</li>
* </ol>
* @param modelPath The path to save the model to. This is useful for {@link ClauseSplitter#load(String)}.
* @param trainingDataDump The path to save the training data, as a set of labeled featurized datums.
* @param featurizer The featurizer to use for this classifier.
*
* @return A factory for creating searchers from a given dependency tree.
*/
static ClauseSplitter train(Stream<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData, Optional<File> modelPath, Optional<File> trainingDataDump, Featurizer featurizer) {
// Parse options
LinearClassifierFactory<ClauseClassifierLabel, String> factory = new LinearClassifierFactory<>();
// Generally useful objects
OpenIE openie = new OpenIE(PropertiesUtils.asProperties("splitter.nomodel", "true", "optimizefor", "GENERAL"));
WeightedDataset<ClauseClassifierLabel, String> dataset = new WeightedDataset<>();
AtomicInteger numExamplesProcessed = new AtomicInteger(0);
final Optional<PrintWriter> datasetDumpWriter = trainingDataDump.map(file -> {
try {
return new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(trainingDataDump.get()))));
} catch (IOException e) {
throw new RuntimeIOException(e);
}
});
// Step 1: Loop over data
forceTrack("Training inference");
trainingData.forEach(rawExample -> {
CoreMap sentence = rawExample.first;
Collection<Pair<Span, Span>> spans = rawExample.second;
List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
SemanticGraph tree = sentence.get(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class);
ClauseSplitterSearchProblem problem = new ClauseSplitterSearchProblem(tree, true);
problem.search(fragmentAndScore -> {
List<Counter<String>> features = fragmentAndScore.second;
SentenceFragment fragment = fragmentAndScore.third.get();
Set<RelationTriple> extractions = new HashSet<>(openie.relationsInFragments(openie.entailmentsFromClause(fragment)));
Trilean correct = Trilean.FALSE;
RELATION_TRIPLE_LOOP: for (RelationTriple extraction : extractions) {
Span subjectGuess = Span.fromValues(extraction.subject.get(0).index() - 1, extraction.subject.get(extraction.subject.size() - 1).index());
Span objectGuess = Span.fromValues(extraction.object.get(0).index() - 1, extraction.object.get(extraction.object.size() - 1).index());
for (Pair<Span, Span> candidateGold : spans) {
Span subjectSpan = candidateGold.first;
Span objectSpan = candidateGold.second;
if ((subjectGuess.equals(subjectSpan) && objectGuess.equals(objectSpan)) || (subjectGuess.equals(objectSpan) && objectGuess.equals(subjectSpan))) {
correct = Trilean.TRUE;
break RELATION_TRIPLE_LOOP;
} else if (Util.nerOverlap(tokens, subjectSpan, subjectGuess) && Util.nerOverlap(tokens, objectSpan, objectGuess) || Util.nerOverlap(tokens, subjectSpan, objectGuess) && Util.nerOverlap(tokens, objectSpan, subjectGuess)) {
if (!correct.isTrue()) {
correct = Trilean.TRUE;
break RELATION_TRIPLE_LOOP;
}
} else {
if (!correct.isTrue()) {
correct = Trilean.UNKNOWN;
break RELATION_TRIPLE_LOOP;
}
}
}
}
if (!features.isEmpty()) {
List<Pair<Counter<String>, ClauseClassifierLabel>> decisionsToAddAsDatums = new ArrayList<>();
if (correct.isTrue()) {
for (int i = 0; i < features.size(); ++i) {
if (i == features.size() - 1) {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
} else {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
}
}
} else if (correct.isFalse()) {
decisionsToAddAsDatums.add(Pair.makePair(features.get(features.size() - 1), ClauseClassifierLabel.NOT_A_CLAUSE));
} else if (correct.isUnknown()) {
boolean isSimpleSplit = false;
for (Counter<String> feats : features) {
if (featurizer.isSimpleSplit(feats)) {
isSimpleSplit = true;
break;
}
}
if (isSimpleSplit) {
for (int i = 0; i < features.size(); ++i) {
if (i == features.size() - 1) {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
} else {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
}
}
}
}
for (Pair<Counter<String>, ClauseClassifierLabel> decision : decisionsToAddAsDatums) {
RVFDatum<ClauseClassifierLabel, String> datum = new RVFDatum<>(decision.first);
datum.setLabel(decision.second);
if (datasetDumpWriter.isPresent()) {
datasetDumpWriter.get().println(decision.second + "\t" + StringUtils.join(decision.first.entrySet().stream().map(entry -> entry.getKey() + "->" + entry.getValue()), ";"));
}
dataset.add(datum);
}
}
return true;
}, new LinearClassifier<>(new ClassicCounter<>()), Collections.emptyMap(), featurizer, 10000);
if (numExamplesProcessed.incrementAndGet() % 100 == 0) {
log("processed " + numExamplesProcessed + " training sentences: " + dataset.size() + " datums");
}
});
endTrack("Training inference");
// Close the file
if (datasetDumpWriter.isPresent()) {
datasetDumpWriter.get().close();
}
// Step 2: Train classifier
forceTrack("Training");
Classifier<ClauseClassifierLabel, String> fullClassifier = factory.trainClassifier(dataset);
endTrack("Training");
if (modelPath.isPresent()) {
Pair<Classifier<ClauseClassifierLabel, String>, Featurizer> toSave = Pair.makePair(fullClassifier, featurizer);
try {
IOUtils.writeObjectToFile(toSave, modelPath.get());
log("SUCCESS: wrote model to " + modelPath.get().getPath());
} catch (IOException e) {
log("ERROR: failed to save model to path: " + modelPath.get().getPath());
err(e);
}
}
// Step 3: Check accuracy of classifier
forceTrack("Training accuracy");
dataset.randomize(42L);
Util.dumpAccuracy(fullClassifier, dataset);
endTrack("Training accuracy");
int numFolds = 5;
forceTrack(numFolds + " fold cross-validation");
for (int fold = 0; fold < numFolds; ++fold) {
forceTrack("Fold " + (fold + 1));
forceTrack("Training");
Pair<GeneralDataset<ClauseClassifierLabel, String>, GeneralDataset<ClauseClassifierLabel, String>> foldData = dataset.splitOutFold(fold, numFolds);
Classifier<ClauseClassifierLabel, String> classifier = factory.trainClassifier(foldData.first);
endTrack("Training");
forceTrack("Test");
Util.dumpAccuracy(classifier, foldData.second);
endTrack("Test");
endTrack("Fold " + (fold + 1));
}
endTrack(numFolds + " fold cross-validation");
// Step 5: return factory
return (tree, truth) -> new ClauseSplitterSearchProblem(tree, truth, Optional.of(fullClassifier), Optional.of(featurizer));
}
Aggregations