Search in sources :

Example 11 with PriorityQueue

use of java.util.PriorityQueue in project CoreNLP by stanfordnlp.

the class FindNearestNeighbors method main.

public static void main(String[] args) throws Exception {
    String modelPath = null;
    String outputPath = null;
    String testTreebankPath = null;
    FileFilter testTreebankFilter = null;
    List<String> unusedArgs = new ArrayList<>();
    for (int argIndex = 0; argIndex < args.length; ) {
        if (args[argIndex].equalsIgnoreCase("-model")) {
            modelPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
            Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
            argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
            testTreebankPath = treebankDescription.first();
            testTreebankFilter = treebankDescription.second();
        } else if (args[argIndex].equalsIgnoreCase("-output")) {
            outputPath = args[argIndex + 1];
            argIndex += 2;
        } else {
            unusedArgs.add(args[argIndex++]);
        }
    }
    if (modelPath == null) {
        throw new IllegalArgumentException("Need to specify -model");
    }
    if (testTreebankPath == null) {
        throw new IllegalArgumentException("Need to specify -testTreebank");
    }
    if (outputPath == null) {
        throw new IllegalArgumentException("Need to specify -output");
    }
    String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
    LexicalizedParser lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
    Treebank testTreebank = null;
    if (testTreebankPath != null) {
        log.info("Reading in trees from " + testTreebankPath);
        if (testTreebankFilter != null) {
            log.info("Filtering on " + testTreebankFilter);
        }
        testTreebank = lexparser.getOp().tlpParams.memoryTreebank();
        ;
        testTreebank.loadPath(testTreebankPath, testTreebankFilter);
        log.info("Read in " + testTreebank.size() + " trees for testing");
    }
    FileWriter out = new FileWriter(outputPath);
    BufferedWriter bout = new BufferedWriter(out);
    log.info("Parsing " + testTreebank.size() + " trees");
    int count = 0;
    List<ParseRecord> records = Generics.newArrayList();
    for (Tree goldTree : testTreebank) {
        List<Word> tokens = goldTree.yieldWords();
        ParserQuery parserQuery = lexparser.parserQuery();
        if (!parserQuery.parse(tokens)) {
            throw new AssertionError("Could not parse: " + tokens);
        }
        if (!(parserQuery instanceof RerankingParserQuery)) {
            throw new IllegalArgumentException("Expected a LexicalizedParser with a Reranker attached");
        }
        RerankingParserQuery rpq = (RerankingParserQuery) parserQuery;
        if (!(rpq.rerankerQuery() instanceof DVModelReranker.Query)) {
            throw new IllegalArgumentException("Expected a LexicalizedParser with a DVModel attached");
        }
        DeepTree tree = ((DVModelReranker.Query) rpq.rerankerQuery()).getDeepTrees().get(0);
        SimpleMatrix rootVector = null;
        for (Map.Entry<Tree, SimpleMatrix> entry : tree.getVectors().entrySet()) {
            if (entry.getKey().label().value().equals("ROOT")) {
                rootVector = entry.getValue();
                break;
            }
        }
        if (rootVector == null) {
            throw new AssertionError("Could not find root nodevector");
        }
        out.write(tokens + "\n");
        out.write(tree.getTree() + "\n");
        for (int i = 0; i < rootVector.getNumElements(); ++i) {
            out.write("  " + rootVector.get(i));
        }
        out.write("\n\n\n");
        count++;
        if (count % 10 == 0) {
            log.info("  " + count);
        }
        records.add(new ParseRecord(tokens, goldTree, tree.getTree(), rootVector, tree.getVectors()));
    }
    log.info("  done parsing");
    List<Pair<Tree, SimpleMatrix>> subtrees = Generics.newArrayList();
    for (ParseRecord record : records) {
        for (Map.Entry<Tree, SimpleMatrix> entry : record.nodeVectors.entrySet()) {
            if (entry.getKey().getLeaves().size() <= maxLength) {
                subtrees.add(Pair.makePair(entry.getKey(), entry.getValue()));
            }
        }
    }
    log.info("There are " + subtrees.size() + " subtrees in the set of trees");
    PriorityQueue<ScoredObject<Pair<Tree, Tree>>> bestmatches = new PriorityQueue<>(101, ScoredComparator.DESCENDING_COMPARATOR);
    for (int i = 0; i < subtrees.size(); ++i) {
        log.info(subtrees.get(i).first().yieldWords());
        log.info(subtrees.get(i).first());
        for (int j = 0; j < subtrees.size(); ++j) {
            if (i == j) {
                continue;
            }
            // TODO: look at basic category?
            double normF = subtrees.get(i).second().minus(subtrees.get(j).second()).normF();
            bestmatches.add(new ScoredObject<>(Pair.makePair(subtrees.get(i).first(), subtrees.get(j).first()), normF));
            if (bestmatches.size() > 100) {
                bestmatches.poll();
            }
        }
        List<ScoredObject<Pair<Tree, Tree>>> ordered = Generics.newArrayList();
        while (bestmatches.size() > 0) {
            ordered.add(bestmatches.poll());
        }
        Collections.reverse(ordered);
        for (ScoredObject<Pair<Tree, Tree>> pair : ordered) {
            log.info(" MATCHED " + pair.object().second.yieldWords() + " ... " + pair.object().second() + " with a score of " + pair.score());
        }
        log.info();
        log.info();
        bestmatches.clear();
    }
    /*
    for (int i = 0; i < records.size(); ++i) {
      if (i % 10 == 0) {
        log.info("  " + i);
      }
      List<ScoredObject<ParseRecord>> scored = Generics.newArrayList();
      for (int j = 0; j < records.size(); ++j) {
        if (i == j) continue;

        double score = 0.0;
        int matches = 0;
        for (Map.Entry<Tree, SimpleMatrix> first : records.get(i).nodeVectors.entrySet()) {
          for (Map.Entry<Tree, SimpleMatrix> second : records.get(j).nodeVectors.entrySet()) {
            String firstBasic = dvparser.dvModel.basicCategory(first.getKey().label().value());
            String secondBasic = dvparser.dvModel.basicCategory(second.getKey().label().value());
            if (firstBasic.equals(secondBasic)) {
              ++matches;
              double normF = first.getValue().minus(second.getValue()).normF();
              score += normF * normF;
            }
          }
        }
        if (matches == 0) {
          score = Double.POSITIVE_INFINITY;
        } else {
          score = score / matches;
        }
        //double score = records.get(i).vector.minus(records.get(j).vector).normF();
        scored.add(new ScoredObject<ParseRecord>(records.get(j), score));
      }
      Collections.sort(scored, ScoredComparator.ASCENDING_COMPARATOR);

      out.write(records.get(i).sentence.toString() + "\n");
      for (int j = 0; j < numNeighbors; ++j) {
        out.write("   " + scored.get(j).score() + ": " + scored.get(j).object().sentence + "\n");
      }
      out.write("\n\n");
    }
    log.info();
    */
    bout.flush();
    out.flush();
    out.close();
}
Also used : Word(edu.stanford.nlp.ling.Word) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) ParserQuery(edu.stanford.nlp.parser.common.ParserQuery) Treebank(edu.stanford.nlp.trees.Treebank) LexicalizedParser(edu.stanford.nlp.parser.lexparser.LexicalizedParser) FileWriter(java.io.FileWriter) ArrayList(java.util.ArrayList) BufferedWriter(java.io.BufferedWriter) SimpleMatrix(org.ejml.simple.SimpleMatrix) ScoredObject(edu.stanford.nlp.util.ScoredObject) DeepTree(edu.stanford.nlp.trees.DeepTree) Tree(edu.stanford.nlp.trees.Tree) DeepTree(edu.stanford.nlp.trees.DeepTree) FileFilter(java.io.FileFilter) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) Pair(edu.stanford.nlp.util.Pair) PriorityQueue(java.util.PriorityQueue) IdentityHashMap(java.util.IdentityHashMap) Map(java.util.Map) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) ParserQuery(edu.stanford.nlp.parser.common.ParserQuery)

