Search in sources :

Example 1 with ClassificationObserver

use of com.joliciel.talismane.machineLearning.ClassificationObserver in project talismane by joliciel-informatique.

the class TransitionBasedParser method parseSentence.

@Override
public List<ParseConfiguration> parseSentence(List<PosTagSequence> input) throws TalismaneException, IOException {
    List<PosTagSequence> posTagSequences = null;
    if (this.propagatePosTaggerBeam) {
        posTagSequences = input;
    } else {
        posTagSequences = new ArrayList<>(1);
        posTagSequences.add(input.get(0));
    }
    long startTime = System.currentTimeMillis();
    int maxAnalysisTimeMilliseconds = maxAnalysisTimePerSentence * 1000;
    int minFreeMemoryBytes = minFreeMemory * KILOBYTE;
    TokenSequence tokenSequence = posTagSequences.get(0).getTokenSequence();
    TreeMap<Integer, PriorityQueue<ParseConfiguration>> heaps = new TreeMap<>();
    PriorityQueue<ParseConfiguration> heap0 = new PriorityQueue<>();
    for (PosTagSequence posTagSequence : posTagSequences) {
        // add an initial ParseConfiguration for each postag sequence
        ParseConfiguration initialConfiguration = new ParseConfiguration(posTagSequence);
        initialConfiguration.setScoringStrategy(decisionMaker.getDefaultScoringStrategy());
        heap0.add(initialConfiguration);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Adding initial posTagSequence: " + posTagSequence);
        }
    }
    heaps.put(0, heap0);
    PriorityQueue<ParseConfiguration> backupHeap = null;
    PriorityQueue<ParseConfiguration> finalHeap = null;
    PriorityQueue<ParseConfiguration> terminalHeap = new PriorityQueue<>();
    while (heaps.size() > 0) {
        Entry<Integer, PriorityQueue<ParseConfiguration>> heapEntry = heaps.pollFirstEntry();
        PriorityQueue<ParseConfiguration> currentHeap = heapEntry.getValue();
        int currentHeapIndex = heapEntry.getKey();
        if (LOG.isTraceEnabled()) {
            LOG.trace("##### Polling next heap: " + heapEntry.getKey() + ", size: " + heapEntry.getValue().size());
        }
        boolean finished = false;
        // systematically set the final heap here, just in case we exit
        // "naturally" with no more heaps
        finalHeap = heapEntry.getValue();
        backupHeap = new PriorityQueue<>();
        // we jump out when either (a) all tokens have been attached or
        // (b) we go over the max alloted time
        ParseConfiguration topConf = currentHeap.peek();
        if (topConf.isTerminal()) {
            LOG.trace("Exiting with terminal heap: " + heapEntry.getKey() + ", size: " + heapEntry.getValue().size());
            finished = true;
        }
        if (earlyStop && terminalHeap.size() >= beamWidth) {
            LOG.debug("Early stop activated and terminal heap contains " + beamWidth + " entries. Exiting.");
            finalHeap = terminalHeap;
            finished = true;
        }
        long analysisTime = System.currentTimeMillis() - startTime;
        if (maxAnalysisTimePerSentence > 0 && analysisTime > maxAnalysisTimeMilliseconds) {
            LOG.info("Parse tree analysis took too long for sentence: " + tokenSequence.getSentence().getText());
            LOG.info("Breaking out after " + maxAnalysisTimePerSentence + " seconds.");
            finished = true;
        }
        if (minFreeMemory > 0) {
            long freeMemory = Runtime.getRuntime().freeMemory();
            if (freeMemory < minFreeMemoryBytes) {
                LOG.info("Not enough memory left to parse sentence: " + tokenSequence.getSentence().getText());
                LOG.info("Min free memory (bytes):" + minFreeMemoryBytes);
                LOG.info("Current free memory (bytes): " + freeMemory);
                finished = true;
            }
        }
        if (finished) {
            break;
        }
        // limit the breadth to K
        int maxSequences = currentHeap.size() > this.beamWidth ? this.beamWidth : currentHeap.size();
        int j = 0;
        while (currentHeap.size() > 0) {
            ParseConfiguration history = currentHeap.poll();
            if (LOG.isTraceEnabled()) {
                LOG.trace("### Next configuration on heap " + heapEntry.getKey() + ":");
                LOG.trace(history.toString());
                LOG.trace("Score: " + df.format(history.getScore()));
                LOG.trace(history.getPosTagSequence().toString());
            }
            List<Decision> decisions = new ArrayList<>();
            // test the positive rules on the current configuration
            boolean ruleApplied = false;
            if (parserPositiveRules != null) {
                for (ParserRule rule : parserPositiveRules) {
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("Checking rule: " + rule.toString());
                    }
                    RuntimeEnvironment env = new RuntimeEnvironment();
                    FeatureResult<Boolean> ruleResult = rule.getCondition().check(history, env);
                    if (ruleResult != null && ruleResult.getOutcome()) {
                        Decision positiveRuleDecision = new Decision(rule.getTransition().getCode());
                        decisions.add(positiveRuleDecision);
                        positiveRuleDecision.addAuthority(rule.getCondition().getName());
                        ruleApplied = true;
                        if (LOG.isTraceEnabled()) {
                            LOG.trace("Rule applies. Setting transition to: " + rule.getTransition().getCode());
                        }
                        break;
                    }
                }
            }
            if (!ruleApplied) {
                // test the features on the current configuration
                List<FeatureResult<?>> parseFeatureResults = new ArrayList<>();
                for (ParseConfigurationFeature<?> feature : this.parseFeatures) {
                    RuntimeEnvironment env = new RuntimeEnvironment();
                    FeatureResult<?> featureResult = feature.check(history, env);
                    if (featureResult != null)
                        parseFeatureResults.add(featureResult);
                }
                if (LOG_FEATURES.isTraceEnabled()) {
                    SortedSet<String> featureResultSet = parseFeatureResults.stream().map(f -> f.toString()).collect(Collectors.toCollection(() -> new TreeSet<>()));
                    for (String featureResultString : featureResultSet) {
                        LOG_FEATURES.trace(featureResultString);
                    }
                }
                // evaluate the feature results using the decision maker
                decisions = this.decisionMaker.decide(parseFeatureResults);
                for (ClassificationObserver observer : this.observers) {
                    observer.onAnalyse(history, parseFeatureResults, decisions);
                }
                List<Decision> decisionShortList = new ArrayList<>(decisions.size());
                for (Decision decision : decisions) {
                    if (decision.getProbability() > MIN_PROB_TO_STORE)
                        decisionShortList.add(decision);
                }
                decisions = decisionShortList;
                // apply the negative rules
                Set<String> eliminatedTransitions = new HashSet<>();
                if (parserNegativeRules != null) {
                    for (ParserRule rule : parserNegativeRules) {
                        if (LOG.isTraceEnabled()) {
                            LOG.trace("Checking negative rule: " + rule.toString());
                        }
                        RuntimeEnvironment env = new RuntimeEnvironment();
                        FeatureResult<Boolean> ruleResult = rule.getCondition().check(history, env);
                        if (ruleResult != null && ruleResult.getOutcome()) {
                            for (Transition transition : rule.getTransitions()) {
                                eliminatedTransitions.add(transition.getCode());
                                if (LOG.isTraceEnabled())
                                    LOG.trace("Rule applies. Eliminating transition: " + transition.getCode());
                            }
                        }
                    }
                    if (eliminatedTransitions.size() > 0) {
                        decisionShortList = new ArrayList<>();
                        for (Decision decision : decisions) {
                            if (!eliminatedTransitions.contains(decision.getOutcome())) {
                                decisionShortList.add(decision);
                            } else {
                                LOG.trace("Eliminating decision: " + decision.toString());
                            }
                        }
                        if (decisionShortList.size() > 0) {
                            decisions = decisionShortList;
                        } else {
                            LOG.debug("All decisions eliminated! Restoring original decisions.");
                        }
                    }
                }
            }
            // has a positive rule been applied?
            boolean transitionApplied = false;
            TransitionSystem transitionSystem = TalismaneSession.get(sessionId).getTransitionSystem();
            // type, we should be able to stop
            for (Decision decision : decisions) {
                Transition transition = transitionSystem.getTransitionForCode(decision.getOutcome());
                if (LOG.isTraceEnabled())
                    LOG.trace("Outcome: " + transition.getCode() + ", " + decision.getProbability());
                if (transition.checkPreconditions(history)) {
                    transitionApplied = true;
                    ParseConfiguration configuration = new ParseConfiguration(history);
                    if (decision.isStatistical())
                        configuration.addDecision(decision);
                    transition.apply(configuration);
                    int nextHeapIndex = parseComparisonStrategy.getComparisonIndex(configuration) * 1000;
                    if (configuration.isTerminal()) {
                        nextHeapIndex = Integer.MAX_VALUE;
                    } else {
                        while (nextHeapIndex <= currentHeapIndex) nextHeapIndex++;
                    }
                    PriorityQueue<ParseConfiguration> nextHeap = heaps.get(nextHeapIndex);
                    if (nextHeap == null) {
                        if (configuration.isTerminal())
                            nextHeap = terminalHeap;
                        else
                            nextHeap = new PriorityQueue<>();
                        heaps.put(nextHeapIndex, nextHeap);
                        if (LOG.isTraceEnabled())
                            LOG.trace("Created heap with index: " + nextHeapIndex);
                    }
                    nextHeap.add(configuration);
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("Added configuration with score " + configuration.getScore() + " to heap: " + nextHeapIndex + ", total size: " + nextHeap.size());
                    }
                    configuration.clearMemory();
                } else {
                    if (LOG.isTraceEnabled())
                        LOG.trace("Cannot apply transition: doesn't meet pre-conditions");
                    // just in case the we run out of both heaps and
                    // analyses, we build this backup heap
                    backupHeap.add(history);
                }
            // does transition meet pre-conditions?
            }
            if (transitionApplied) {
                j++;
            } else {
                LOG.trace("No transitions could be applied: not counting this history as part of the beam");
            }
            // beam width test
            if (j == maxSequences)
                break;
        }
    // next history
    }
    // next atomic index
    // return the best sequences on the heap
    List<ParseConfiguration> bestConfigurations = new ArrayList<>();
    int i = 0;
    if (finalHeap.isEmpty())
        finalHeap = backupHeap;
    while (!finalHeap.isEmpty()) {
        bestConfigurations.add(finalHeap.poll());
        i++;
        if (i >= this.getBeamWidth())
            break;
    }
    if (LOG.isDebugEnabled()) {
        for (ParseConfiguration finalConfiguration : bestConfigurations) {
            LOG.debug(df.format(finalConfiguration.getScore()) + ": " + finalConfiguration.toString());
            LOG.debug("Pos tag sequence: " + finalConfiguration.getPosTagSequence());
            LOG.debug("Transitions: " + finalConfiguration.getTransitions());
            LOG.debug("Decisions: " + finalConfiguration.getDecisions());
            if (LOG.isTraceEnabled()) {
                StringBuilder sb = new StringBuilder();
                for (Decision decision : finalConfiguration.getDecisions()) {
                    sb.append(" * ");
                    sb.append(df.format(decision.getProbability()));
                }
                sb.append(" root ");
                sb.append(finalConfiguration.getTransitions().size());
                LOG.trace(sb.toString());
                sb = new StringBuilder();
                sb.append(" * PosTag sequence score ");
                sb.append(df.format(finalConfiguration.getPosTagSequence().getScore()));
                sb.append(" = ");
                for (PosTaggedToken posTaggedToken : finalConfiguration.getPosTagSequence()) {
                    sb.append(" * ");
                    sb.append(df.format(posTaggedToken.getDecision().getProbability()));
                }
                sb.append(" root ");
                sb.append(finalConfiguration.getPosTagSequence().size());
                LOG.trace(sb.toString());
                sb = new StringBuilder();
                sb.append(" * Token sequence score = ");
                sb.append(df.format(finalConfiguration.getPosTagSequence().getTokenSequence().getScore()));
                LOG.trace(sb.toString());
            }
        }
    }
    return bestConfigurations;
}
Also used : ClassificationObserver(com.joliciel.talismane.machineLearning.ClassificationObserver) ZipInputStream(java.util.zip.ZipInputStream) SortedSet(java.util.SortedSet) ParserRule(com.joliciel.talismane.parser.features.ParserRule) PriorityQueue(java.util.PriorityQueue) LoggerFactory(org.slf4j.LoggerFactory) Scanner(java.util.Scanner) HashMap(java.util.HashMap) TokenSequence(com.joliciel.talismane.tokeniser.TokenSequence) MachineLearningModelFactory(com.joliciel.talismane.machineLearning.MachineLearningModelFactory) TreeSet(java.util.TreeSet) TalismaneException(com.joliciel.talismane.TalismaneException) TalismaneSession(com.joliciel.talismane.TalismaneSession) ParseConfigurationFeature(com.joliciel.talismane.parser.features.ParseConfigurationFeature) ArrayList(java.util.ArrayList) ClassificationModel(com.joliciel.talismane.machineLearning.ClassificationModel) HashSet(java.util.HashSet) RuntimeEnvironment(com.joliciel.talismane.machineLearning.features.RuntimeEnvironment) FeatureResult(com.joliciel.talismane.machineLearning.features.FeatureResult) PosTaggedToken(com.joliciel.talismane.posTagger.PosTaggedToken) Map(java.util.Map) ConfigUtils(com.joliciel.talismane.utils.ConfigUtils) ConfigFactory(com.typesafe.config.ConfigFactory) ArrayListNoNulls(com.joliciel.talismane.utils.ArrayListNoNulls) ExternalResource(com.joliciel.talismane.machineLearning.ExternalResource) DecisionMaker(com.joliciel.talismane.machineLearning.DecisionMaker) Logger(org.slf4j.Logger) PosTagSequence(com.joliciel.talismane.posTagger.PosTagSequence) Config(com.typesafe.config.Config) Collection(java.util.Collection) DecimalFormat(java.text.DecimalFormat) Set(java.util.Set) IOException(java.io.IOException) Decision(com.joliciel.talismane.machineLearning.Decision) Collectors(java.util.stream.Collectors) File(java.io.File) List(java.util.List) TreeMap(java.util.TreeMap) Entry(java.util.Map.Entry) InputStream(java.io.InputStream) ParserFeatureParser(com.joliciel.talismane.parser.features.ParserFeatureParser) ParserRule(com.joliciel.talismane.parser.features.ParserRule) ArrayList(java.util.ArrayList) TreeSet(java.util.TreeSet) HashSet(java.util.HashSet) RuntimeEnvironment(com.joliciel.talismane.machineLearning.features.RuntimeEnvironment) PosTaggedToken(com.joliciel.talismane.posTagger.PosTaggedToken) PriorityQueue(java.util.PriorityQueue) TreeMap(java.util.TreeMap) Decision(com.joliciel.talismane.machineLearning.Decision) ClassificationObserver(com.joliciel.talismane.machineLearning.ClassificationObserver) PosTagSequence(com.joliciel.talismane.posTagger.PosTagSequence) TokenSequence(com.joliciel.talismane.tokeniser.TokenSequence) FeatureResult(com.joliciel.talismane.machineLearning.features.FeatureResult)

