Search in sources :

Example 1 with GenericCopier

use of org.apache.drill.exec.physical.impl.svremover.GenericCopier in project drill by apache.

the class OperatorTestBuilder method go.

@SuppressWarnings("unchecked")
public void go() throws Exception {
    final List<RowSet> actualResults = new ArrayList<>();
    CloseableRecordBatch testOperator = null;
    try {
        validate();
        int expectedNumBatches = expectedNumBatchesOpt.orElse(expectedResults.size());
        physicalOpUnitTestBase.mockOpContext(physicalOperator, initReservation, maxAllocation);
        final BatchCreator<PhysicalOperator> opCreator = (BatchCreator<PhysicalOperator>) physicalOpUnitTestBase.opCreatorReg.getOperatorCreator(physicalOperator.getClass());
        testOperator = opCreator.getBatch(physicalOpUnitTestBase.fragContext, physicalOperator, (List) upstreamBatches);
        batchIterator: for (int batchIndex = 0; ; batchIndex++) {
            final RecordBatch.IterOutcome outcome = testOperator.next();
            switch(outcome) {
                case NONE:
                    if (!combineOutputBatches) {
                        Assert.assertEquals(expectedNumBatches, batchIndex);
                    }
                    // We are done iterating over batches. Now we need to compare them.
                    break batchIterator;
                case OK_NEW_SCHEMA:
                    boolean skip = true;
                    try {
                        skip = testOperator.getContainer().getRecordCount() == 0;
                    } catch (IllegalStateException e) {
                    // We should skip this batch in this case. It means no data was included with the okay schema
                    } finally {
                        if (skip) {
                            batchIndex--;
                            break;
                        }
                    }
                case OK:
                    if (!combineOutputBatches && batchIndex >= expectedNumBatches) {
                        testOperator.getContainer().clear();
                        Assert.fail("More batches received than expected.");
                    } else {
                        final boolean hasSelectionVector = testOperator.getSchema().getSelectionVectorMode().hasSelectionVector;
                        final VectorContainer container = testOperator.getContainer();
                        if (hasSelectionVector) {
                            throw new UnsupportedOperationException("Implement DRILL-6698");
                        } else {
                            actualResults.add(DirectRowSet.fromContainer(container));
                        }
                        break;
                    }
                default:
                    throw new UnsupportedOperationException("Can't handle this yet");
            }
        }
        int actualTotalRows = actualResults.stream().mapToInt(RowSet::rowCount).reduce(Integer::sum).orElse(0);
        if (expectedResults.isEmpty()) {
            Assert.assertEquals((int) expectedTotalRowsOpt.orElse(0), actualTotalRows);
            // We are done, we don't have any expected result to compare
            return;
        }
        if (combineOutputBatches) {
            final RowSet expectedBatch = expectedResults.get(0);
            final RowSet actualBatch = DirectRowSet.fromSchema(physicalOpUnitTestBase.operatorFixture.allocator, actualResults.get(0).container().getSchema());
            final VectorContainer actualBatchContainer = actualBatch.container();
            actualBatchContainer.setRecordCount(0);
            final int numColumns = expectedBatch.schema().size();
            List<MutableInt> totalBytesPerColumn = new ArrayList<>();
            for (int columnIndex = 0; columnIndex < numColumns; columnIndex++) {
                totalBytesPerColumn.add(new MutableInt());
            }
            // Get column sizes for each result batch
            final List<List<RecordBatchSizer.ColumnSize>> columnSizesPerBatch = actualResults.stream().map(rowSet -> {
                switch(rowSet.indirectionType()) {
                    case NONE:
                        return new RecordBatchSizer(rowSet.container()).columnsList();
                    default:
                        throw new UnsupportedOperationException("Implement DRILL-6698");
                }
            }).collect(Collectors.toList());
            for (List<RecordBatchSizer.ColumnSize> columnSizes : columnSizesPerBatch) {
                for (int columnIndex = 0; columnIndex < numColumns; columnIndex++) {
                    final MutableInt totalBytes = totalBytesPerColumn.get(columnIndex);
                    final RecordBatchSizer.ColumnSize columnSize = columnSizes.get(columnIndex);
                    totalBytes.add(columnSize.getTotalDataSize());
                }
            }
            for (int columnIndex = 0; columnIndex < numColumns; columnIndex++) {
                final ValueVector valueVector = actualBatchContainer.getValueVector(columnIndex).getValueVector();
                if (valueVector instanceof FixedWidthVector) {
                    ((FixedWidthVector) valueVector).allocateNew(actualTotalRows);
                } else if (valueVector instanceof VariableWidthVector) {
                    final MutableInt totalBytes = totalBytesPerColumn.get(columnIndex);
                    ((VariableWidthVector) valueVector).allocateNew(totalBytes.getValue(), actualTotalRows);
                } else {
                    throw new UnsupportedOperationException();
                }
            }
            try {
                int currentIndex = 0;
                for (RowSet actualRowSet : actualResults) {
                    final Copier copier;
                    final VectorContainer rowSetContainer = actualRowSet.container();
                    rowSetContainer.setRecordCount(actualRowSet.rowCount());
                    switch(actualRowSet.indirectionType()) {
                        case NONE:
                            copier = new GenericCopier();
                            break;
                        default:
                            throw new UnsupportedOperationException("Implement DRILL-6698");
                    }
                    copier.setup(rowSetContainer, actualBatchContainer);
                    copier.appendRecords(currentIndex, actualRowSet.rowCount());
                    currentIndex += actualRowSet.rowCount();
                    verify(expectedBatch, actualBatch);
                }
            } finally {
                actualBatch.clear();
            }
        } else {
            // Compare expected and actual results
            for (int batchIndex = 0; batchIndex < expectedNumBatches; batchIndex++) {
                final RowSet expectedBatch = expectedResults.get(batchIndex);
                final RowSet actualBatch = actualResults.get(batchIndex);
                verify(expectedBatch, actualBatch);
            }
        }
    } finally {
        if (testOperator != null) {
            testOperator.close();
        }
        actualResults.forEach(rowSet -> rowSet.clear());
        if (expectedResults != null) {
            expectedResults.forEach(rowSet -> rowSet.clear());
        }
        upstreamBatches.forEach(rowSetBatch -> {
            try {
                rowSetBatch.close();
            } catch (Exception e) {
                logger.error("Error while closing RowSetBatch", e);
            }
        });
    }
}
Also used : BatchCreator(org.apache.drill.exec.physical.impl.BatchCreator) AbstractBase(org.apache.drill.exec.physical.base.AbstractBase) MockRecordBatch(org.apache.drill.exec.physical.impl.MockRecordBatch) ValueVector(org.apache.drill.exec.vector.ValueVector) MutableInt(org.apache.commons.lang3.mutable.MutableInt) FixedWidthVector(org.apache.drill.exec.vector.FixedWidthVector) RecordBatch(org.apache.drill.exec.record.RecordBatch) Copier(org.apache.drill.exec.physical.impl.svremover.Copier) VectorContainer(org.apache.drill.exec.record.VectorContainer) Collectors(java.util.stream.Collectors) GenericCopier(org.apache.drill.exec.physical.impl.svremover.GenericCopier) ArrayList(java.util.ArrayList) PhysicalOperator(org.apache.drill.exec.physical.base.PhysicalOperator) CloseableRecordBatch(org.apache.drill.exec.record.CloseableRecordBatch) VariableWidthVector(org.apache.drill.exec.vector.VariableWidthVector) DirectRowSet(org.apache.drill.exec.physical.rowSet.DirectRowSet) RowSetComparison(org.apache.drill.test.rowSet.RowSetComparison) List(java.util.List) Preconditions(org.apache.drill.shaded.guava.com.google.common.base.Preconditions) RecordBatchSizer(org.apache.drill.exec.record.RecordBatchSizer) Optional(java.util.Optional) Assert(org.junit.Assert) RowSet(org.apache.drill.exec.physical.rowSet.RowSet) BatchCreator(org.apache.drill.exec.physical.impl.BatchCreator) GenericCopier(org.apache.drill.exec.physical.impl.svremover.GenericCopier) DirectRowSet(org.apache.drill.exec.physical.rowSet.DirectRowSet) RowSet(org.apache.drill.exec.physical.rowSet.RowSet) ArrayList(java.util.ArrayList) PhysicalOperator(org.apache.drill.exec.physical.base.PhysicalOperator) Copier(org.apache.drill.exec.physical.impl.svremover.Copier) GenericCopier(org.apache.drill.exec.physical.impl.svremover.GenericCopier) CloseableRecordBatch(org.apache.drill.exec.record.CloseableRecordBatch) ArrayList(java.util.ArrayList) List(java.util.List) FixedWidthVector(org.apache.drill.exec.vector.FixedWidthVector) VariableWidthVector(org.apache.drill.exec.vector.VariableWidthVector) VectorContainer(org.apache.drill.exec.record.VectorContainer) ValueVector(org.apache.drill.exec.vector.ValueVector) RecordBatchSizer(org.apache.drill.exec.record.RecordBatchSizer) MutableInt(org.apache.commons.lang3.mutable.MutableInt)

