use of edu.stanford.nlp.util.Lazy in project CoreNLP by stanfordnlp.
the class ForwardEntailerSearchProblem method searchImplementation.
/**
* The search algorithm, starting with a full sentence and iteratively shortening it to its entailed sentences.
*
* @return A list of search results, corresponding to shortenings of the sentence.
*/
@SuppressWarnings("unchecked")
private List<SearchResult> searchImplementation() {
// Pre-process the tree
SemanticGraph parseTree = new SemanticGraph(this.parseTree);
assert Util.isTree(parseTree);
// (remove common determiners)
List<String> determinerRemovals = new ArrayList<>();
parseTree.getLeafVertices().stream().filter(vertex -> "the".equalsIgnoreCase(vertex.word()) || "a".equalsIgnoreCase(vertex.word()) || "an".equalsIgnoreCase(vertex.word()) || "this".equalsIgnoreCase(vertex.word()) || "that".equalsIgnoreCase(vertex.word()) || "those".equalsIgnoreCase(vertex.word()) || "these".equalsIgnoreCase(vertex.word())).forEach(vertex -> {
parseTree.removeVertex(vertex);
assert Util.isTree(parseTree);
determinerRemovals.add("det");
});
// (cut conj_and nodes)
Set<SemanticGraphEdge> andsToAdd = new HashSet<>();
for (IndexedWord vertex : parseTree.vertexSet()) {
if (parseTree.inDegree(vertex) > 1) {
SemanticGraphEdge conjAnd = null;
for (SemanticGraphEdge edge : parseTree.incomingEdgeIterable(vertex)) {
if ("conj:and".equals(edge.getRelation().toString())) {
conjAnd = edge;
}
}
if (conjAnd != null) {
parseTree.removeEdge(conjAnd);
assert Util.isTree(parseTree);
andsToAdd.add(conjAnd);
}
}
}
// Clean the tree
Util.cleanTree(parseTree);
assert Util.isTree(parseTree);
// Find the subject / object split
// This takes max O(n^2) time, expected O(n*log(n)) time.
// Optimal is O(n), but I'm too lazy to implement it.
BitSet isSubject = new BitSet(256);
for (IndexedWord vertex : parseTree.vertexSet()) {
// Search up the tree for a subj node; if found, mark that vertex as a subject.
Iterator<SemanticGraphEdge> incomingEdges = parseTree.incomingEdgeIterator(vertex);
SemanticGraphEdge edge = null;
if (incomingEdges.hasNext()) {
edge = incomingEdges.next();
}
int numIters = 0;
while (edge != null) {
if (edge.getRelation().toString().endsWith("subj")) {
assert vertex.index() > 0;
isSubject.set(vertex.index() - 1);
break;
}
incomingEdges = parseTree.incomingEdgeIterator(edge.getGovernor());
if (incomingEdges.hasNext()) {
edge = incomingEdges.next();
} else {
edge = null;
}
numIters += 1;
if (numIters > 100) {
// log.error("tree has apparent depth > 100");
return Collections.EMPTY_LIST;
}
}
}
// Outputs
List<SearchResult> results = new ArrayList<>();
if (!determinerRemovals.isEmpty()) {
if (andsToAdd.isEmpty()) {
double score = Math.pow(weights.deletionProbability("det"), (double) determinerRemovals.size());
assert !Double.isNaN(score);
assert !Double.isInfinite(score);
results.add(new SearchResult(parseTree, determinerRemovals, score));
} else {
SemanticGraph treeWithAnds = new SemanticGraph(parseTree);
assert Util.isTree(treeWithAnds);
for (SemanticGraphEdge and : andsToAdd) {
treeWithAnds.addEdge(and.getGovernor(), and.getDependent(), and.getRelation(), Double.NEGATIVE_INFINITY, false);
}
assert Util.isTree(treeWithAnds);
results.add(new SearchResult(treeWithAnds, determinerRemovals, Math.pow(weights.deletionProbability("det"), (double) determinerRemovals.size())));
}
}
// Initialize the search
assert Util.isTree(parseTree);
List<IndexedWord> topologicalVertices;
try {
topologicalVertices = parseTree.topologicalSort();
} catch (IllegalStateException e) {
// log.info("Could not topologically sort the vertices! Using left-to-right traversal.");
topologicalVertices = parseTree.vertexListSorted();
}
if (topologicalVertices.isEmpty()) {
return results;
}
Stack<SearchState> fringe = new Stack<>();
fringe.push(new SearchState(new BitSet(256), 0, parseTree, null, null, 1.0));
// Start the search
int numTicks = 0;
while (!fringe.isEmpty()) {
// Overhead with popping a node.
if (numTicks >= maxTicks) {
return results;
}
numTicks += 1;
if (results.size() >= maxResults) {
return results;
}
SearchState state = fringe.pop();
assert state.score > 0.0;
IndexedWord currentWord = topologicalVertices.get(state.currentIndex);
// Push the case where we don't delete
int nextIndex = state.currentIndex + 1;
int numIters = 0;
while (nextIndex < topologicalVertices.size()) {
IndexedWord nextWord = topologicalVertices.get(nextIndex);
assert nextWord.index() > 0;
if (!state.deletionMask.get(nextWord.index() - 1)) {
fringe.push(new SearchState(state.deletionMask, nextIndex, state.tree, null, state, state.score));
break;
} else {
nextIndex += 1;
}
numIters += 1;
if (numIters > 10000) {
// log.error("logic error (apparent infinite loop); returning");
return results;
}
}
// Check if we can delete this subtree
boolean canDelete = !state.tree.getFirstRoot().equals(currentWord);
for (SemanticGraphEdge edge : state.tree.incomingEdgeIterable(currentWord)) {
if ("CD".equals(edge.getGovernor().tag())) {
canDelete = false;
} else {
// Get token information
CoreLabel token = edge.getDependent().backingLabel();
OperatorSpec operator;
NaturalLogicRelation lexicalRelation;
Polarity tokenPolarity = token.get(NaturalLogicAnnotations.PolarityAnnotation.class);
if (tokenPolarity == null) {
tokenPolarity = Polarity.DEFAULT;
}
// Get the relation for this deletion
if ((operator = token.get(NaturalLogicAnnotations.OperatorAnnotation.class)) != null) {
lexicalRelation = operator.instance.deleteRelation;
} else {
assert edge.getDependent().index() > 0;
lexicalRelation = NaturalLogicRelation.forDependencyDeletion(edge.getRelation().toString(), isSubject.get(edge.getDependent().index() - 1));
}
NaturalLogicRelation projectedRelation = tokenPolarity.projectLexicalRelation(lexicalRelation);
// Make sure this is a valid entailment
if (!projectedRelation.applyToTruthValue(truthOfPremise).isTrue()) {
canDelete = false;
}
}
}
if (canDelete) {
// Register the deletion
Lazy<Pair<SemanticGraph, BitSet>> treeWithDeletionsAndNewMask = Lazy.of(() -> {
SemanticGraph impl = new SemanticGraph(state.tree);
BitSet newMask = state.deletionMask;
for (IndexedWord vertex : state.tree.descendants(currentWord)) {
impl.removeVertex(vertex);
assert vertex.index() > 0;
newMask.set(vertex.index() - 1);
assert newMask.get(vertex.index() - 1);
}
return Pair.makePair(impl, newMask);
});
// Compute the score of the sentence
double newScore = state.score;
for (SemanticGraphEdge edge : state.tree.incomingEdgeIterable(currentWord)) {
double multiplier = weights.deletionProbability(edge, state.tree.outgoingEdgeIterable(edge.getGovernor()));
assert !Double.isNaN(multiplier);
assert !Double.isInfinite(multiplier);
newScore *= multiplier;
}
// Register the result
if (newScore > 0.0) {
SemanticGraph resultTree = new SemanticGraph(treeWithDeletionsAndNewMask.get().first);
andsToAdd.stream().filter(edge -> resultTree.containsVertex(edge.getGovernor()) && resultTree.containsVertex(edge.getDependent())).forEach(edge -> resultTree.addEdge(edge.getGovernor(), edge.getDependent(), edge.getRelation(), Double.NEGATIVE_INFINITY, false));
results.add(new SearchResult(resultTree, aggregateDeletedEdges(state, state.tree.incomingEdgeIterable(currentWord), determinerRemovals), newScore));
// Push the state with this subtree deleted
nextIndex = state.currentIndex + 1;
numIters = 0;
while (nextIndex < topologicalVertices.size()) {
IndexedWord nextWord = topologicalVertices.get(nextIndex);
BitSet newMask = treeWithDeletionsAndNewMask.get().second;
SemanticGraph treeWithDeletions = treeWithDeletionsAndNewMask.get().first;
if (!newMask.get(nextWord.index() - 1)) {
assert treeWithDeletions.containsVertex(topologicalVertices.get(nextIndex));
fringe.push(new SearchState(newMask, nextIndex, treeWithDeletions, null, state, newScore));
break;
} else {
nextIndex += 1;
}
numIters += 1;
if (numIters > 10000) {
// log.error("logic error (apparent infinite loop); returning");
return results;
}
}
}
}
}
// Return
return results;
}
Aggregations