Example 2 with ClassificationObserver

use of com.joliciel.talismane.machineLearning.ClassificationObserver in project talismane by joliciel-informatique.

the class ForwardStatisticalPosTagger method tagSentence.

@Override
public List<PosTagSequence> tagSentence(List<TokenSequence> input) throws TalismaneException, IOException {
    List<TokenSequence> tokenSequences = null;
    if (this.propagateTokeniserBeam) {
        tokenSequences = input;
    } else {
        tokenSequences = new ArrayList<>(1);
        tokenSequences.add(input.get(0));
    }
    int sentenceLength = tokenSequences.get(0).getSentence().getText().length();
    TreeMap<Double, PriorityQueue<PosTagSequence>> heaps = new TreeMap<Double, PriorityQueue<PosTagSequence>>();
    PriorityQueue<PosTagSequence> heap0 = new PriorityQueue<PosTagSequence>();
    for (TokenSequence tokenSequence : tokenSequences) {
        // add an empty PosTagSequence for each token sequence
        PosTagSequence emptySequence = new PosTagSequence(tokenSequence);
        emptySequence.setScoringStrategy(decisionMaker.getDefaultScoringStrategy());
        heap0.add(emptySequence);
    }
    heaps.put(0.0, heap0);
    PriorityQueue<PosTagSequence> finalHeap = null;
    while (heaps.size() > 0) {
        Entry<Double, PriorityQueue<PosTagSequence>> heapEntry = heaps.pollFirstEntry();
        if (LOG.isTraceEnabled()) {
            LOG.trace("heap key: " + heapEntry.getKey() + ", sentence length: " + sentenceLength);
        }
        if (heapEntry.getKey() == sentenceLength) {
            finalHeap = heapEntry.getValue();
            break;
        }
        PriorityQueue<PosTagSequence> previousHeap = heapEntry.getValue();
        // limit the breadth to K
        int maxSequences = previousHeap.size() > this.beamWidth ? this.beamWidth : previousHeap.size();
        for (int j = 0; j < maxSequences; j++) {
            PosTagSequence history = previousHeap.poll();
            Token token = history.getNextToken();
            if (LOG.isTraceEnabled()) {
                LOG.trace("#### Next history ( " + heapEntry.getKey() + "): " + history.toString());
                LOG.trace("Prob: " + df.format(history.getScore()));
                LOG.trace("Token: " + token.getText());
                StringBuilder sb = new StringBuilder();
                for (Token oneToken : history.getTokenSequence().listWithWhiteSpace()) {
                    if (oneToken.equals(token))
                        sb.append("[" + oneToken + "]");
                    else
                        sb.append(oneToken);
                }
                LOG.trace(sb.toString());
            }
            PosTaggerContext context = new PosTaggerContextImpl(token, history);
            List<Decision> decisions = new ArrayList<Decision>();
            boolean ruleApplied = false;
            // assigned?
            if (token.getAttributes().containsKey(PosTagger.POS_TAG_ATTRIBUTE)) {
                StringAttribute posTagCodeAttribute = (StringAttribute) token.getAttributes().get(PosTagger.POS_TAG_ATTRIBUTE);
                String posTagCode = posTagCodeAttribute.getValue();
                Decision positiveRuleDecision = new Decision(posTagCode);
                decisions.add(positiveRuleDecision);
                positiveRuleDecision.addAuthority("tokenAttribute");
                ruleApplied = true;
                if (LOG.isTraceEnabled()) {
                    LOG.trace("Token has attribute \"" + PosTagger.POS_TAG_ATTRIBUTE + "\". Setting posTag to: " + posTagCode);
                }
            }
            // test the positive rules on the current token
            if (!ruleApplied) {
                if (posTaggerPositiveRules != null) {
                    for (PosTaggerRule rule : posTaggerPositiveRules) {
                        if (LOG.isTraceEnabled()) {
                            LOG.trace("Checking rule: " + rule.getCondition().getName());
                        }
                        RuntimeEnvironment env = new RuntimeEnvironment();
                        FeatureResult<Boolean> ruleResult = rule.getCondition().check(context, env);
                        if (ruleResult != null && ruleResult.getOutcome()) {
                            Decision positiveRuleDecision = new Decision(rule.getTag().getCode());
                            decisions.add(positiveRuleDecision);
                            positiveRuleDecision.addAuthority(rule.getCondition().getName());
                            ruleApplied = true;
                            if (LOG.isTraceEnabled()) {
                                LOG.trace("Rule applies. Setting posTag to: " + rule.getTag().getCode());
                            }
                            break;
                        }
                    }
                }
            }
            if (!ruleApplied) {
                // test the features on the current token
                List<FeatureResult<?>> featureResults = new ArrayList<FeatureResult<?>>();
                for (PosTaggerFeature<?> posTaggerFeature : posTaggerFeatures) {
                    RuntimeEnvironment env = new RuntimeEnvironment();
                    FeatureResult<?> featureResult = posTaggerFeature.check(context, env);
                    if (featureResult != null)
                        featureResults.add(featureResult);
                }
                if (LOG.isTraceEnabled()) {
                    SortedSet<String> featureResultSet = featureResults.stream().map(f -> f.toString()).collect(Collectors.toCollection(() -> new TreeSet<String>()));
                    for (String featureResultString : featureResultSet) {
                        LOG.trace(featureResultString);
                    }
                }
                // evaluate the feature results using the maxent model
                decisions = this.decisionMaker.decide(featureResults);
                for (ClassificationObserver observer : this.observers) {
                    observer.onAnalyse(token, featureResults, decisions);
                }
                // apply the negative rules
                Set<String> eliminatedPosTags = new TreeSet<String>();
                if (posTaggerNegativeRules != null) {
                    for (PosTaggerRule rule : posTaggerNegativeRules) {
                        if (LOG.isTraceEnabled()) {
                            LOG.trace("Checking negative rule: " + rule.getCondition().getName());
                        }
                        RuntimeEnvironment env = new RuntimeEnvironment();
                        FeatureResult<Boolean> ruleResult = rule.getCondition().check(context, env);
                        if (ruleResult != null && ruleResult.getOutcome()) {
                            eliminatedPosTags.add(rule.getTag().getCode());
                            if (LOG.isTraceEnabled()) {
                                LOG.trace("Rule applies. Eliminating posTag: " + rule.getTag().getCode());
                            }
                        }
                    }
                    if (eliminatedPosTags.size() > 0) {
                        List<Decision> decisionShortList = new ArrayList<Decision>();
                        for (Decision decision : decisions) {
                            if (!eliminatedPosTags.contains(decision.getOutcome())) {
                                decisionShortList.add(decision);
                            } else {
                                LOG.trace("Eliminating decision: " + decision.toString());
                            }
                        }
                        if (decisionShortList.size() > 0) {
                            decisions = decisionShortList;
                        } else {
                            LOG.debug("All decisions eliminated! Restoring original decisions.");
                        }
                    }
                }
                // is this a known word in the lexicon?
                if (LOG.isTraceEnabled()) {
                    String posTags = "";
                    for (PosTag onePosTag : token.getPossiblePosTags()) {
                        posTags += onePosTag.getCode() + ",";
                    }
                    LOG.trace("Token: " + token.getText() + ". PosTags: " + posTags);
                }
                List<Decision> decisionShortList = new ArrayList<Decision>();
                for (Decision decision : decisions) {
                    if (decision.getProbability() >= MIN_PROB_TO_STORE) {
                        decisionShortList.add(decision);
                    }
                }
                if (decisionShortList.size() > 0) {
                    decisions = decisionShortList;
                }
            }
            // outcome provided by MaxEnt
            for (Decision decision : decisions) {
                if (LOG.isTraceEnabled())
                    LOG.trace("Outcome: " + decision.getOutcome() + ", " + decision.getProbability());
                PosTaggedToken posTaggedToken = new PosTaggedToken(token, decision, this.sessionId);
                PosTagSequence sequence = new PosTagSequence(history);
                sequence.addPosTaggedToken(posTaggedToken);
                if (decision.isStatistical())
                    sequence.addDecision(decision);
                double heapIndex = token.getEndIndex();
                // it from regular ones
                if (token.getStartIndex() == token.getEndIndex())
                    heapIndex += 0.5;
                // if it's the last token, make sure we end
                if (token.getIndex() == sequence.getTokenSequence().size() - 1)
                    heapIndex = sentenceLength;
                if (LOG.isTraceEnabled())
                    LOG.trace("Heap index: " + heapIndex);
                PriorityQueue<PosTagSequence> heap = heaps.get(heapIndex);
                if (heap == null) {
                    heap = new PriorityQueue<PosTagSequence>();
                    heaps.put(heapIndex, heap);
                }
                heap.add(sequence);
            }
        // next outcome for this token
        }
    // next history
    }
    // next atomic index
    // return the best sequence on the heap
    List<PosTagSequence> sequences = new ArrayList<PosTagSequence>();
    int i = 0;
    while (!finalHeap.isEmpty()) {
        // clone the pos tag sequences to ensure they don't share any underlying
        // data (e.g. token sequences)
        sequences.add(finalHeap.poll().clonePosTagSequence());
        i++;
        if (i >= this.getBeamWidth())
            break;
    }
    // apply post-processing filters
    if (LOG.isDebugEnabled()) {
        LOG.debug("####Final postag sequences:");
        int j = 1;
        for (PosTagSequence sequence : sequences) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Sequence " + (j++) + ", score=" + df.format(sequence.getScore()));
                LOG.debug("Sequence: " + sequence);
            }
        }
    }
    return sequences;
}
Also used : ClassificationObserver(com.joliciel.talismane.machineLearning.ClassificationObserver) ZipInputStream(java.util.zip.ZipInputStream) SortedSet(java.util.SortedSet) PriorityQueue(java.util.PriorityQueue) LoggerFactory(org.slf4j.LoggerFactory) Scanner(java.util.Scanner) HashMap(java.util.HashMap) TokenSequence(com.joliciel.talismane.tokeniser.TokenSequence) MachineLearningModelFactory(com.joliciel.talismane.machineLearning.MachineLearningModelFactory) TreeSet(java.util.TreeSet) TalismaneException(com.joliciel.talismane.TalismaneException) TalismaneSession(com.joliciel.talismane.TalismaneSession) ArrayList(java.util.ArrayList) ClassificationModel(com.joliciel.talismane.machineLearning.ClassificationModel) PosTaggerRule(com.joliciel.talismane.posTagger.features.PosTaggerRule) HashSet(java.util.HashSet) RuntimeEnvironment(com.joliciel.talismane.machineLearning.features.RuntimeEnvironment) PosTaggerFeature(com.joliciel.talismane.posTagger.features.PosTaggerFeature) FeatureResult(com.joliciel.talismane.machineLearning.features.FeatureResult) Map(java.util.Map) ConfigUtils(com.joliciel.talismane.utils.ConfigUtils) ConfigFactory(com.typesafe.config.ConfigFactory) ArrayListNoNulls(com.joliciel.talismane.utils.ArrayListNoNulls) ExternalResource(com.joliciel.talismane.machineLearning.ExternalResource) DecisionMaker(com.joliciel.talismane.machineLearning.DecisionMaker) StringAttribute(com.joliciel.talismane.tokeniser.StringAttribute) Logger(org.slf4j.Logger) Config(com.typesafe.config.Config) Collection(java.util.Collection) DecimalFormat(java.text.DecimalFormat) Set(java.util.Set) IOException(java.io.IOException) Decision(com.joliciel.talismane.machineLearning.Decision) Collectors(java.util.stream.Collectors) File(java.io.File) List(java.util.List) TreeMap(java.util.TreeMap) PosTaggerFeatureParser(com.joliciel.talismane.posTagger.features.PosTaggerFeatureParser) Token(com.joliciel.talismane.tokeniser.Token) Entry(java.util.Map.Entry) InputStream(java.io.InputStream) ArrayList(java.util.ArrayList) StringAttribute(com.joliciel.talismane.tokeniser.StringAttribute) Token(com.joliciel.talismane.tokeniser.Token) PosTaggerRule(com.joliciel.talismane.posTagger.features.PosTaggerRule) TreeSet(java.util.TreeSet) RuntimeEnvironment(com.joliciel.talismane.machineLearning.features.RuntimeEnvironment) PriorityQueue(java.util.PriorityQueue) TreeMap(java.util.TreeMap) Decision(com.joliciel.talismane.machineLearning.Decision) ClassificationObserver(com.joliciel.talismane.machineLearning.ClassificationObserver) TokenSequence(com.joliciel.talismane.tokeniser.TokenSequence) FeatureResult(com.joliciel.talismane.machineLearning.features.FeatureResult)

