use of edu.stanford.nlp.trees.Tree in project CoreNLP by stanfordnlp.
the class SplittingGrammarExtractor method countMergeEffects.
public void countMergeEffects(Tree tree, Map<String, double[]> totalStateMass, Map<String, double[]> deltaAnnotations, IdentityHashMap<Tree, double[]> probIn, IdentityHashMap<Tree, double[]> probOut) {
if (tree.isLeaf()) {
return;
}
if (tree.label().value().equals(Lexicon.BOUNDARY_TAG)) {
return;
}
String label = tree.label().value();
double totalMass = 0.0;
double[] stateMass = totalStateMass.get(label);
for (double mass : stateMass) {
totalMass += mass;
}
double[] nodeProbIn = probIn.get(tree);
double[] nodeProbOut = probOut.get(tree);
double[] nodeDelta = deltaAnnotations.get(label);
if (nodeDelta == null) {
nodeDelta = new double[nodeProbIn.length / 2];
deltaAnnotations.put(label, nodeDelta);
}
for (int i = 0; i < nodeProbIn.length / 2; ++i) {
double probInMerged = SloppyMath.logAdd(Math.log(stateMass[i * 2] / totalMass) + nodeProbIn[i * 2], Math.log(stateMass[i * 2 + 1] / totalMass) + nodeProbIn[i * 2 + 1]);
double probOutMerged = SloppyMath.logAdd(nodeProbOut[i * 2], nodeProbOut[i * 2 + 1]);
double probMerged = probInMerged + probOutMerged;
double probUnmerged = SloppyMath.logAdd(nodeProbIn[i * 2] + nodeProbOut[i * 2], nodeProbIn[i * 2 + 1] + nodeProbOut[i * 2 + 1]);
nodeDelta[i] = nodeDelta[i] + probMerged - probUnmerged;
}
if (tree.isPreTerminal()) {
return;
}
for (Tree child : tree.children()) {
countMergeEffects(child, totalStateMass, deltaAnnotations, probIn, probOut);
}
}
use of edu.stanford.nlp.trees.Tree in project CoreNLP by stanfordnlp.
the class SplittingGrammarExtractor method mergeStates.
public void mergeStates() {
if (op.trainOptions.splitRecombineRate <= 0.0) {
return;
}
// we go through the machinery to sum up the temporary betas,
// counting the total mass
TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<>();
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<>();
Map<String, double[]> totalStateMass = Generics.newHashMap();
recalculateTemporaryBetas(false, totalStateMass, tempUnaryBetas, tempBinaryBetas);
// Next, for each tree we count the effect of merging its
// annotations. We only consider the most recently split
// annotations as candidates for merging.
Map<String, double[]> deltaAnnotations = Generics.newHashMap();
for (Tree tree : trees) {
countMergeEffects(tree, totalStateMass, deltaAnnotations);
}
// Now we have a map of the (approximate) likelihood loss from
// merging each state. We merge the ones that provide the least
// benefit, up to the splitRecombineRate
List<Triple<String, Integer, Double>> sortedDeltas = new ArrayList<>();
for (String state : deltaAnnotations.keySet()) {
double[] scores = deltaAnnotations.get(state);
for (int i = 0; i < scores.length; ++i) {
sortedDeltas.add(new Triple<>(state, i * 2, scores[i]));
}
}
Collections.sort(sortedDeltas, new Comparator<Triple<String, Integer, Double>>() {
public int compare(Triple<String, Integer, Double> first, Triple<String, Integer, Double> second) {
// "backwards", sorting from high to low.
return Double.compare(second.third(), first.third());
}
public boolean equals(Object o) {
return o == this;
}
});
// for (Triple<String, Integer, Double> delta : sortedDeltas) {
// System.out.println(delta.first() + "-" + delta.second() + ": " + delta.third());
// }
// System.out.println("-------------");
// Only merge a fraction of the splits based on what the user
// originally asked for
int splitsToMerge = (int) (sortedDeltas.size() * op.trainOptions.splitRecombineRate);
splitsToMerge = Math.max(0, splitsToMerge);
splitsToMerge = Math.min(sortedDeltas.size() - 1, splitsToMerge);
sortedDeltas = sortedDeltas.subList(0, splitsToMerge);
System.out.println();
System.out.println(sortedDeltas);
Map<String, int[]> mergeCorrespondence = buildMergeCorrespondence(sortedDeltas);
recalculateMergedBetas(mergeCorrespondence);
for (Triple<String, Integer, Double> delta : sortedDeltas) {
stateSplitCounts.decrementCount(delta.first(), 1);
}
}
use of edu.stanford.nlp.trees.Tree in project CoreNLP by stanfordnlp.
the class SplittingGrammarExtractor method outputTransitions.
public void outputTransitions(Tree tree, int depth, IdentityHashMap<Tree, double[][]> unaryTransitions, IdentityHashMap<Tree, double[][][]> binaryTransitions) {
for (int i = 0; i < depth; ++i) {
System.out.print(" ");
}
if (tree.isLeaf()) {
System.out.println(tree.label().value());
return;
}
if (tree.children().length == 1) {
System.out.println(tree.label().value() + " -> " + tree.children()[0].label().value());
if (!tree.isPreTerminal()) {
double[][] transitions = unaryTransitions.get(tree);
for (int i = 0; i < transitions.length; ++i) {
for (int j = 0; j < transitions[0].length; ++j) {
for (int z = 0; z < depth; ++z) {
System.out.print(" ");
}
System.out.println(" " + i + "," + j + ": " + transitions[i][j] + " | " + Math.exp(transitions[i][j]));
}
}
}
} else {
System.out.println(tree.label().value() + " -> " + tree.children()[0].label().value() + " " + tree.children()[1].label().value());
double[][][] transitions = binaryTransitions.get(tree);
for (int i = 0; i < transitions.length; ++i) {
for (int j = 0; j < transitions[0].length; ++j) {
for (int k = 0; k < transitions[0][0].length; ++k) {
for (int z = 0; z < depth; ++z) {
System.out.print(" ");
}
System.out.println(" " + i + "," + j + "," + k + ": " + transitions[i][j][k] + " | " + Math.exp(transitions[i][j][k]));
}
}
}
}
if (tree.isPreTerminal()) {
return;
}
for (Tree child : tree.children()) {
outputTransitions(child, depth + 1, unaryTransitions, binaryTransitions);
}
}
use of edu.stanford.nlp.trees.Tree in project CoreNLP by stanfordnlp.
the class SplittingGrammarExtractor method saveTrees.
public void saveTrees(Collection<Tree> trees1, double weight1, Collection<Tree> trees2, double weight2) {
trainSize = 0.0;
int treeCount = 0;
trees.clear();
treeWeights.clear();
for (Tree tree : trees1) {
trees.add(tree);
treeWeights.incrementCount(tree, weight1);
trainSize += weight1;
}
treeCount += trees1.size();
if (trees2 != null && weight2 >= 0.0) {
for (Tree tree : trees2) {
trees.add(tree);
treeWeights.incrementCount(tree, weight2);
trainSize += weight2;
}
treeCount += trees2.size();
}
log.info("Found " + treeCount + " trees with total weight " + trainSize);
}
use of edu.stanford.nlp.trees.Tree in project CoreNLP by stanfordnlp.
the class SplittingGrammarExtractor method countMergeEffects.
public void countMergeEffects(Tree tree, Map<String, double[]> totalStateMass, Map<String, double[]> deltaAnnotations) {
IdentityHashMap<Tree, double[]> probIn = new IdentityHashMap<>();
IdentityHashMap<Tree, double[]> probOut = new IdentityHashMap<>();
IdentityHashMap<Tree, double[][]> unaryTransitions = new IdentityHashMap<>();
IdentityHashMap<Tree, double[][][]> binaryTransitions = new IdentityHashMap<>();
recountTree(tree, false, probIn, probOut, unaryTransitions, binaryTransitions);
// no need to count the root
for (Tree child : tree.children()) {
countMergeEffects(child, totalStateMass, deltaAnnotations, probIn, probOut);
}
}
Aggregations