use of edu.illinois.cs.cogcomp.core.stats.Counter in project cogcomp-nlp by CogComp.
the class CreateTrainDevTestSplit method getBestSplit.
/** iterate over candidate sets of documents; find smallest diff of relation counts with target counts */
private Pair<Set<String>, Counter<String>> getBestSplit(double frac, Set<String> availIds) {
Set<String> split = new HashSet<>();
Counter<String> splitCount = null;
if (frac < 0.01)
return new Pair(split, splitCount);
Map<String, Double> targetCounts = findTargetCounts(frac);
double bestDiff = LARGE_DIFF;
/*
* fill in a table of partial counts. Naive, so size is approx 2 * (n choose k)
* as we keep the last row to save some computation.
* stop as soon as we have a round where we don't improve the bestRoundDiff, as adding more documents
* will not reduce the count differences.
*/
PriorityQueue<QueueElement> oldBestSplitsOfSizeK = new PriorityQueue<>(BEAM_SIZE);
PriorityQueue<QueueElement> bestSplits = new PriorityQueue<>(BEAM_SIZE);
// number of documents in the sets considered
for (int num = 1; num <= availIds.size(); ++num) {
logger.info("Round {}...", num);
double bestRoundDiff = LARGE_DIFF;
// store new combinations generated this round
boolean isBetterRound = false;
// each document to that of each previously existing id combination
// todo: move dcc into olddcc; populate newdcc with dcc counts plus doc counts for each doc
// make sure to copy counters to avoid shared references across combinations (will corrupt counts)
//new HashMap<>();
Map<Set<String>, Counter<String>> oldCombCounts = initializeCurrentRoundCounts(oldBestSplitsOfSizeK);
/*
* compute NUM_DOCS * BEAM_SIZE possible splits.
*/
Map<Set<String>, Counter<String>> docCombinationCounts = new HashMap<>();
for (Set<String> keyComb : oldCombCounts.keySet()) {
Counter<String> keyCount = oldCombCounts.get(keyComb);
for (String docId : availIds) {
Set<String> newComb = new HashSet<>();
newComb.addAll(keyComb);
newComb.add(docId);
// naive implementation does not consider order, so avoid duplication
if (!oldCombCounts.containsKey(newComb)) {
// the counts for the current docId
Counter<String> docLabelCount = labelCounts.get(docId);
Counter<String> newCombLabelCount = new Counter<>();
// initialize newCombLabelCount with count from base id combination
for (String label : keyCount.keySet()) newCombLabelCount.incrementCount(label, keyCount.getCount(label));
//add current docId label counts
for (String label : docLabelCount.items()) {
newCombLabelCount.incrementCount(label, docLabelCount.getCount(label));
}
docCombinationCounts.put(newComb, newCombLabelCount);
}
}
}
PriorityQueue<QueueElement> bestSplitsOfSizeK = new PriorityQueue<>();
// want explicit generation because we will use these as seeds in the next round
for (Set<String> docidComb : docCombinationCounts.keySet()) {
double diff = computeCountDiff(docCombinationCounts.get(docidComb), targetCounts);
bestSplitsOfSizeK.add(new QueueElement(diff, docidComb, docCombinationCounts.get(docidComb)));
if (diff < bestRoundDiff) {
bestRoundDiff = diff;
if (bestRoundDiff < bestDiff) {
isBetterRound = true;
bestDiff = bestRoundDiff;
}
}
}
logger.info("current round best diff is {}", bestRoundDiff);
if (stopEarly && !isBetterRound) {
logger.warn("Stopping after round {}", num);
logger.warn("current round best diff is {}", bestRoundDiff);
break;
}
// store best fixed-size splits
oldBestSplitsOfSizeK = bestSplitsOfSizeK;
// track best splits overall
bestSplits.addAll(bestSplitsOfSizeK);
oldBestSplitsOfSizeK = trimQueue(oldBestSplitsOfSizeK);
bestSplits = trimQueue(bestSplits);
}
QueueElement bestSplit = bestSplits.poll();
return new Pair(bestSplit.docIdSet, bestSplit.labelCounter);
}
Aggregations