Example 3 with ClassificationObserver

use of com.joliciel.talismane.machineLearning.ClassificationObserver in project talismane by joliciel-informatique.

the class PatternTokeniser method tokeniseInternal.

@Override
protected List<TokenisedAtomicTokenSequence> tokeniseInternal(TokenSequence initialSequence, Sentence sentence) throws TalismaneException, IOException {
    List<TokenisedAtomicTokenSequence> sequences;
    // Assign each separator its default value
    List<TokeniserOutcome> defaultOutcomes = this.tokeniserPatternManager.getDefaultOutcomes(initialSequence);
    List<Decision> defaultDecisions = new ArrayList<Decision>(defaultOutcomes.size());
    for (TokeniserOutcome outcome : defaultOutcomes) {
        Decision tokeniserDecision = new Decision(outcome.name());
        tokeniserDecision.addAuthority("_" + this.getClass().getSimpleName());
        tokeniserDecision.addAuthority("_" + "DefaultDecision");
        defaultDecisions.add(tokeniserDecision);
    }
    // For each test pattern, see if anything in the sentence matches it
    if (this.decisionMaker != null) {
        List<TokenPatternMatchSequence> matchingSequences = new ArrayList<TokenPatternMatchSequence>();
        Map<Token, Set<TokenPatternMatchSequence>> tokenMatchSequenceMap = new HashMap<Token, Set<TokenPatternMatchSequence>>();
        Map<TokenPatternMatchSequence, TokenPatternMatch> primaryMatchMap = new HashMap<TokenPatternMatchSequence, TokenPatternMatch>();
        Set<Token> matchedTokens = new HashSet<Token>();
        for (TokenPattern parsedPattern : this.getTokeniserPatternManager().getParsedTestPatterns()) {
            List<TokenPatternMatchSequence> matchesForThisPattern = parsedPattern.match(initialSequence);
            for (TokenPatternMatchSequence matchSequence : matchesForThisPattern) {
                if (matchSequence.getTokensToCheck().size() > 0) {
                    matchingSequences.add(matchSequence);
                    matchedTokens.addAll(matchSequence.getTokensToCheck());
                    TokenPatternMatch primaryMatch = null;
                    Token token = matchSequence.getTokensToCheck().get(0);
                    Set<TokenPatternMatchSequence> matchSequences = tokenMatchSequenceMap.get(token);
                    if (matchSequences == null) {
                        matchSequences = new TreeSet<TokenPatternMatchSequence>();
                        tokenMatchSequenceMap.put(token, matchSequences);
                    }
                    matchSequences.add(matchSequence);
                    for (TokenPatternMatch patternMatch : matchSequence.getTokenPatternMatches()) {
                        if (patternMatch.getToken().equals(token)) {
                            primaryMatch = patternMatch;
                            break;
                        }
                    }
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("Found match: " + primaryMatch);
                    }
                    primaryMatchMap.put(matchSequence, primaryMatch);
                }
            }
        }
        // we want to create the n most likely token sequences
        // the sequence has to correspond to a token pattern
        Map<TokenPatternMatchSequence, List<Decision>> matchSequenceDecisionMap = new HashMap<TokenPatternMatchSequence, List<Decision>>();
        for (TokenPatternMatchSequence matchSequence : matchingSequences) {
            TokenPatternMatch match = primaryMatchMap.get(matchSequence);
            LOG.debug("next pattern match: " + match.toString());
            List<FeatureResult<?>> tokenFeatureResults = new ArrayList<FeatureResult<?>>();
            for (TokenPatternMatchFeature<?> feature : features) {
                RuntimeEnvironment env = new RuntimeEnvironment();
                FeatureResult<?> featureResult = feature.check(match, env);
                if (featureResult != null) {
                    tokenFeatureResults.add(featureResult);
                }
            }
            if (LOG.isTraceEnabled()) {
                SortedSet<String> featureResultSet = tokenFeatureResults.stream().map(f -> f.toString()).collect(Collectors.toCollection(() -> new TreeSet<String>()));
                for (String featureResultString : featureResultSet) {
                    LOG.trace(featureResultString);
                }
            }
            List<Decision> decisions = this.decisionMaker.decide(tokenFeatureResults);
            for (ClassificationObserver observer : this.observers) observer.onAnalyse(match.getToken(), tokenFeatureResults, decisions);
            for (Decision decision : decisions) {
                decision.addAuthority("_" + this.getClass().getSimpleName());
                decision.addAuthority("_" + "Patterns");
                decision.addAuthority(match.getPattern().getName());
            }
            matchSequenceDecisionMap.put(matchSequence, decisions);
        }
        // initially create a heap with a single, empty sequence
        PriorityQueue<TokenisedAtomicTokenSequence> heap = new PriorityQueue<TokenisedAtomicTokenSequence>();
        TokenisedAtomicTokenSequence emptySequence = new TokenisedAtomicTokenSequence(sentence, 0, this.getSessionId());
        heap.add(emptySequence);
        for (int i = 0; i < initialSequence.listWithWhiteSpace().size(); i++) {
            Token token = initialSequence.listWithWhiteSpace().get(i);
            if (LOG.isTraceEnabled()) {
                LOG.trace("Token : \"" + token.getAnalyisText() + "\"");
            }
            // build a new heap for this iteration
            PriorityQueue<TokenisedAtomicTokenSequence> previousHeap = heap;
            heap = new PriorityQueue<TokenisedAtomicTokenSequence>();
            if (i == 0) {
                // first token is always "separate" from the outside world
                Decision decision = new Decision(TokeniserOutcome.SEPARATE.name());
                decision.addAuthority("_" + this.getClass().getSimpleName());
                decision.addAuthority("_" + "DefaultDecision");
                TaggedToken<TokeniserOutcome> taggedToken = new TaggedToken<>(token, decision, TokeniserOutcome.valueOf(decision.getOutcome()));
                TokenisedAtomicTokenSequence newSequence = new TokenisedAtomicTokenSequence(emptySequence);
                newSequence.add(taggedToken);
                heap.add(newSequence);
                continue;
            }
            // limit the heap breadth to K
            int maxSequences = previousHeap.size() > this.getBeamWidth() ? this.getBeamWidth() : previousHeap.size();
            for (int j = 0; j < maxSequences; j++) {
                TokenisedAtomicTokenSequence history = previousHeap.poll();
                // Find the separating & non-separating decisions
                if (history.size() > i) {
                    // token already added as part of a sequence
                    // introduced by another token
                    heap.add(history);
                } else if (tokenMatchSequenceMap.containsKey(token)) {
                    // token begins one or more match sequences
                    // these are ordered from shortest to longest (via
                    // TreeSet)
                    List<TokenPatternMatchSequence> matchSequences = new ArrayList<TokenPatternMatchSequence>(tokenMatchSequenceMap.get(token));
                    // Since sequences P1..Pn contain each other,
                    // there can be exactly matchSequences.size()
                    // consistent solutions
                    // Assume the default is separate
                    // 0: all separate
                    // 1: join P1, separate rest
                    // 2: join P2, separate rest
                    // ...
                    // n: join Pn
                    // We need to add each of these to the heap
                    // by taking the product of all probabilities
                    // consistent with each solution
                    // The probabities for each solution are (j=join,
                    // s=separate)
                    // All separate: s1 x s2 x ... x sn
                    // P1: j1 x s2 x ... x sn
                    // P2: j1 x j2 x ... x sn
                    // ...
                    // Pn: j1 x j2 x ... x jn
                    // Any solution of the form s1 x j2 would be
                    // inconsistent, and is not considered
                    // If Pi and Pj start and end on the exact same
                    // token, then the solution for both is
                    // Pi: j1 x ... x ji x jj x sj+1 ... x sn
                    // Pj: j1 x ... x ji x jj x sj+1 ... x sn
                    // Note of course that we're never likely to have
                    // more than two Ps here,
                    // but we need a solution for more just to be sure
                    // to be sure
                    TokeniserOutcome defaultOutcome = TokeniserOutcome.valueOf(defaultDecisions.get(token.getIndexWithWhiteSpace()).getOutcome());
                    TokeniserOutcome otherOutcome = null;
                    if (defaultOutcome == TokeniserOutcome.SEPARATE)
                        otherOutcome = TokeniserOutcome.JOIN;
                    else
                        otherOutcome = TokeniserOutcome.SEPARATE;
                    double[] decisionProbs = new double[matchSequences.size() + 1];
                    for (int k = 0; k < decisionProbs.length; k++) decisionProbs[k] = 1;
                    // Note: k0 = default decision (e.g. separate all),
                    // k1=first pattern
                    // p1 = first pattern
                    int p = 1;
                    int prevEndIndex = -1;
                    for (TokenPatternMatchSequence matchSequence : matchSequences) {
                        int endIndex = matchSequence.getTokensToCheck().get(matchSequence.getTokensToCheck().size() - 1).getEndIndex();
                        List<Decision> decisions = matchSequenceDecisionMap.get(matchSequence);
                        for (Decision decision : decisions) {
                            for (int k = 0; k < decisionProbs.length; k++) {
                                if (decision.getOutcome().equals(defaultOutcome.name())) {
                                    // e.g. separate in most cases
                                    if (k < p && endIndex > prevEndIndex)
                                        decisionProbs[k] *= decision.getProbability();
                                    else if (k + 1 < p && endIndex <= prevEndIndex)
                                        decisionProbs[k] *= decision.getProbability();
                                } else {
                                    // e.g. join in most cases
                                    if (k >= p && endIndex > prevEndIndex)
                                        decisionProbs[k] *= decision.getProbability();
                                    else if (k + 1 >= p && endIndex <= prevEndIndex)
                                        decisionProbs[k] *= decision.getProbability();
                                }
                            }
                        // next k
                        }
                        // next decision (only 2 of these)
                        prevEndIndex = endIndex;
                        p++;
                    }
                    // transform to probability distribution
                    double sumProbs = 0;
                    for (int k = 0; k < decisionProbs.length; k++) sumProbs += decisionProbs[k];
                    if (sumProbs > 0)
                        for (int k = 0; k < decisionProbs.length; k++) decisionProbs[k] /= sumProbs;
                    // Apply default decision
                    // Since this is the default decision for all tokens
                    // in the sequence, we don't add the other tokens
                    // for now,
                    // so as to allow them
                    // to get examined one at a time, just in case one
                    // of them starts its own separate sequence
                    Decision defaultDecision = new Decision(defaultOutcome.name(), decisionProbs[0]);
                    defaultDecision.addAuthority("_" + this.getClass().getSimpleName());
                    defaultDecision.addAuthority("_" + "Patterns");
                    for (TokenPatternMatchSequence matchSequence : matchSequences) {
                        defaultDecision.addAuthority(matchSequence.getTokenPattern().getName());
                    }
                    TaggedToken<TokeniserOutcome> defaultTaggedToken = new TaggedToken<>(token, defaultDecision, TokeniserOutcome.valueOf(defaultDecision.getOutcome()));
                    TokenisedAtomicTokenSequence defaultSequence = new TokenisedAtomicTokenSequence(history);
                    defaultSequence.add(defaultTaggedToken);
                    defaultSequence.addDecision(defaultDecision);
                    heap.add(defaultSequence);
                    // Apply one non-default decision per match sequence
                    for (int k = 0; k < matchSequences.size(); k++) {
                        TokenPatternMatchSequence matchSequence = matchSequences.get(k);
                        double prob = decisionProbs[k + 1];
                        Decision decision = new Decision(otherOutcome.name(), prob);
                        decision.addAuthority("_" + this.getClass().getSimpleName());
                        decision.addAuthority("_" + "Patterns");
                        decision.addAuthority(matchSequence.getTokenPattern().getName());
                        TaggedToken<TokeniserOutcome> taggedToken = new TaggedToken<>(token, decision, TokeniserOutcome.valueOf(decision.getOutcome()));
                        TokenisedAtomicTokenSequence newSequence = new TokenisedAtomicTokenSequence(history);
                        newSequence.add(taggedToken);
                        newSequence.addDecision(decision);
                        // in this sequence to the solution
                        for (Token tokenInSequence : matchSequence.getTokensToCheck()) {
                            if (tokenInSequence.equals(token)) {
                                continue;
                            }
                            Decision decisionInSequence = new Decision(decision.getOutcome());
                            decisionInSequence.addAuthority("_" + this.getClass().getSimpleName());
                            decisionInSequence.addAuthority("_" + "DecisionInSequence");
                            decisionInSequence.addAuthority("_" + "DecisionInSequence_non_default");
                            decisionInSequence.addAuthority("_" + "Patterns");
                            TaggedToken<TokeniserOutcome> taggedTokenInSequence = new TaggedToken<>(tokenInSequence, decisionInSequence, TokeniserOutcome.valueOf(decisionInSequence.getOutcome()));
                            newSequence.add(taggedTokenInSequence);
                        }
                        heap.add(newSequence);
                    }
                // next sequence
                } else {
                    // token doesn't start match sequence, and hasn't
                    // already been added to the current sequence
                    Decision decision = defaultDecisions.get(i);
                    if (matchedTokens.contains(token)) {
                        decision = new Decision(decision.getOutcome());
                        decision.addAuthority("_" + this.getClass().getSimpleName());
                        decision.addAuthority("_" + "DecisionInSequence");
                        decision.addAuthority("_" + "DecisionInSequence_default");
                        decision.addAuthority("_" + "Patterns");
                    }
                    TaggedToken<TokeniserOutcome> taggedToken = new TaggedToken<>(token, decision, TokeniserOutcome.valueOf(decision.getOutcome()));
                    TokenisedAtomicTokenSequence newSequence = new TokenisedAtomicTokenSequence(history);
                    newSequence.add(taggedToken);
                    heap.add(newSequence);
                }
            }
        // next sequence in the old heap
        }
        // next token
        sequences = new ArrayList<TokenisedAtomicTokenSequence>();
        int k = 0;
        while (!heap.isEmpty()) {
            sequences.add(heap.poll());
            k++;
            if (k >= this.getBeamWidth())
                break;
        }
    } else {
        sequences = new ArrayList<TokenisedAtomicTokenSequence>();
        TokenisedAtomicTokenSequence defaultSequence = new TokenisedAtomicTokenSequence(sentence, 0, this.getSessionId());
        int i = 0;
        for (Token token : initialSequence.listWithWhiteSpace()) {
            Decision decision = defaultDecisions.get(i++);
            TaggedToken<TokeniserOutcome> taggedToken = new TaggedToken<>(token, decision, TokeniserOutcome.valueOf(decision.getOutcome()));
            defaultSequence.add(taggedToken);
        }
        sequences.add(defaultSequence);
    }
    // have decision maker?
    return sequences;
}
Also used : ClassificationObserver(com.joliciel.talismane.machineLearning.ClassificationObserver) ZipInputStream(java.util.zip.ZipInputStream) SortedSet(java.util.SortedSet) PriorityQueue(java.util.PriorityQueue) LoggerFactory(org.slf4j.LoggerFactory) TokenisedAtomicTokenSequence(com.joliciel.talismane.tokeniser.TokenisedAtomicTokenSequence) HashMap(java.util.HashMap) TokenSequence(com.joliciel.talismane.tokeniser.TokenSequence) MachineLearningModelFactory(com.joliciel.talismane.machineLearning.MachineLearningModelFactory) TaggedToken(com.joliciel.talismane.tokeniser.TaggedToken) TreeSet(java.util.TreeSet) TalismaneException(com.joliciel.talismane.TalismaneException) TalismaneSession(com.joliciel.talismane.TalismaneSession) ArrayList(java.util.ArrayList) ClassificationModel(com.joliciel.talismane.machineLearning.ClassificationModel) HashSet(java.util.HashSet) RuntimeEnvironment(com.joliciel.talismane.machineLearning.features.RuntimeEnvironment) TokenPatternMatchFeatureParser(com.joliciel.talismane.tokeniser.features.TokenPatternMatchFeatureParser) TokenPatternMatchFeature(com.joliciel.talismane.tokeniser.features.TokenPatternMatchFeature) FeatureResult(com.joliciel.talismane.machineLearning.features.FeatureResult) Map(java.util.Map) ConfigUtils(com.joliciel.talismane.utils.ConfigUtils) ConfigFactory(com.typesafe.config.ConfigFactory) ExternalResource(com.joliciel.talismane.machineLearning.ExternalResource) DecisionMaker(com.joliciel.talismane.machineLearning.DecisionMaker) Tokeniser(com.joliciel.talismane.tokeniser.Tokeniser) Logger(org.slf4j.Logger) Config(com.typesafe.config.Config) Collection(java.util.Collection) Set(java.util.Set) IOException(java.io.IOException) TokeniserOutcome(com.joliciel.talismane.tokeniser.TokeniserOutcome) Decision(com.joliciel.talismane.machineLearning.Decision) Collectors(java.util.stream.Collectors) File(java.io.File) List(java.util.List) Token(com.joliciel.talismane.tokeniser.Token) Sentence(com.joliciel.talismane.rawText.Sentence) InputStream(java.io.InputStream) SortedSet(java.util.SortedSet) TreeSet(java.util.TreeSet) HashSet(java.util.HashSet) Set(java.util.Set) TaggedToken(com.joliciel.talismane.tokeniser.TaggedToken) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) TaggedToken(com.joliciel.talismane.tokeniser.TaggedToken) Token(com.joliciel.talismane.tokeniser.Token) TokeniserOutcome(com.joliciel.talismane.tokeniser.TokeniserOutcome) TreeSet(java.util.TreeSet) ArrayList(java.util.ArrayList) List(java.util.List) TokenisedAtomicTokenSequence(com.joliciel.talismane.tokeniser.TokenisedAtomicTokenSequence) HashSet(java.util.HashSet) RuntimeEnvironment(com.joliciel.talismane.machineLearning.features.RuntimeEnvironment) PriorityQueue(java.util.PriorityQueue) Decision(com.joliciel.talismane.machineLearning.Decision) ClassificationObserver(com.joliciel.talismane.machineLearning.ClassificationObserver) FeatureResult(com.joliciel.talismane.machineLearning.features.FeatureResult)