Aggregations

ArrayList (java.util.ArrayList)1 List (java.util.List)1 Optional (java.util.Optional)1 Collectors (java.util.stream.Collectors)1 MutableInt (org.apache.commons.lang3.mutable.MutableInt)1 AbstractBase (org.apache.drill.exec.physical.base.AbstractBase)1 PhysicalOperator (org.apache.drill.exec.physical.base.PhysicalOperator)1 BatchCreator (org.apache.drill.exec.physical.impl.BatchCreator)1 MockRecordBatch (org.apache.drill.exec.physical.impl.MockRecordBatch)1 Copier (org.apache.drill.exec.physical.impl.svremover.Copier)1 GenericCopier (org.apache.drill.exec.physical.impl.svremover.GenericCopier)1 DirectRowSet (org.apache.drill.exec.physical.rowSet.DirectRowSet)1 RowSet (org.apache.drill.exec.physical.rowSet.RowSet)1 CloseableRecordBatch (org.apache.drill.exec.record.CloseableRecordBatch)1 RecordBatch (org.apache.drill.exec.record.RecordBatch)1 RecordBatchSizer (org.apache.drill.exec.record.RecordBatchSizer)1 VectorContainer (org.apache.drill.exec.record.VectorContainer)1 FixedWidthVector (org.apache.drill.exec.vector.FixedWidthVector)1 ValueVector (org.apache.drill.exec.vector.ValueVector)1 VariableWidthVector (org.apache.drill.exec.vector.VariableWidthVector)1