use of org.apache.drill.exec.physical.impl.BatchCreator in project drill by apache.
the class OperatorTestBuilder method go.
@SuppressWarnings("unchecked")
public void go() throws Exception {
final List<RowSet> actualResults = new ArrayList<>();
CloseableRecordBatch testOperator = null;
try {
validate();
int expectedNumBatches = expectedNumBatchesOpt.orElse(expectedResults.size());
physicalOpUnitTestBase.mockOpContext(physicalOperator, initReservation, maxAllocation);
final BatchCreator<PhysicalOperator> opCreator = (BatchCreator<PhysicalOperator>) physicalOpUnitTestBase.opCreatorReg.getOperatorCreator(physicalOperator.getClass());
testOperator = opCreator.getBatch(physicalOpUnitTestBase.fragContext, physicalOperator, (List) upstreamBatches);
batchIterator: for (int batchIndex = 0; ; batchIndex++) {
final RecordBatch.IterOutcome outcome = testOperator.next();
switch(outcome) {
case NONE:
if (!combineOutputBatches) {
Assert.assertEquals(expectedNumBatches, batchIndex);
}
// We are done iterating over batches. Now we need to compare them.
break batchIterator;
case OK_NEW_SCHEMA:
boolean skip = true;
try {
skip = testOperator.getContainer().getRecordCount() == 0;
} catch (IllegalStateException e) {
// We should skip this batch in this case. It means no data was included with the okay schema
} finally {
if (skip) {
batchIndex--;
break;
}
}
case OK:
if (!combineOutputBatches && batchIndex >= expectedNumBatches) {
testOperator.getContainer().clear();
Assert.fail("More batches received than expected.");
} else {
final boolean hasSelectionVector = testOperator.getSchema().getSelectionVectorMode().hasSelectionVector;
final VectorContainer container = testOperator.getContainer();
if (hasSelectionVector) {
throw new UnsupportedOperationException("Implement DRILL-6698");
} else {
actualResults.add(DirectRowSet.fromContainer(container));
}
break;
}
default:
throw new UnsupportedOperationException("Can't handle this yet");
}
}
int actualTotalRows = actualResults.stream().mapToInt(RowSet::rowCount).reduce(Integer::sum).orElse(0);
if (expectedResults.isEmpty()) {
Assert.assertEquals((int) expectedTotalRowsOpt.orElse(0), actualTotalRows);
// We are done, we don't have any expected result to compare
return;
}
if (combineOutputBatches) {
final RowSet expectedBatch = expectedResults.get(0);
final RowSet actualBatch = DirectRowSet.fromSchema(physicalOpUnitTestBase.operatorFixture.allocator, actualResults.get(0).container().getSchema());
final VectorContainer actualBatchContainer = actualBatch.container();
actualBatchContainer.setRecordCount(0);
final int numColumns = expectedBatch.schema().size();
List<MutableInt> totalBytesPerColumn = new ArrayList<>();
for (int columnIndex = 0; columnIndex < numColumns; columnIndex++) {
totalBytesPerColumn.add(new MutableInt());
}
// Get column sizes for each result batch
final List<List<RecordBatchSizer.ColumnSize>> columnSizesPerBatch = actualResults.stream().map(rowSet -> {
switch(rowSet.indirectionType()) {
case NONE:
return new RecordBatchSizer(rowSet.container()).columnsList();
default:
throw new UnsupportedOperationException("Implement DRILL-6698");
}
}).collect(Collectors.toList());
for (List<RecordBatchSizer.ColumnSize> columnSizes : columnSizesPerBatch) {
for (int columnIndex = 0; columnIndex < numColumns; columnIndex++) {
final MutableInt totalBytes = totalBytesPerColumn.get(columnIndex);
final RecordBatchSizer.ColumnSize columnSize = columnSizes.get(columnIndex);
totalBytes.add(columnSize.getTotalDataSize());
}
}
for (int columnIndex = 0; columnIndex < numColumns; columnIndex++) {
final ValueVector valueVector = actualBatchContainer.getValueVector(columnIndex).getValueVector();
if (valueVector instanceof FixedWidthVector) {
((FixedWidthVector) valueVector).allocateNew(actualTotalRows);
} else if (valueVector instanceof VariableWidthVector) {
final MutableInt totalBytes = totalBytesPerColumn.get(columnIndex);
((VariableWidthVector) valueVector).allocateNew(totalBytes.getValue(), actualTotalRows);
} else {
throw new UnsupportedOperationException();
}
}
try {
int currentIndex = 0;
for (RowSet actualRowSet : actualResults) {
final Copier copier;
final VectorContainer rowSetContainer = actualRowSet.container();
rowSetContainer.setRecordCount(actualRowSet.rowCount());
switch(actualRowSet.indirectionType()) {
case NONE:
copier = new GenericCopier();
break;
default:
throw new UnsupportedOperationException("Implement DRILL-6698");
}
copier.setup(rowSetContainer, actualBatchContainer);
copier.appendRecords(currentIndex, actualRowSet.rowCount());
currentIndex += actualRowSet.rowCount();
verify(expectedBatch, actualBatch);
}
} finally {
actualBatch.clear();
}
} else {
// Compare expected and actual results
for (int batchIndex = 0; batchIndex < expectedNumBatches; batchIndex++) {
final RowSet expectedBatch = expectedResults.get(batchIndex);
final RowSet actualBatch = actualResults.get(batchIndex);
verify(expectedBatch, actualBatch);
}
}
} finally {
if (testOperator != null) {
testOperator.close();
}
actualResults.forEach(rowSet -> rowSet.clear());
if (expectedResults != null) {
expectedResults.forEach(rowSet -> rowSet.clear());
}
upstreamBatches.forEach(rowSetBatch -> {
try {
rowSetBatch.close();
} catch (Exception e) {
logger.error("Error while closing RowSetBatch", e);
}
});
}
}
Aggregations