use of java.util.PriorityQueue in project CoreNLP by stanfordnlp.
the class FindNearestNeighbors method main.
public static void main(String[] args) throws Exception {
String modelPath = null;
String outputPath = null;
String testTreebankPath = null;
FileFilter testTreebankFilter = null;
List<String> unusedArgs = new ArrayList<>();
for (int argIndex = 0; argIndex < args.length; ) {
if (args[argIndex].equalsIgnoreCase("-model")) {
modelPath = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
testTreebankPath = treebankDescription.first();
testTreebankFilter = treebankDescription.second();
} else if (args[argIndex].equalsIgnoreCase("-output")) {
outputPath = args[argIndex + 1];
argIndex += 2;
} else {
unusedArgs.add(args[argIndex++]);
}
}
if (modelPath == null) {
throw new IllegalArgumentException("Need to specify -model");
}
if (testTreebankPath == null) {
throw new IllegalArgumentException("Need to specify -testTreebank");
}
if (outputPath == null) {
throw new IllegalArgumentException("Need to specify -output");
}
String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
LexicalizedParser lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
Treebank testTreebank = null;
if (testTreebankPath != null) {
log.info("Reading in trees from " + testTreebankPath);
if (testTreebankFilter != null) {
log.info("Filtering on " + testTreebankFilter);
}
testTreebank = lexparser.getOp().tlpParams.memoryTreebank();
;
testTreebank.loadPath(testTreebankPath, testTreebankFilter);
log.info("Read in " + testTreebank.size() + " trees for testing");
}
FileWriter out = new FileWriter(outputPath);
BufferedWriter bout = new BufferedWriter(out);
log.info("Parsing " + testTreebank.size() + " trees");
int count = 0;
List<ParseRecord> records = Generics.newArrayList();
for (Tree goldTree : testTreebank) {
List<Word> tokens = goldTree.yieldWords();
ParserQuery parserQuery = lexparser.parserQuery();
if (!parserQuery.parse(tokens)) {
throw new AssertionError("Could not parse: " + tokens);
}
if (!(parserQuery instanceof RerankingParserQuery)) {
throw new IllegalArgumentException("Expected a LexicalizedParser with a Reranker attached");
}
RerankingParserQuery rpq = (RerankingParserQuery) parserQuery;
if (!(rpq.rerankerQuery() instanceof DVModelReranker.Query)) {
throw new IllegalArgumentException("Expected a LexicalizedParser with a DVModel attached");
}
DeepTree tree = ((DVModelReranker.Query) rpq.rerankerQuery()).getDeepTrees().get(0);
SimpleMatrix rootVector = null;
for (Map.Entry<Tree, SimpleMatrix> entry : tree.getVectors().entrySet()) {
if (entry.getKey().label().value().equals("ROOT")) {
rootVector = entry.getValue();
break;
}
}
if (rootVector == null) {
throw new AssertionError("Could not find root nodevector");
}
out.write(tokens + "\n");
out.write(tree.getTree() + "\n");
for (int i = 0; i < rootVector.getNumElements(); ++i) {
out.write(" " + rootVector.get(i));
}
out.write("\n\n\n");
count++;
if (count % 10 == 0) {
log.info(" " + count);
}
records.add(new ParseRecord(tokens, goldTree, tree.getTree(), rootVector, tree.getVectors()));
}
log.info(" done parsing");
List<Pair<Tree, SimpleMatrix>> subtrees = Generics.newArrayList();
for (ParseRecord record : records) {
for (Map.Entry<Tree, SimpleMatrix> entry : record.nodeVectors.entrySet()) {
if (entry.getKey().getLeaves().size() <= maxLength) {
subtrees.add(Pair.makePair(entry.getKey(), entry.getValue()));
}
}
}
log.info("There are " + subtrees.size() + " subtrees in the set of trees");
PriorityQueue<ScoredObject<Pair<Tree, Tree>>> bestmatches = new PriorityQueue<>(101, ScoredComparator.DESCENDING_COMPARATOR);
for (int i = 0; i < subtrees.size(); ++i) {
log.info(subtrees.get(i).first().yieldWords());
log.info(subtrees.get(i).first());
for (int j = 0; j < subtrees.size(); ++j) {
if (i == j) {
continue;
}
// TODO: look at basic category?
double normF = subtrees.get(i).second().minus(subtrees.get(j).second()).normF();
bestmatches.add(new ScoredObject<>(Pair.makePair(subtrees.get(i).first(), subtrees.get(j).first()), normF));
if (bestmatches.size() > 100) {
bestmatches.poll();
}
}
List<ScoredObject<Pair<Tree, Tree>>> ordered = Generics.newArrayList();
while (bestmatches.size() > 0) {
ordered.add(bestmatches.poll());
}
Collections.reverse(ordered);
for (ScoredObject<Pair<Tree, Tree>> pair : ordered) {
log.info(" MATCHED " + pair.object().second.yieldWords() + " ... " + pair.object().second() + " with a score of " + pair.score());
}
log.info();
log.info();
bestmatches.clear();
}
/*
for (int i = 0; i < records.size(); ++i) {
if (i % 10 == 0) {
log.info(" " + i);
}
List<ScoredObject<ParseRecord>> scored = Generics.newArrayList();
for (int j = 0; j < records.size(); ++j) {
if (i == j) continue;
double score = 0.0;
int matches = 0;
for (Map.Entry<Tree, SimpleMatrix> first : records.get(i).nodeVectors.entrySet()) {
for (Map.Entry<Tree, SimpleMatrix> second : records.get(j).nodeVectors.entrySet()) {
String firstBasic = dvparser.dvModel.basicCategory(first.getKey().label().value());
String secondBasic = dvparser.dvModel.basicCategory(second.getKey().label().value());
if (firstBasic.equals(secondBasic)) {
++matches;
double normF = first.getValue().minus(second.getValue()).normF();
score += normF * normF;
}
}
}
if (matches == 0) {
score = Double.POSITIVE_INFINITY;
} else {
score = score / matches;
}
//double score = records.get(i).vector.minus(records.get(j).vector).normF();
scored.add(new ScoredObject<ParseRecord>(records.get(j), score));
}
Collections.sort(scored, ScoredComparator.ASCENDING_COMPARATOR);
out.write(records.get(i).sentence.toString() + "\n");
for (int j = 0; j < numNeighbors; ++j) {
out.write(" " + scored.get(j).score() + ": " + scored.get(j).object().sentence + "\n");
}
out.write("\n\n");
}
log.info();
*/
bout.flush();
out.flush();
out.close();
}
use of java.util.PriorityQueue in project CoreNLP by stanfordnlp.
the class PerceptronModel method trainTree.
private Pair<Integer, Integer> trainTree(int index, List<Tree> binarizedTrees, List<List<Transition>> transitionLists, List<Update> updates, Oracle oracle) {
int numCorrect = 0;
int numWrong = 0;
Tree tree = binarizedTrees.get(index);
ReorderingOracle reorderer = null;
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
reorderer = new ReorderingOracle(op);
}
// it under control.
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ORACLE) {
State state = ShiftReduceParser.initialStateFromGoldTagTree(tree);
while (!state.isFinished()) {
List<String> features = featureFactory.featurize(state);
ScoredObject<Integer> prediction = findHighestScoringTransition(state, features, true);
if (prediction == null) {
throw new AssertionError("Did not find a legal transition");
}
int predictedNum = prediction.object();
Transition predicted = transitionIndex.get(predictedNum);
OracleTransition gold = oracle.goldTransition(index, state);
if (gold.isCorrect(predicted)) {
numCorrect++;
if (gold.transition != null && !gold.transition.equals(predicted)) {
int transitionNum = transitionIndex.indexOf(gold.transition);
if (transitionNum < 0) {
// only possible when the parser has gone off the rails?
continue;
}
updates.add(new Update(features, transitionNum, -1, 1.0f));
}
} else {
numWrong++;
int transitionNum = -1;
if (gold.transition != null) {
transitionNum = transitionIndex.indexOf(gold.transition);
// TODO: this can theoretically result in a -1 gold
// transition if the transition exists, but is a
// CompoundUnaryTransition which only exists because the
// parser is wrong. Do we want to add those transitions?
}
updates.add(new Update(features, transitionNum, predictedNum, 1.0f));
}
state = predicted.apply(state);
}
} else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
if (op.trainOptions().beamSize <= 0) {
throw new IllegalArgumentException("Illegal beam size " + op.trainOptions().beamSize);
}
List<Transition> transitions = Generics.newLinkedList(transitionLists.get(index));
PriorityQueue<State> agenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
State goldState = ShiftReduceParser.initialStateFromGoldTagTree(tree);
agenda.add(goldState);
int transitionCount = 0;
while (transitions.size() > 0) {
Transition goldTransition = transitions.get(0);
Transition highestScoringTransitionFromGoldState = null;
double highestScoreFromGoldState = 0.0;
PriorityQueue<State> newAgenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
State highestScoringState = null;
State highestCurrentState = null;
for (State currentState : agenda) {
boolean isGoldState = (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && goldState.areTransitionsEqual(currentState));
List<String> features = featureFactory.featurize(currentState);
Collection<ScoredObject<Integer>> stateTransitions = findHighestScoringTransitions(currentState, features, true, op.trainOptions().beamSize, null);
for (ScoredObject<Integer> transition : stateTransitions) {
State newState = transitionIndex.get(transition.object()).apply(currentState, transition.score());
newAgenda.add(newState);
if (newAgenda.size() > op.trainOptions().beamSize) {
newAgenda.poll();
}
if (highestScoringState == null || highestScoringState.score() < newState.score()) {
highestScoringState = newState;
highestCurrentState = currentState;
}
if (isGoldState && (highestScoringTransitionFromGoldState == null || transition.score() > highestScoreFromGoldState)) {
highestScoringTransitionFromGoldState = transitionIndex.get(transition.object());
highestScoreFromGoldState = transition.score();
}
}
}
// state (eg one with ROOT) isn't on the agenda so it stops.
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && highestScoringTransitionFromGoldState == null) {
break;
}
State newGoldState = goldTransition.apply(goldState, 0.0);
// otherwise, down the last transition, up the correct
if (!newGoldState.areTransitionsEqual(highestScoringState)) {
++numWrong;
List<String> goldFeatures = featureFactory.featurize(goldState);
int lastTransition = transitionIndex.indexOf(highestScoringState.transitions.peek());
updates.add(new Update(featureFactory.featurize(highestCurrentState), -1, lastTransition, 1.0f));
updates.add(new Update(goldFeatures, transitionIndex.indexOf(goldTransition), -1, 1.0f));
if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM) {
// If the correct state has fallen off the agenda, break
if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
break;
} else {
transitions.remove(0);
}
} else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
if (!reorderer.reorder(goldState, highestScoringTransitionFromGoldState, transitions)) {
break;
}
newGoldState = highestScoringTransitionFromGoldState.apply(goldState);
if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
break;
}
} else {
transitions.remove(0);
}
}
} else {
++numCorrect;
transitions.remove(0);
}
goldState = newGoldState;
agenda = newAgenda;
}
} else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.EARLY_TERMINATION || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.GOLD) {
State state = ShiftReduceParser.initialStateFromGoldTagTree(tree);
List<Transition> transitions = transitionLists.get(index);
transitions = Generics.newLinkedList(transitions);
boolean keepGoing = true;
while (transitions.size() > 0 && keepGoing) {
Transition transition = transitions.get(0);
int transitionNum = transitionIndex.indexOf(transition);
List<String> features = featureFactory.featurize(state);
int predictedNum = findHighestScoringTransition(state, features, false).object();
Transition predicted = transitionIndex.get(predictedNum);
if (transitionNum == predictedNum) {
transitions.remove(0);
state = transition.apply(state);
numCorrect++;
} else {
numWrong++;
// TODO: allow weighted features, weighted training, etc
updates.add(new Update(features, transitionNum, predictedNum, 1.0f));
switch(op.trainOptions().trainingMethod) {
case EARLY_TERMINATION:
keepGoing = false;
break;
case GOLD:
transitions.remove(0);
state = transition.apply(state);
break;
case REORDER_ORACLE:
keepGoing = reorderer.reorder(state, predicted, transitions);
if (keepGoing) {
state = predicted.apply(state);
}
break;
default:
throw new IllegalArgumentException("Unexpected method " + op.trainOptions().trainingMethod);
}
}
}
}
return Pair.makePair(numCorrect, numWrong);
}
use of java.util.PriorityQueue in project CoreNLP by stanfordnlp.
the class PerceptronModel method findHighestScoringTransitions.
private Collection<ScoredObject<Integer>> findHighestScoringTransitions(State state, List<String> features, boolean requireLegal, int numTransitions, List<ParserConstraint> constraints) {
float[] scores = new float[transitionIndex.size()];
for (String feature : features) {
Weight weight = featureWeights.get(feature);
if (weight == null) {
// Features not in our index are ignored
continue;
}
weight.score(scores);
}
PriorityQueue<ScoredObject<Integer>> queue = new PriorityQueue<>(numTransitions + 1, ScoredComparator.ASCENDING_COMPARATOR);
for (int i = 0; i < scores.length; ++i) {
if (!requireLegal || transitionIndex.get(i).isLegal(state, constraints)) {
queue.add(new ScoredObject<>(i, scores[i]));
if (queue.size() > numTransitions) {
queue.poll();
}
}
}
return queue;
}
use of java.util.PriorityQueue in project HanLP by hankcs.
the class Dijkstra method compute.
public static List<Vertex> compute(Graph graph) {
List<Vertex> resultList = new LinkedList<Vertex>();
Vertex[] vertexes = graph.getVertexes();
List<EdgeFrom>[] edgesTo = graph.getEdgesTo();
double[] d = new double[vertexes.length];
Arrays.fill(d, Double.MAX_VALUE);
d[d.length - 1] = 0;
int[] path = new int[vertexes.length];
Arrays.fill(path, -1);
PriorityQueue<State> que = new PriorityQueue<State>();
que.add(new State(0, vertexes.length - 1));
while (!que.isEmpty()) {
State p = que.poll();
if (d[p.vertex] < p.cost)
continue;
for (EdgeFrom edgeFrom : edgesTo[p.vertex]) {
if (d[edgeFrom.from] > d[p.vertex] + edgeFrom.weight) {
d[edgeFrom.from] = d[p.vertex] + edgeFrom.weight;
que.add(new State(d[edgeFrom.from], edgeFrom.from));
path[edgeFrom.from] = p.vertex;
}
}
}
for (int t = 0; t != -1; t = path[t]) {
resultList.add(vertexes[t]);
}
return resultList;
}
use of java.util.PriorityQueue in project graphhopper by graphhopper.
the class AbstractBinHeapTest method testSize.
@Test
public void testSize() {
PriorityQueue<SPTEntry> juQueue = new PriorityQueue<SPTEntry>(100);
BinHeapWrapper<Number, Integer> binHeap = createHeap(100);
Random rand = new Random(1);
int N = 1000;
for (int i = 0; i < N; i++) {
int val = rand.nextInt();
binHeap.insert(val, i);
juQueue.add(new SPTEntry(EdgeIterator.NO_EDGE, i, val));
}
assertEquals(juQueue.size(), binHeap.getSize());
for (int i = 0; i < N; i++) {
assertEquals(juQueue.poll().adjNode, binHeap.pollElement(), 1e-5);
}
assertEquals(binHeap.getSize(), 0);
}
Aggregations