Aggregations

TalismaneException (com.joliciel.talismane.TalismaneException)3 TalismaneSession (com.joliciel.talismane.TalismaneSession)3 ClassificationModel (com.joliciel.talismane.machineLearning.ClassificationModel)3 ClassificationObserver (com.joliciel.talismane.machineLearning.ClassificationObserver)3 Decision (com.joliciel.talismane.machineLearning.Decision)3 DecisionMaker (com.joliciel.talismane.machineLearning.DecisionMaker)3 ExternalResource (com.joliciel.talismane.machineLearning.ExternalResource)3 MachineLearningModelFactory (com.joliciel.talismane.machineLearning.MachineLearningModelFactory)3 FeatureResult (com.joliciel.talismane.machineLearning.features.FeatureResult)3 RuntimeEnvironment (com.joliciel.talismane.machineLearning.features.RuntimeEnvironment)3 TokenSequence (com.joliciel.talismane.tokeniser.TokenSequence)3 ConfigUtils (com.joliciel.talismane.utils.ConfigUtils)3 Config (com.typesafe.config.Config)3 ConfigFactory (com.typesafe.config.ConfigFactory)3 File (java.io.File)3 IOException (java.io.IOException)3 InputStream (java.io.InputStream)3 ArrayList (java.util.ArrayList)3 Collection (java.util.Collection)3 HashMap (java.util.HashMap)3