Example 12 with PriorityQueue

use of java.util.PriorityQueue in project CoreNLP by stanfordnlp.

the class PerceptronModel method trainTree.

private Pair<Integer, Integer> trainTree(int index, List<Tree> binarizedTrees, List<List<Transition>> transitionLists, List<Update> updates, Oracle oracle) {
    int numCorrect = 0;
    int numWrong = 0;
    Tree tree = binarizedTrees.get(index);
    ReorderingOracle reorderer = null;
    if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
        reorderer = new ReorderingOracle(op);
    }
    // it under control.
    if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ORACLE) {
        State state = ShiftReduceParser.initialStateFromGoldTagTree(tree);
        while (!state.isFinished()) {
            List<String> features = featureFactory.featurize(state);
            ScoredObject<Integer> prediction = findHighestScoringTransition(state, features, true);
            if (prediction == null) {
                throw new AssertionError("Did not find a legal transition");
            }
            int predictedNum = prediction.object();
            Transition predicted = transitionIndex.get(predictedNum);
            OracleTransition gold = oracle.goldTransition(index, state);
            if (gold.isCorrect(predicted)) {
                numCorrect++;
                if (gold.transition != null && !gold.transition.equals(predicted)) {
                    int transitionNum = transitionIndex.indexOf(gold.transition);
                    if (transitionNum < 0) {
                        // only possible when the parser has gone off the rails?
                        continue;
                    }
                    updates.add(new Update(features, transitionNum, -1, 1.0f));
                }
            } else {
                numWrong++;
                int transitionNum = -1;
                if (gold.transition != null) {
                    transitionNum = transitionIndex.indexOf(gold.transition);
                // TODO: this can theoretically result in a -1 gold
                // transition if the transition exists, but is a
                // CompoundUnaryTransition which only exists because the
                // parser is wrong.  Do we want to add those transitions?
                }
                updates.add(new Update(features, transitionNum, predictedNum, 1.0f));
            }
            state = predicted.apply(state);
        }
    } else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
        if (op.trainOptions().beamSize <= 0) {
            throw new IllegalArgumentException("Illegal beam size " + op.trainOptions().beamSize);
        }
        List<Transition> transitions = Generics.newLinkedList(transitionLists.get(index));
        PriorityQueue<State> agenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
        State goldState = ShiftReduceParser.initialStateFromGoldTagTree(tree);
        agenda.add(goldState);
        int transitionCount = 0;
        while (transitions.size() > 0) {
            Transition goldTransition = transitions.get(0);
            Transition highestScoringTransitionFromGoldState = null;
            double highestScoreFromGoldState = 0.0;
            PriorityQueue<State> newAgenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
            State highestScoringState = null;
            State highestCurrentState = null;
            for (State currentState : agenda) {
                boolean isGoldState = (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && goldState.areTransitionsEqual(currentState));
                List<String> features = featureFactory.featurize(currentState);
                Collection<ScoredObject<Integer>> stateTransitions = findHighestScoringTransitions(currentState, features, true, op.trainOptions().beamSize, null);
                for (ScoredObject<Integer> transition : stateTransitions) {
                    State newState = transitionIndex.get(transition.object()).apply(currentState, transition.score());
                    newAgenda.add(newState);
                    if (newAgenda.size() > op.trainOptions().beamSize) {
                        newAgenda.poll();
                    }
                    if (highestScoringState == null || highestScoringState.score() < newState.score()) {
                        highestScoringState = newState;
                        highestCurrentState = currentState;
                    }
                    if (isGoldState && (highestScoringTransitionFromGoldState == null || transition.score() > highestScoreFromGoldState)) {
                        highestScoringTransitionFromGoldState = transitionIndex.get(transition.object());
                        highestScoreFromGoldState = transition.score();
                    }
                }
            }
            // state (eg one with ROOT) isn't on the agenda so it stops.
            if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && highestScoringTransitionFromGoldState == null) {
                break;
            }
            State newGoldState = goldTransition.apply(goldState, 0.0);
            // otherwise, down the last transition, up the correct
            if (!newGoldState.areTransitionsEqual(highestScoringState)) {
                ++numWrong;
                List<String> goldFeatures = featureFactory.featurize(goldState);
                int lastTransition = transitionIndex.indexOf(highestScoringState.transitions.peek());
                updates.add(new Update(featureFactory.featurize(highestCurrentState), -1, lastTransition, 1.0f));
                updates.add(new Update(goldFeatures, transitionIndex.indexOf(goldTransition), -1, 1.0f));
                if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM) {
                    // If the correct state has fallen off the agenda, break
                    if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
                        break;
                    } else {
                        transitions.remove(0);
                    }
                } else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
                    if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
                        if (!reorderer.reorder(goldState, highestScoringTransitionFromGoldState, transitions)) {
                            break;
                        }
                        newGoldState = highestScoringTransitionFromGoldState.apply(goldState);
                        if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
                            break;
                        }
                    } else {
                        transitions.remove(0);
                    }
                }
            } else {
                ++numCorrect;
                transitions.remove(0);
            }
            goldState = newGoldState;
            agenda = newAgenda;
        }
    } else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.EARLY_TERMINATION || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.GOLD) {
        State state = ShiftReduceParser.initialStateFromGoldTagTree(tree);
        List<Transition> transitions = transitionLists.get(index);
        transitions = Generics.newLinkedList(transitions);
        boolean keepGoing = true;
        while (transitions.size() > 0 && keepGoing) {
            Transition transition = transitions.get(0);
            int transitionNum = transitionIndex.indexOf(transition);
            List<String> features = featureFactory.featurize(state);
            int predictedNum = findHighestScoringTransition(state, features, false).object();
            Transition predicted = transitionIndex.get(predictedNum);
            if (transitionNum == predictedNum) {
                transitions.remove(0);
                state = transition.apply(state);
                numCorrect++;
            } else {
                numWrong++;
                // TODO: allow weighted features, weighted training, etc
                updates.add(new Update(features, transitionNum, predictedNum, 1.0f));
                switch(op.trainOptions().trainingMethod) {
                    case EARLY_TERMINATION:
                        keepGoing = false;
                        break;
                    case GOLD:
                        transitions.remove(0);
                        state = transition.apply(state);
                        break;
                    case REORDER_ORACLE:
                        keepGoing = reorderer.reorder(state, predicted, transitions);
                        if (keepGoing) {
                            state = predicted.apply(state);
                        }
                        break;
                    default:
                        throw new IllegalArgumentException("Unexpected method " + op.trainOptions().trainingMethod);
                }
            }
        }
    }
    return Pair.makePair(numCorrect, numWrong);
}
Also used : PriorityQueue(java.util.PriorityQueue) ParserConstraint(edu.stanford.nlp.parser.common.ParserConstraint) ScoredObject(edu.stanford.nlp.util.ScoredObject) Tree(edu.stanford.nlp.trees.Tree) Collection(java.util.Collection) List(java.util.List)

