use of org.apache.lucene.search.grouping.GroupingSearch in project lucene-solr by apache.
the class DatasetSplitter method split.
/**
* Split a given index into 3 indexes for training, test and cross validation tasks respectively
*
* @param originalIndex an {@link org.apache.lucene.index.LeafReader} on the source index
* @param trainingIndex a {@link Directory} used to write the training index
* @param testIndex a {@link Directory} used to write the test index
* @param crossValidationIndex a {@link Directory} used to write the cross validation index
* @param analyzer {@link Analyzer} used to create the new docs
* @param termVectors {@code true} if term vectors should be kept
* @param classFieldName name of the field used as the label for classification; this must be indexed with sorted doc values
* @param fieldNames names of fields that need to be put in the new indexes or <code>null</code> if all should be used
* @throws IOException if any writing operation fails on any of the indexes
*/
public void split(IndexReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex, Analyzer analyzer, boolean termVectors, String classFieldName, String... fieldNames) throws IOException {
// create IWs for train / test / cv IDXs
IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer));
IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer));
IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));
// get the exact no. of existing classes
int noOfClasses = 0;
for (LeafReaderContext leave : originalIndex.leaves()) {
long valueCount = 0;
SortedDocValues classValues = leave.reader().getSortedDocValues(classFieldName);
if (classValues != null) {
valueCount = classValues.getValueCount();
} else {
SortedSetDocValues sortedSetDocValues = leave.reader().getSortedSetDocValues(classFieldName);
if (sortedSetDocValues != null) {
valueCount = sortedSetDocValues.getValueCount();
}
}
if (classValues == null) {
// approximate with no. of terms
noOfClasses += leave.reader().terms(classFieldName).size();
}
noOfClasses += valueCount;
}
try {
IndexSearcher indexSearcher = new IndexSearcher(originalIndex);
GroupingSearch gs = new GroupingSearch(classFieldName);
gs.setGroupSort(Sort.INDEXORDER);
gs.setSortWithinGroup(Sort.INDEXORDER);
gs.setAllGroups(true);
gs.setGroupDocsLimit(originalIndex.maxDoc());
TopGroups<Object> topGroups = gs.search(indexSearcher, new MatchAllDocsQuery(), 0, noOfClasses);
// set the type to be indexed, stored, with term vectors
FieldType ft = new FieldType(TextField.TYPE_STORED);
if (termVectors) {
ft.setStoreTermVectors(true);
ft.setStoreTermVectorOffsets(true);
ft.setStoreTermVectorPositions(true);
}
int b = 0;
// iterate over existing documents
for (GroupDocs<Object> group : topGroups.groups) {
int totalHits = group.totalHits;
double testSize = totalHits * testRatio;
int tc = 0;
double cvSize = totalHits * crossValidationRatio;
int cvc = 0;
for (ScoreDoc scoreDoc : group.scoreDocs) {
// create a new document for indexing
Document doc = createNewDoc(originalIndex, ft, scoreDoc, fieldNames);
// add it to one of the IDXs
if (b % 2 == 0 && tc < testSize) {
testWriter.addDocument(doc);
tc++;
} else if (cvc < cvSize) {
cvWriter.addDocument(doc);
cvc++;
} else {
trainingWriter.addDocument(doc);
}
b++;
}
}
// commit
testWriter.commit();
cvWriter.commit();
trainingWriter.commit();
// merge
testWriter.forceMerge(3);
cvWriter.forceMerge(3);
trainingWriter.forceMerge(3);
} catch (Exception e) {
throw new IOException(e);
} finally {
// close IWs
testWriter.close();
cvWriter.close();
trainingWriter.close();
originalIndex.close();
}
}
Aggregations