use of edu.stanford.nlp.patterns.dep.DataInstanceDep in project CoreNLP by stanfordnlp.
the class ScorePhrasesLearnFeatWt method chooseUnknownPhrases.
Set<CandidatePhrase> chooseUnknownPhrases(DataInstance sent, Random random, double perSelect, Class positiveClass, String label, int maxNum) {
Set<CandidatePhrase> unknownSamples = new HashSet<>();
if (maxNum == 0)
return unknownSamples;
Function<CoreLabel, Boolean> acceptWord = coreLabel -> {
if (coreLabel.get(positiveClass).equals(label) || constVars.functionWords.contains(coreLabel.word()))
return false;
else
return true;
};
Random r = new Random(0);
List<Integer> lengths = new ArrayList<>();
for (int i = 1; i <= PatternFactory.numWordsCompoundMapped.get(label); i++) lengths.add(i);
int length = CollectionUtils.sample(lengths, r);
if (constVars.patternType.equals(PatternFactory.PatternType.DEP)) {
ExtractPhraseFromPattern extract = new ExtractPhraseFromPattern(true, length);
SemanticGraph g = ((DataInstanceDep) sent).getGraph();
Collection<CoreLabel> sampledHeads = CollectionUtils.sampleWithoutReplacement(sent.getTokens(), Math.min(maxNum, (int) (perSelect * sent.getTokens().size())), random);
//TODO: change this for more efficient implementation
List<String> textTokens = sent.getTokens().stream().map(x -> x.word()).collect(Collectors.toList());
for (CoreLabel l : sampledHeads) {
if (!acceptWord.apply(l))
continue;
IndexedWord w = g.getNodeByIndex(l.index());
List<String> outputPhrases = new ArrayList<>();
List<ExtractedPhrase> extractedPhrases = new ArrayList<>();
List<IntPair> outputIndices = new ArrayList<>();
extract.printSubGraph(g, w, new ArrayList<>(), textTokens, outputPhrases, outputIndices, new ArrayList<>(), new ArrayList<>(), false, extractedPhrases, null, acceptWord);
for (ExtractedPhrase p : extractedPhrases) {
unknownSamples.add(CandidatePhrase.createOrGet(p.getValue(), null, p.getFeatures()));
}
}
} else if (constVars.patternType.equals(PatternFactory.PatternType.SURFACE)) {
CoreLabel[] tokens = sent.getTokens().toArray(new CoreLabel[0]);
for (int i = 0; i < tokens.length; i++) {
if (random.nextDouble() < perSelect) {
int left = (int) ((length - 1) / 2.0);
int right = length - 1 - left;
String ph = "";
boolean haspositive = false;
for (int j = Math.max(0, i - left); j < tokens.length && j <= i + right; j++) {
if (tokens[j].get(positiveClass).equals(label)) {
haspositive = true;
break;
}
ph += " " + tokens[j].word();
}
ph = ph.trim();
if (!haspositive && !ph.trim().isEmpty() && !constVars.functionWords.contains(ph)) {
unknownSamples.add(CandidatePhrase.createOrGet(ph));
}
}
}
} else
throw new RuntimeException("not yet implemented");
return unknownSamples;
}
Aggregations