Example 13 with PriorityQueue

use of java.util.PriorityQueue in project CoreNLP by stanfordnlp.

the class PerceptronModel method findHighestScoringTransitions.

private Collection<ScoredObject<Integer>> findHighestScoringTransitions(State state, List<String> features, boolean requireLegal, int numTransitions, List<ParserConstraint> constraints) {
    float[] scores = new float[transitionIndex.size()];
    for (String feature : features) {
        Weight weight = featureWeights.get(feature);
        if (weight == null) {
            // Features not in our index are ignored
            continue;
        }
        weight.score(scores);
    }
    PriorityQueue<ScoredObject<Integer>> queue = new PriorityQueue<>(numTransitions + 1, ScoredComparator.ASCENDING_COMPARATOR);
    for (int i = 0; i < scores.length; ++i) {
        if (!requireLegal || transitionIndex.get(i).isLegal(state, constraints)) {
            queue.add(new ScoredObject<>(i, scores[i]));
            if (queue.size() > numTransitions) {
                queue.poll();
            }
        }
    }
    return queue;
}
Also used : ScoredObject(edu.stanford.nlp.util.ScoredObject) PriorityQueue(java.util.PriorityQueue) ParserConstraint(edu.stanford.nlp.parser.common.ParserConstraint)

Example 14 with PriorityQueue

