use of com.joliciel.talismane.machineLearning.ClassificationObserver in project talismane by joliciel-informatique.
the class TransitionBasedParser method parseSentence.
@Override
public List<ParseConfiguration> parseSentence(List<PosTagSequence> input) throws TalismaneException, IOException {
List<PosTagSequence> posTagSequences = null;
if (this.propagatePosTaggerBeam) {
posTagSequences = input;
} else {
posTagSequences = new ArrayList<>(1);
posTagSequences.add(input.get(0));
}
long startTime = System.currentTimeMillis();
int maxAnalysisTimeMilliseconds = maxAnalysisTimePerSentence * 1000;
int minFreeMemoryBytes = minFreeMemory * KILOBYTE;
TokenSequence tokenSequence = posTagSequences.get(0).getTokenSequence();
TreeMap<Integer, PriorityQueue<ParseConfiguration>> heaps = new TreeMap<>();
PriorityQueue<ParseConfiguration> heap0 = new PriorityQueue<>();
for (PosTagSequence posTagSequence : posTagSequences) {
// add an initial ParseConfiguration for each postag sequence
ParseConfiguration initialConfiguration = new ParseConfiguration(posTagSequence);
initialConfiguration.setScoringStrategy(decisionMaker.getDefaultScoringStrategy());
heap0.add(initialConfiguration);
if (LOG.isDebugEnabled()) {
LOG.debug("Adding initial posTagSequence: " + posTagSequence);
}
}
heaps.put(0, heap0);
PriorityQueue<ParseConfiguration> backupHeap = null;
PriorityQueue<ParseConfiguration> finalHeap = null;
PriorityQueue<ParseConfiguration> terminalHeap = new PriorityQueue<>();
while (heaps.size() > 0) {
Entry<Integer, PriorityQueue<ParseConfiguration>> heapEntry = heaps.pollFirstEntry();
PriorityQueue<ParseConfiguration> currentHeap = heapEntry.getValue();
int currentHeapIndex = heapEntry.getKey();
if (LOG.isTraceEnabled()) {
LOG.trace("##### Polling next heap: " + heapEntry.getKey() + ", size: " + heapEntry.getValue().size());
}
boolean finished = false;
// systematically set the final heap here, just in case we exit
// "naturally" with no more heaps
finalHeap = heapEntry.getValue();
backupHeap = new PriorityQueue<>();
// we jump out when either (a) all tokens have been attached or
// (b) we go over the max alloted time
ParseConfiguration topConf = currentHeap.peek();
if (topConf.isTerminal()) {
LOG.trace("Exiting with terminal heap: " + heapEntry.getKey() + ", size: " + heapEntry.getValue().size());
finished = true;
}
if (earlyStop && terminalHeap.size() >= beamWidth) {
LOG.debug("Early stop activated and terminal heap contains " + beamWidth + " entries. Exiting.");
finalHeap = terminalHeap;
finished = true;
}
long analysisTime = System.currentTimeMillis() - startTime;
if (maxAnalysisTimePerSentence > 0 && analysisTime > maxAnalysisTimeMilliseconds) {
LOG.info("Parse tree analysis took too long for sentence: " + tokenSequence.getSentence().getText());
LOG.info("Breaking out after " + maxAnalysisTimePerSentence + " seconds.");
finished = true;
}
if (minFreeMemory > 0) {
long freeMemory = Runtime.getRuntime().freeMemory();
if (freeMemory < minFreeMemoryBytes) {
LOG.info("Not enough memory left to parse sentence: " + tokenSequence.getSentence().getText());
LOG.info("Min free memory (bytes):" + minFreeMemoryBytes);
LOG.info("Current free memory (bytes): " + freeMemory);
finished = true;
}
}
if (finished) {
break;
}
// limit the breadth to K
int maxSequences = currentHeap.size() > this.beamWidth ? this.beamWidth : currentHeap.size();
int j = 0;
while (currentHeap.size() > 0) {
ParseConfiguration history = currentHeap.poll();
if (LOG.isTraceEnabled()) {
LOG.trace("### Next configuration on heap " + heapEntry.getKey() + ":");
LOG.trace(history.toString());
LOG.trace("Score: " + df.format(history.getScore()));
LOG.trace(history.getPosTagSequence().toString());
}
List<Decision> decisions = new ArrayList<>();
// test the positive rules on the current configuration
boolean ruleApplied = false;
if (parserPositiveRules != null) {
for (ParserRule rule : parserPositiveRules) {
if (LOG.isTraceEnabled()) {
LOG.trace("Checking rule: " + rule.toString());
}
RuntimeEnvironment env = new RuntimeEnvironment();
FeatureResult<Boolean> ruleResult = rule.getCondition().check(history, env);
if (ruleResult != null && ruleResult.getOutcome()) {
Decision positiveRuleDecision = new Decision(rule.getTransition().getCode());
decisions.add(positiveRuleDecision);
positiveRuleDecision.addAuthority(rule.getCondition().getName());
ruleApplied = true;
if (LOG.isTraceEnabled()) {
LOG.trace("Rule applies. Setting transition to: " + rule.getTransition().getCode());
}
break;
}
}
}
if (!ruleApplied) {
// test the features on the current configuration
List<FeatureResult<?>> parseFeatureResults = new ArrayList<>();
for (ParseConfigurationFeature<?> feature : this.parseFeatures) {
RuntimeEnvironment env = new RuntimeEnvironment();
FeatureResult<?> featureResult = feature.check(history, env);
if (featureResult != null)
parseFeatureResults.add(featureResult);
}
if (LOG_FEATURES.isTraceEnabled()) {
SortedSet<String> featureResultSet = parseFeatureResults.stream().map(f -> f.toString()).collect(Collectors.toCollection(() -> new TreeSet<>()));
for (String featureResultString : featureResultSet) {
LOG_FEATURES.trace(featureResultString);
}
}
// evaluate the feature results using the decision maker
decisions = this.decisionMaker.decide(parseFeatureResults);
for (ClassificationObserver observer : this.observers) {
observer.onAnalyse(history, parseFeatureResults, decisions);
}
List<Decision> decisionShortList = new ArrayList<>(decisions.size());
for (Decision decision : decisions) {
if (decision.getProbability() > MIN_PROB_TO_STORE)
decisionShortList.add(decision);
}
decisions = decisionShortList;
// apply the negative rules
Set<String> eliminatedTransitions = new HashSet<>();
if (parserNegativeRules != null) {
for (ParserRule rule : parserNegativeRules) {
if (LOG.isTraceEnabled()) {
LOG.trace("Checking negative rule: " + rule.toString());
}
RuntimeEnvironment env = new RuntimeEnvironment();
FeatureResult<Boolean> ruleResult = rule.getCondition().check(history, env);
if (ruleResult != null && ruleResult.getOutcome()) {
for (Transition transition : rule.getTransitions()) {
eliminatedTransitions.add(transition.getCode());
if (LOG.isTraceEnabled())
LOG.trace("Rule applies. Eliminating transition: " + transition.getCode());
}
}
}
if (eliminatedTransitions.size() > 0) {
decisionShortList = new ArrayList<>();
for (Decision decision : decisions) {
if (!eliminatedTransitions.contains(decision.getOutcome())) {
decisionShortList.add(decision);
} else {
LOG.trace("Eliminating decision: " + decision.toString());
}
}
if (decisionShortList.size() > 0) {
decisions = decisionShortList;
} else {
LOG.debug("All decisions eliminated! Restoring original decisions.");
}
}
}
}
// has a positive rule been applied?
boolean transitionApplied = false;
TransitionSystem transitionSystem = TalismaneSession.get(sessionId).getTransitionSystem();
// type, we should be able to stop
for (Decision decision : decisions) {
Transition transition = transitionSystem.getTransitionForCode(decision.getOutcome());
if (LOG.isTraceEnabled())
LOG.trace("Outcome: " + transition.getCode() + ", " + decision.getProbability());
if (transition.checkPreconditions(history)) {
transitionApplied = true;
ParseConfiguration configuration = new ParseConfiguration(history);
if (decision.isStatistical())
configuration.addDecision(decision);
transition.apply(configuration);
int nextHeapIndex = parseComparisonStrategy.getComparisonIndex(configuration) * 1000;
if (configuration.isTerminal()) {
nextHeapIndex = Integer.MAX_VALUE;
} else {
while (nextHeapIndex <= currentHeapIndex) nextHeapIndex++;
}
PriorityQueue<ParseConfiguration> nextHeap = heaps.get(nextHeapIndex);
if (nextHeap == null) {
if (configuration.isTerminal())
nextHeap = terminalHeap;
else
nextHeap = new PriorityQueue<>();
heaps.put(nextHeapIndex, nextHeap);
if (LOG.isTraceEnabled())
LOG.trace("Created heap with index: " + nextHeapIndex);
}
nextHeap.add(configuration);
if (LOG.isTraceEnabled()) {
LOG.trace("Added configuration with score " + configuration.getScore() + " to heap: " + nextHeapIndex + ", total size: " + nextHeap.size());
}
configuration.clearMemory();
} else {
if (LOG.isTraceEnabled())
LOG.trace("Cannot apply transition: doesn't meet pre-conditions");
// just in case the we run out of both heaps and
// analyses, we build this backup heap
backupHeap.add(history);
}
// does transition meet pre-conditions?
}
if (transitionApplied) {
j++;
} else {
LOG.trace("No transitions could be applied: not counting this history as part of the beam");
}
// beam width test
if (j == maxSequences)
break;
}
// next history
}
// next atomic index
// return the best sequences on the heap
List<ParseConfiguration> bestConfigurations = new ArrayList<>();
int i = 0;
if (finalHeap.isEmpty())
finalHeap = backupHeap;
while (!finalHeap.isEmpty()) {
bestConfigurations.add(finalHeap.poll());
i++;
if (i >= this.getBeamWidth())
break;
}
if (LOG.isDebugEnabled()) {
for (ParseConfiguration finalConfiguration : bestConfigurations) {
LOG.debug(df.format(finalConfiguration.getScore()) + ": " + finalConfiguration.toString());
LOG.debug("Pos tag sequence: " + finalConfiguration.getPosTagSequence());
LOG.debug("Transitions: " + finalConfiguration.getTransitions());
LOG.debug("Decisions: " + finalConfiguration.getDecisions());
if (LOG.isTraceEnabled()) {
StringBuilder sb = new StringBuilder();
for (Decision decision : finalConfiguration.getDecisions()) {
sb.append(" * ");
sb.append(df.format(decision.getProbability()));
}
sb.append(" root ");
sb.append(finalConfiguration.getTransitions().size());
LOG.trace(sb.toString());
sb = new StringBuilder();
sb.append(" * PosTag sequence score ");
sb.append(df.format(finalConfiguration.getPosTagSequence().getScore()));
sb.append(" = ");
for (PosTaggedToken posTaggedToken : finalConfiguration.getPosTagSequence()) {
sb.append(" * ");
sb.append(df.format(posTaggedToken.getDecision().getProbability()));
}
sb.append(" root ");
sb.append(finalConfiguration.getPosTagSequence().size());
LOG.trace(sb.toString());
sb = new StringBuilder();
sb.append(" * Token sequence score = ");
sb.append(df.format(finalConfiguration.getPosTagSequence().getTokenSequence().getScore()));
LOG.trace(sb.toString());
}
}
}
return bestConfigurations;
}
use of com.joliciel.talismane.machineLearning.ClassificationObserver in project talismane by joliciel-informatique.
the class ForwardStatisticalPosTagger method tagSentence.
@Override
public List<PosTagSequence> tagSentence(List<TokenSequence> input) throws TalismaneException, IOException {
List<TokenSequence> tokenSequences = null;
if (this.propagateTokeniserBeam) {
tokenSequences = input;
} else {
tokenSequences = new ArrayList<>(1);
tokenSequences.add(input.get(0));
}
int sentenceLength = tokenSequences.get(0).getSentence().getText().length();
TreeMap<Double, PriorityQueue<PosTagSequence>> heaps = new TreeMap<Double, PriorityQueue<PosTagSequence>>();
PriorityQueue<PosTagSequence> heap0 = new PriorityQueue<PosTagSequence>();
for (TokenSequence tokenSequence : tokenSequences) {
// add an empty PosTagSequence for each token sequence
PosTagSequence emptySequence = new PosTagSequence(tokenSequence);
emptySequence.setScoringStrategy(decisionMaker.getDefaultScoringStrategy());
heap0.add(emptySequence);
}
heaps.put(0.0, heap0);
PriorityQueue<PosTagSequence> finalHeap = null;
while (heaps.size() > 0) {
Entry<Double, PriorityQueue<PosTagSequence>> heapEntry = heaps.pollFirstEntry();
if (LOG.isTraceEnabled()) {
LOG.trace("heap key: " + heapEntry.getKey() + ", sentence length: " + sentenceLength);
}
if (heapEntry.getKey() == sentenceLength) {
finalHeap = heapEntry.getValue();
break;
}
PriorityQueue<PosTagSequence> previousHeap = heapEntry.getValue();
// limit the breadth to K
int maxSequences = previousHeap.size() > this.beamWidth ? this.beamWidth : previousHeap.size();
for (int j = 0; j < maxSequences; j++) {
PosTagSequence history = previousHeap.poll();
Token token = history.getNextToken();
if (LOG.isTraceEnabled()) {
LOG.trace("#### Next history ( " + heapEntry.getKey() + "): " + history.toString());
LOG.trace("Prob: " + df.format(history.getScore()));
LOG.trace("Token: " + token.getText());
StringBuilder sb = new StringBuilder();
for (Token oneToken : history.getTokenSequence().listWithWhiteSpace()) {
if (oneToken.equals(token))
sb.append("[" + oneToken + "]");
else
sb.append(oneToken);
}
LOG.trace(sb.toString());
}
PosTaggerContext context = new PosTaggerContextImpl(token, history);
List<Decision> decisions = new ArrayList<Decision>();
boolean ruleApplied = false;
// assigned?
if (token.getAttributes().containsKey(PosTagger.POS_TAG_ATTRIBUTE)) {
StringAttribute posTagCodeAttribute = (StringAttribute) token.getAttributes().get(PosTagger.POS_TAG_ATTRIBUTE);
String posTagCode = posTagCodeAttribute.getValue();
Decision positiveRuleDecision = new Decision(posTagCode);
decisions.add(positiveRuleDecision);
positiveRuleDecision.addAuthority("tokenAttribute");
ruleApplied = true;
if (LOG.isTraceEnabled()) {
LOG.trace("Token has attribute \"" + PosTagger.POS_TAG_ATTRIBUTE + "\". Setting posTag to: " + posTagCode);
}
}
// test the positive rules on the current token
if (!ruleApplied) {
if (posTaggerPositiveRules != null) {
for (PosTaggerRule rule : posTaggerPositiveRules) {
if (LOG.isTraceEnabled()) {
LOG.trace("Checking rule: " + rule.getCondition().getName());
}
RuntimeEnvironment env = new RuntimeEnvironment();
FeatureResult<Boolean> ruleResult = rule.getCondition().check(context, env);
if (ruleResult != null && ruleResult.getOutcome()) {
Decision positiveRuleDecision = new Decision(rule.getTag().getCode());
decisions.add(positiveRuleDecision);
positiveRuleDecision.addAuthority(rule.getCondition().getName());
ruleApplied = true;
if (LOG.isTraceEnabled()) {
LOG.trace("Rule applies. Setting posTag to: " + rule.getTag().getCode());
}
break;
}
}
}
}
if (!ruleApplied) {
// test the features on the current token
List<FeatureResult<?>> featureResults = new ArrayList<FeatureResult<?>>();
for (PosTaggerFeature<?> posTaggerFeature : posTaggerFeatures) {
RuntimeEnvironment env = new RuntimeEnvironment();
FeatureResult<?> featureResult = posTaggerFeature.check(context, env);
if (featureResult != null)
featureResults.add(featureResult);
}
if (LOG.isTraceEnabled()) {
SortedSet<String> featureResultSet = featureResults.stream().map(f -> f.toString()).collect(Collectors.toCollection(() -> new TreeSet<String>()));
for (String featureResultString : featureResultSet) {
LOG.trace(featureResultString);
}
}
// evaluate the feature results using the maxent model
decisions = this.decisionMaker.decide(featureResults);
for (ClassificationObserver observer : this.observers) {
observer.onAnalyse(token, featureResults, decisions);
}
// apply the negative rules
Set<String> eliminatedPosTags = new TreeSet<String>();
if (posTaggerNegativeRules != null) {
for (PosTaggerRule rule : posTaggerNegativeRules) {
if (LOG.isTraceEnabled()) {
LOG.trace("Checking negative rule: " + rule.getCondition().getName());
}
RuntimeEnvironment env = new RuntimeEnvironment();
FeatureResult<Boolean> ruleResult = rule.getCondition().check(context, env);
if (ruleResult != null && ruleResult.getOutcome()) {
eliminatedPosTags.add(rule.getTag().getCode());
if (LOG.isTraceEnabled()) {
LOG.trace("Rule applies. Eliminating posTag: " + rule.getTag().getCode());
}
}
}
if (eliminatedPosTags.size() > 0) {
List<Decision> decisionShortList = new ArrayList<Decision>();
for (Decision decision : decisions) {
if (!eliminatedPosTags.contains(decision.getOutcome())) {
decisionShortList.add(decision);
} else {
LOG.trace("Eliminating decision: " + decision.toString());
}
}
if (decisionShortList.size() > 0) {
decisions = decisionShortList;
} else {
LOG.debug("All decisions eliminated! Restoring original decisions.");
}
}
}
// is this a known word in the lexicon?
if (LOG.isTraceEnabled()) {
String posTags = "";
for (PosTag onePosTag : token.getPossiblePosTags()) {
posTags += onePosTag.getCode() + ",";
}
LOG.trace("Token: " + token.getText() + ". PosTags: " + posTags);
}
List<Decision> decisionShortList = new ArrayList<Decision>();
for (Decision decision : decisions) {
if (decision.getProbability() >= MIN_PROB_TO_STORE) {
decisionShortList.add(decision);
}
}
if (decisionShortList.size() > 0) {
decisions = decisionShortList;
}
}
// outcome provided by MaxEnt
for (Decision decision : decisions) {
if (LOG.isTraceEnabled())
LOG.trace("Outcome: " + decision.getOutcome() + ", " + decision.getProbability());
PosTaggedToken posTaggedToken = new PosTaggedToken(token, decision, this.sessionId);
PosTagSequence sequence = new PosTagSequence(history);
sequence.addPosTaggedToken(posTaggedToken);
if (decision.isStatistical())
sequence.addDecision(decision);
double heapIndex = token.getEndIndex();
// it from regular ones
if (token.getStartIndex() == token.getEndIndex())
heapIndex += 0.5;
// if it's the last token, make sure we end
if (token.getIndex() == sequence.getTokenSequence().size() - 1)
heapIndex = sentenceLength;
if (LOG.isTraceEnabled())
LOG.trace("Heap index: " + heapIndex);
PriorityQueue<PosTagSequence> heap = heaps.get(heapIndex);
if (heap == null) {
heap = new PriorityQueue<PosTagSequence>();
heaps.put(heapIndex, heap);
}
heap.add(sequence);
}
// next outcome for this token
}
// next history
}
// next atomic index
// return the best sequence on the heap
List<PosTagSequence> sequences = new ArrayList<PosTagSequence>();
int i = 0;
while (!finalHeap.isEmpty()) {
// clone the pos tag sequences to ensure they don't share any underlying
// data (e.g. token sequences)
sequences.add(finalHeap.poll().clonePosTagSequence());
i++;
if (i >= this.getBeamWidth())
break;
}
// apply post-processing filters
if (LOG.isDebugEnabled()) {
LOG.debug("####Final postag sequences:");
int j = 1;
for (PosTagSequence sequence : sequences) {
if (LOG.isDebugEnabled()) {
LOG.debug("Sequence " + (j++) + ", score=" + df.format(sequence.getScore()));
LOG.debug("Sequence: " + sequence);
}
}
}
return sequences;
}
use of com.joliciel.talismane.machineLearning.ClassificationObserver in project talismane by joliciel-informatique.
the class PatternTokeniser method tokeniseInternal.
@Override
protected List<TokenisedAtomicTokenSequence> tokeniseInternal(TokenSequence initialSequence, Sentence sentence) throws TalismaneException, IOException {
List<TokenisedAtomicTokenSequence> sequences;
// Assign each separator its default value
List<TokeniserOutcome> defaultOutcomes = this.tokeniserPatternManager.getDefaultOutcomes(initialSequence);
List<Decision> defaultDecisions = new ArrayList<Decision>(defaultOutcomes.size());
for (TokeniserOutcome outcome : defaultOutcomes) {
Decision tokeniserDecision = new Decision(outcome.name());
tokeniserDecision.addAuthority("_" + this.getClass().getSimpleName());
tokeniserDecision.addAuthority("_" + "DefaultDecision");
defaultDecisions.add(tokeniserDecision);
}
// For each test pattern, see if anything in the sentence matches it
if (this.decisionMaker != null) {
List<TokenPatternMatchSequence> matchingSequences = new ArrayList<TokenPatternMatchSequence>();
Map<Token, Set<TokenPatternMatchSequence>> tokenMatchSequenceMap = new HashMap<Token, Set<TokenPatternMatchSequence>>();
Map<TokenPatternMatchSequence, TokenPatternMatch> primaryMatchMap = new HashMap<TokenPatternMatchSequence, TokenPatternMatch>();
Set<Token> matchedTokens = new HashSet<Token>();
for (TokenPattern parsedPattern : this.getTokeniserPatternManager().getParsedTestPatterns()) {
List<TokenPatternMatchSequence> matchesForThisPattern = parsedPattern.match(initialSequence);
for (TokenPatternMatchSequence matchSequence : matchesForThisPattern) {
if (matchSequence.getTokensToCheck().size() > 0) {
matchingSequences.add(matchSequence);
matchedTokens.addAll(matchSequence.getTokensToCheck());
TokenPatternMatch primaryMatch = null;
Token token = matchSequence.getTokensToCheck().get(0);
Set<TokenPatternMatchSequence> matchSequences = tokenMatchSequenceMap.get(token);
if (matchSequences == null) {
matchSequences = new TreeSet<TokenPatternMatchSequence>();
tokenMatchSequenceMap.put(token, matchSequences);
}
matchSequences.add(matchSequence);
for (TokenPatternMatch patternMatch : matchSequence.getTokenPatternMatches()) {
if (patternMatch.getToken().equals(token)) {
primaryMatch = patternMatch;
break;
}
}
if (LOG.isTraceEnabled()) {
LOG.trace("Found match: " + primaryMatch);
}
primaryMatchMap.put(matchSequence, primaryMatch);
}
}
}
// we want to create the n most likely token sequences
// the sequence has to correspond to a token pattern
Map<TokenPatternMatchSequence, List<Decision>> matchSequenceDecisionMap = new HashMap<TokenPatternMatchSequence, List<Decision>>();
for (TokenPatternMatchSequence matchSequence : matchingSequences) {
TokenPatternMatch match = primaryMatchMap.get(matchSequence);
LOG.debug("next pattern match: " + match.toString());
List<FeatureResult<?>> tokenFeatureResults = new ArrayList<FeatureResult<?>>();
for (TokenPatternMatchFeature<?> feature : features) {
RuntimeEnvironment env = new RuntimeEnvironment();
FeatureResult<?> featureResult = feature.check(match, env);
if (featureResult != null) {
tokenFeatureResults.add(featureResult);
}
}
if (LOG.isTraceEnabled()) {
SortedSet<String> featureResultSet = tokenFeatureResults.stream().map(f -> f.toString()).collect(Collectors.toCollection(() -> new TreeSet<String>()));
for (String featureResultString : featureResultSet) {
LOG.trace(featureResultString);
}
}
List<Decision> decisions = this.decisionMaker.decide(tokenFeatureResults);
for (ClassificationObserver observer : this.observers) observer.onAnalyse(match.getToken(), tokenFeatureResults, decisions);
for (Decision decision : decisions) {
decision.addAuthority("_" + this.getClass().getSimpleName());
decision.addAuthority("_" + "Patterns");
decision.addAuthority(match.getPattern().getName());
}
matchSequenceDecisionMap.put(matchSequence, decisions);
}
// initially create a heap with a single, empty sequence
PriorityQueue<TokenisedAtomicTokenSequence> heap = new PriorityQueue<TokenisedAtomicTokenSequence>();
TokenisedAtomicTokenSequence emptySequence = new TokenisedAtomicTokenSequence(sentence, 0, this.getSessionId());
heap.add(emptySequence);
for (int i = 0; i < initialSequence.listWithWhiteSpace().size(); i++) {
Token token = initialSequence.listWithWhiteSpace().get(i);
if (LOG.isTraceEnabled()) {
LOG.trace("Token : \"" + token.getAnalyisText() + "\"");
}
// build a new heap for this iteration
PriorityQueue<TokenisedAtomicTokenSequence> previousHeap = heap;
heap = new PriorityQueue<TokenisedAtomicTokenSequence>();
if (i == 0) {
// first token is always "separate" from the outside world
Decision decision = new Decision(TokeniserOutcome.SEPARATE.name());
decision.addAuthority("_" + this.getClass().getSimpleName());
decision.addAuthority("_" + "DefaultDecision");
TaggedToken<TokeniserOutcome> taggedToken = new TaggedToken<>(token, decision, TokeniserOutcome.valueOf(decision.getOutcome()));
TokenisedAtomicTokenSequence newSequence = new TokenisedAtomicTokenSequence(emptySequence);
newSequence.add(taggedToken);
heap.add(newSequence);
continue;
}
// limit the heap breadth to K
int maxSequences = previousHeap.size() > this.getBeamWidth() ? this.getBeamWidth() : previousHeap.size();
for (int j = 0; j < maxSequences; j++) {
TokenisedAtomicTokenSequence history = previousHeap.poll();
// Find the separating & non-separating decisions
if (history.size() > i) {
// token already added as part of a sequence
// introduced by another token
heap.add(history);
} else if (tokenMatchSequenceMap.containsKey(token)) {
// token begins one or more match sequences
// these are ordered from shortest to longest (via
// TreeSet)
List<TokenPatternMatchSequence> matchSequences = new ArrayList<TokenPatternMatchSequence>(tokenMatchSequenceMap.get(token));
// Since sequences P1..Pn contain each other,
// there can be exactly matchSequences.size()
// consistent solutions
// Assume the default is separate
// 0: all separate
// 1: join P1, separate rest
// 2: join P2, separate rest
// ...
// n: join Pn
// We need to add each of these to the heap
// by taking the product of all probabilities
// consistent with each solution
// The probabities for each solution are (j=join,
// s=separate)
// All separate: s1 x s2 x ... x sn
// P1: j1 x s2 x ... x sn
// P2: j1 x j2 x ... x sn
// ...
// Pn: j1 x j2 x ... x jn
// Any solution of the form s1 x j2 would be
// inconsistent, and is not considered
// If Pi and Pj start and end on the exact same
// token, then the solution for both is
// Pi: j1 x ... x ji x jj x sj+1 ... x sn
// Pj: j1 x ... x ji x jj x sj+1 ... x sn
// Note of course that we're never likely to have
// more than two Ps here,
// but we need a solution for more just to be sure
// to be sure
TokeniserOutcome defaultOutcome = TokeniserOutcome.valueOf(defaultDecisions.get(token.getIndexWithWhiteSpace()).getOutcome());
TokeniserOutcome otherOutcome = null;
if (defaultOutcome == TokeniserOutcome.SEPARATE)
otherOutcome = TokeniserOutcome.JOIN;
else
otherOutcome = TokeniserOutcome.SEPARATE;
double[] decisionProbs = new double[matchSequences.size() + 1];
for (int k = 0; k < decisionProbs.length; k++) decisionProbs[k] = 1;
// Note: k0 = default decision (e.g. separate all),
// k1=first pattern
// p1 = first pattern
int p = 1;
int prevEndIndex = -1;
for (TokenPatternMatchSequence matchSequence : matchSequences) {
int endIndex = matchSequence.getTokensToCheck().get(matchSequence.getTokensToCheck().size() - 1).getEndIndex();
List<Decision> decisions = matchSequenceDecisionMap.get(matchSequence);
for (Decision decision : decisions) {
for (int k = 0; k < decisionProbs.length; k++) {
if (decision.getOutcome().equals(defaultOutcome.name())) {
// e.g. separate in most cases
if (k < p && endIndex > prevEndIndex)
decisionProbs[k] *= decision.getProbability();
else if (k + 1 < p && endIndex <= prevEndIndex)
decisionProbs[k] *= decision.getProbability();
} else {
// e.g. join in most cases
if (k >= p && endIndex > prevEndIndex)
decisionProbs[k] *= decision.getProbability();
else if (k + 1 >= p && endIndex <= prevEndIndex)
decisionProbs[k] *= decision.getProbability();
}
}
// next k
}
// next decision (only 2 of these)
prevEndIndex = endIndex;
p++;
}
// transform to probability distribution
double sumProbs = 0;
for (int k = 0; k < decisionProbs.length; k++) sumProbs += decisionProbs[k];
if (sumProbs > 0)
for (int k = 0; k < decisionProbs.length; k++) decisionProbs[k] /= sumProbs;
// Apply default decision
// Since this is the default decision for all tokens
// in the sequence, we don't add the other tokens
// for now,
// so as to allow them
// to get examined one at a time, just in case one
// of them starts its own separate sequence
Decision defaultDecision = new Decision(defaultOutcome.name(), decisionProbs[0]);
defaultDecision.addAuthority("_" + this.getClass().getSimpleName());
defaultDecision.addAuthority("_" + "Patterns");
for (TokenPatternMatchSequence matchSequence : matchSequences) {
defaultDecision.addAuthority(matchSequence.getTokenPattern().getName());
}
TaggedToken<TokeniserOutcome> defaultTaggedToken = new TaggedToken<>(token, defaultDecision, TokeniserOutcome.valueOf(defaultDecision.getOutcome()));
TokenisedAtomicTokenSequence defaultSequence = new TokenisedAtomicTokenSequence(history);
defaultSequence.add(defaultTaggedToken);
defaultSequence.addDecision(defaultDecision);
heap.add(defaultSequence);
// Apply one non-default decision per match sequence
for (int k = 0; k < matchSequences.size(); k++) {
TokenPatternMatchSequence matchSequence = matchSequences.get(k);
double prob = decisionProbs[k + 1];
Decision decision = new Decision(otherOutcome.name(), prob);
decision.addAuthority("_" + this.getClass().getSimpleName());
decision.addAuthority("_" + "Patterns");
decision.addAuthority(matchSequence.getTokenPattern().getName());
TaggedToken<TokeniserOutcome> taggedToken = new TaggedToken<>(token, decision, TokeniserOutcome.valueOf(decision.getOutcome()));
TokenisedAtomicTokenSequence newSequence = new TokenisedAtomicTokenSequence(history);
newSequence.add(taggedToken);
newSequence.addDecision(decision);
// in this sequence to the solution
for (Token tokenInSequence : matchSequence.getTokensToCheck()) {
if (tokenInSequence.equals(token)) {
continue;
}
Decision decisionInSequence = new Decision(decision.getOutcome());
decisionInSequence.addAuthority("_" + this.getClass().getSimpleName());
decisionInSequence.addAuthority("_" + "DecisionInSequence");
decisionInSequence.addAuthority("_" + "DecisionInSequence_non_default");
decisionInSequence.addAuthority("_" + "Patterns");
TaggedToken<TokeniserOutcome> taggedTokenInSequence = new TaggedToken<>(tokenInSequence, decisionInSequence, TokeniserOutcome.valueOf(decisionInSequence.getOutcome()));
newSequence.add(taggedTokenInSequence);
}
heap.add(newSequence);
}
// next sequence
} else {
// token doesn't start match sequence, and hasn't
// already been added to the current sequence
Decision decision = defaultDecisions.get(i);
if (matchedTokens.contains(token)) {
decision = new Decision(decision.getOutcome());
decision.addAuthority("_" + this.getClass().getSimpleName());
decision.addAuthority("_" + "DecisionInSequence");
decision.addAuthority("_" + "DecisionInSequence_default");
decision.addAuthority("_" + "Patterns");
}
TaggedToken<TokeniserOutcome> taggedToken = new TaggedToken<>(token, decision, TokeniserOutcome.valueOf(decision.getOutcome()));
TokenisedAtomicTokenSequence newSequence = new TokenisedAtomicTokenSequence(history);
newSequence.add(taggedToken);
heap.add(newSequence);
}
}
// next sequence in the old heap
}
// next token
sequences = new ArrayList<TokenisedAtomicTokenSequence>();
int k = 0;
while (!heap.isEmpty()) {
sequences.add(heap.poll());
k++;
if (k >= this.getBeamWidth())
break;
}
} else {
sequences = new ArrayList<TokenisedAtomicTokenSequence>();
TokenisedAtomicTokenSequence defaultSequence = new TokenisedAtomicTokenSequence(sentence, 0, this.getSessionId());
int i = 0;
for (Token token : initialSequence.listWithWhiteSpace()) {
Decision decision = defaultDecisions.get(i++);
TaggedToken<TokeniserOutcome> taggedToken = new TaggedToken<>(token, decision, TokeniserOutcome.valueOf(decision.getOutcome()));
defaultSequence.add(taggedToken);
}
sequences.add(defaultSequence);
}
// have decision maker?
return sequences;
}
Aggregations