use of edu.stanford.nlp.util.ThreeDimensionalMap in project CoreNLP by stanfordnlp.
the class SplittingGrammarExtractor method mergeStates.
public void mergeStates() {
if (op.trainOptions.splitRecombineRate <= 0.0) {
return;
}
// we go through the machinery to sum up the temporary betas,
// counting the total mass
TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<>();
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<>();
Map<String, double[]> totalStateMass = Generics.newHashMap();
recalculateTemporaryBetas(false, totalStateMass, tempUnaryBetas, tempBinaryBetas);
// Next, for each tree we count the effect of merging its
// annotations. We only consider the most recently split
// annotations as candidates for merging.
Map<String, double[]> deltaAnnotations = Generics.newHashMap();
for (Tree tree : trees) {
countMergeEffects(tree, totalStateMass, deltaAnnotations);
}
// Now we have a map of the (approximate) likelihood loss from
// merging each state. We merge the ones that provide the least
// benefit, up to the splitRecombineRate
List<Triple<String, Integer, Double>> sortedDeltas = new ArrayList<>();
for (String state : deltaAnnotations.keySet()) {
double[] scores = deltaAnnotations.get(state);
for (int i = 0; i < scores.length; ++i) {
sortedDeltas.add(new Triple<>(state, i * 2, scores[i]));
}
}
Collections.sort(sortedDeltas, new Comparator<Triple<String, Integer, Double>>() {
public int compare(Triple<String, Integer, Double> first, Triple<String, Integer, Double> second) {
// "backwards", sorting from high to low.
return Double.compare(second.third(), first.third());
}
public boolean equals(Object o) {
return o == this;
}
});
// for (Triple<String, Integer, Double> delta : sortedDeltas) {
// System.out.println(delta.first() + "-" + delta.second() + ": " + delta.third());
// }
// System.out.println("-------------");
// Only merge a fraction of the splits based on what the user
// originally asked for
int splitsToMerge = (int) (sortedDeltas.size() * op.trainOptions.splitRecombineRate);
splitsToMerge = Math.max(0, splitsToMerge);
splitsToMerge = Math.min(sortedDeltas.size() - 1, splitsToMerge);
sortedDeltas = sortedDeltas.subList(0, splitsToMerge);
System.out.println();
System.out.println(sortedDeltas);
Map<String, int[]> mergeCorrespondence = buildMergeCorrespondence(sortedDeltas);
recalculateMergedBetas(mergeCorrespondence);
for (Triple<String, Integer, Double> delta : sortedDeltas) {
stateSplitCounts.decrementCount(delta.first(), 1);
}
}
use of edu.stanford.nlp.util.ThreeDimensionalMap in project CoreNLP by stanfordnlp.
the class SplittingGrammarExtractor method recalculateMergedBetas.
public void recalculateMergedBetas(Map<String, int[]> mergeCorrespondence) {
TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<>();
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<>();
tempWordIndex = new HashIndex<>();
tempTagIndex = new HashIndex<>();
tempLex = op.tlpParams.lex(op, tempWordIndex, tempTagIndex);
tempLex.initializeTraining(trainSize);
for (Tree tree : trees) {
double treeWeight = treeWeights.getCount(tree);
double[] stateWeights = { Math.log(treeWeight) };
tempLex.incrementTreesRead(treeWeight);
IdentityHashMap<Tree, double[][]> oldUnaryTransitions = new IdentityHashMap<>();
IdentityHashMap<Tree, double[][][]> oldBinaryTransitions = new IdentityHashMap<>();
recountTree(tree, false, oldUnaryTransitions, oldBinaryTransitions);
IdentityHashMap<Tree, double[][]> unaryTransitions = new IdentityHashMap<>();
IdentityHashMap<Tree, double[][][]> binaryTransitions = new IdentityHashMap<>();
mergeTransitions(tree, oldUnaryTransitions, oldBinaryTransitions, unaryTransitions, binaryTransitions, stateWeights, mergeCorrespondence);
recalculateTemporaryBetas(tree, stateWeights, 0, unaryTransitions, binaryTransitions, null, tempUnaryBetas, tempBinaryBetas);
}
tempLex.finishTraining();
useNewBetas(false, tempUnaryBetas, tempBinaryBetas);
}
Aggregations