use of java.util.PriorityQueue in project HanLP by hankcs.

the class Dijkstra method compute.

public static List<Vertex> compute(Graph graph) {
    List<Vertex> resultList = new LinkedList<Vertex>();
    Vertex[] vertexes = graph.getVertexes();
    List<EdgeFrom>[] edgesTo = graph.getEdgesTo();
    double[] d = new double[vertexes.length];
    Arrays.fill(d, Double.MAX_VALUE);
    d[d.length - 1] = 0;
    int[] path = new int[vertexes.length];
    Arrays.fill(path, -1);
    PriorityQueue<State> que = new PriorityQueue<State>();
    que.add(new State(0, vertexes.length - 1));
    while (!que.isEmpty()) {
        State p = que.poll();
        if (d[p.vertex] < p.cost)
            continue;
        for (EdgeFrom edgeFrom : edgesTo[p.vertex]) {
            if (d[edgeFrom.from] > d[p.vertex] + edgeFrom.weight) {
                d[edgeFrom.from] = d[p.vertex] + edgeFrom.weight;
                que.add(new State(d[edgeFrom.from], edgeFrom.from));
                path[edgeFrom.from] = p.vertex;
            }
        }
    }
    for (int t = 0; t != -1; t = path[t]) {
        resultList.add(vertexes[t]);
    }
    return resultList;
}
Also used : Vertex(com.hankcs.hanlp.seg.common.Vertex) PriorityQueue(java.util.PriorityQueue) LinkedList(java.util.LinkedList) EdgeFrom(com.hankcs.hanlp.seg.common.EdgeFrom) State(com.hankcs.hanlp.seg.Dijkstra.Path.State) List(java.util.List) LinkedList(java.util.LinkedList)

