Search in sources :

Example 1 with ThreeDimensionalMap

use of edu.stanford.nlp.util.ThreeDimensionalMap in project CoreNLP by stanfordnlp.

the class SplittingGrammarExtractor method mergeStates.

public void mergeStates() {
    if (op.trainOptions.splitRecombineRate <= 0.0) {
        return;
    }
    // we go through the machinery to sum up the temporary betas,
    // counting the total mass
    TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<>();
    ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<>();
    Map<String, double[]> totalStateMass = Generics.newHashMap();
    recalculateTemporaryBetas(false, totalStateMass, tempUnaryBetas, tempBinaryBetas);
    // Next, for each tree we count the effect of merging its
    // annotations.  We only consider the most recently split
    // annotations as candidates for merging.
    Map<String, double[]> deltaAnnotations = Generics.newHashMap();
    for (Tree tree : trees) {
        countMergeEffects(tree, totalStateMass, deltaAnnotations);
    }
    // Now we have a map of the (approximate) likelihood loss from
    // merging each state.  We merge the ones that provide the least
    // benefit, up to the splitRecombineRate
    List<Triple<String, Integer, Double>> sortedDeltas = new ArrayList<>();
    for (String state : deltaAnnotations.keySet()) {
        double[] scores = deltaAnnotations.get(state);
        for (int i = 0; i < scores.length; ++i) {
            sortedDeltas.add(new Triple<>(state, i * 2, scores[i]));
        }
    }
    Collections.sort(sortedDeltas, new Comparator<Triple<String, Integer, Double>>() {

        public int compare(Triple<String, Integer, Double> first, Triple<String, Integer, Double> second) {
            // "backwards", sorting from high to low.
            return Double.compare(second.third(), first.third());
        }

        public boolean equals(Object o) {
            return o == this;
        }
    });
    // for (Triple<String, Integer, Double> delta : sortedDeltas) {
    //   System.out.println(delta.first() + "-" + delta.second() + ": " + delta.third());
    // }
    // System.out.println("-------------");
    // Only merge a fraction of the splits based on what the user
    // originally asked for
    int splitsToMerge = (int) (sortedDeltas.size() * op.trainOptions.splitRecombineRate);
    splitsToMerge = Math.max(0, splitsToMerge);
    splitsToMerge = Math.min(sortedDeltas.size() - 1, splitsToMerge);
    sortedDeltas = sortedDeltas.subList(0, splitsToMerge);
    System.out.println();
    System.out.println(sortedDeltas);
    Map<String, int[]> mergeCorrespondence = buildMergeCorrespondence(sortedDeltas);
    recalculateMergedBetas(mergeCorrespondence);
    for (Triple<String, Integer, Double> delta : sortedDeltas) {
        stateSplitCounts.decrementCount(delta.first(), 1);
    }
}
Also used : ThreeDimensionalMap(edu.stanford.nlp.util.ThreeDimensionalMap) ArrayList(java.util.ArrayList) MutableDouble(edu.stanford.nlp.util.MutableDouble) Triple(edu.stanford.nlp.util.Triple) Tree(edu.stanford.nlp.trees.Tree) TwoDimensionalMap(edu.stanford.nlp.util.TwoDimensionalMap)

Example 2 with ThreeDimensionalMap

use of edu.stanford.nlp.util.ThreeDimensionalMap in project CoreNLP by stanfordnlp.

the class SplittingGrammarExtractor method recalculateMergedBetas.

public void recalculateMergedBetas(Map<String, int[]> mergeCorrespondence) {
    TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<>();
    ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<>();
    tempWordIndex = new HashIndex<>();
    tempTagIndex = new HashIndex<>();
    tempLex = op.tlpParams.lex(op, tempWordIndex, tempTagIndex);
    tempLex.initializeTraining(trainSize);
    for (Tree tree : trees) {
        double treeWeight = treeWeights.getCount(tree);
        double[] stateWeights = { Math.log(treeWeight) };
        tempLex.incrementTreesRead(treeWeight);
        IdentityHashMap<Tree, double[][]> oldUnaryTransitions = new IdentityHashMap<>();
        IdentityHashMap<Tree, double[][][]> oldBinaryTransitions = new IdentityHashMap<>();
        recountTree(tree, false, oldUnaryTransitions, oldBinaryTransitions);
        IdentityHashMap<Tree, double[][]> unaryTransitions = new IdentityHashMap<>();
        IdentityHashMap<Tree, double[][][]> binaryTransitions = new IdentityHashMap<>();
        mergeTransitions(tree, oldUnaryTransitions, oldBinaryTransitions, unaryTransitions, binaryTransitions, stateWeights, mergeCorrespondence);
        recalculateTemporaryBetas(tree, stateWeights, 0, unaryTransitions, binaryTransitions, null, tempUnaryBetas, tempBinaryBetas);
    }
    tempLex.finishTraining();
    useNewBetas(false, tempUnaryBetas, tempBinaryBetas);
}
Also used : ThreeDimensionalMap(edu.stanford.nlp.util.ThreeDimensionalMap) IdentityHashMap(java.util.IdentityHashMap) Tree(edu.stanford.nlp.trees.Tree) TwoDimensionalMap(edu.stanford.nlp.util.TwoDimensionalMap)

Aggregations

Tree (edu.stanford.nlp.trees.Tree)2 ThreeDimensionalMap (edu.stanford.nlp.util.ThreeDimensionalMap)2 TwoDimensionalMap (edu.stanford.nlp.util.TwoDimensionalMap)2 MutableDouble (edu.stanford.nlp.util.MutableDouble)1 Triple (edu.stanford.nlp.util.Triple)1 ArrayList (java.util.ArrayList)1 IdentityHashMap (java.util.IdentityHashMap)1