Example 15 with PriorityQueue

use of java.util.PriorityQueue in project graphhopper by graphhopper.

the class AbstractBinHeapTest method testSize.

@Test
public void testSize() {
    PriorityQueue<SPTEntry> juQueue = new PriorityQueue<SPTEntry>(100);
    BinHeapWrapper<Number, Integer> binHeap = createHeap(100);
    Random rand = new Random(1);
    int N = 1000;
    for (int i = 0; i < N; i++) {
        int val = rand.nextInt();
        binHeap.insert(val, i);
        juQueue.add(new SPTEntry(EdgeIterator.NO_EDGE, i, val));
    }
    assertEquals(juQueue.size(), binHeap.getSize());
    for (int i = 0; i < N; i++) {
        assertEquals(juQueue.poll().adjNode, binHeap.pollElement(), 1e-5);
    }
    assertEquals(binHeap.getSize(), 0);
}
Also used : SPTEntry(com.graphhopper.storage.SPTEntry) Random(java.util.Random) PriorityQueue(java.util.PriorityQueue) Test(org.junit.Test)

Aggregations

PriorityQueue (java.util.PriorityQueue)51 ArrayList (java.util.ArrayList)16 List (java.util.List)10 Map (java.util.Map)9 HashMap (java.util.HashMap)7 LinkedList (java.util.LinkedList)5 File (java.io.File)4 IOException (java.io.IOException)4 Entry (java.util.Map.Entry)4 Random (java.util.Random)4 BytesRef (org.apache.lucene.util.BytesRef)4 AbstractMapTable (com.ctriposs.sdb.table.AbstractMapTable)3 ScoredObject (edu.stanford.nlp.util.ScoredObject)3 Comparator (java.util.Comparator)3 Set (java.util.Set)3 FCMapTable (com.ctriposs.sdb.table.FCMapTable)2 HashMapTable (com.ctriposs.sdb.table.HashMapTable)2 IMapEntry (com.ctriposs.sdb.table.IMapEntry)2 MMFMapTable (com.ctriposs.sdb.table.MMFMapTable)2 Type (com.facebook.presto.spi.type.Type)2