use of com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation in project aws-athena-query-federation by awslabs.
the class S3BlockSpillerTest method spillTest.
@Test
public void spillTest() throws IOException {
logger.info("spillTest: enter");
logger.info("spillTest: starting write test");
final ByteHolder byteHolder = new ByteHolder();
when(mockS3.putObject(eq(bucket), anyString(), anyObject(), anyObject())).thenAnswer(new Answer<Object>() {
@Override
public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
InputStream inputStream = (InputStream) invocationOnMock.getArguments()[2];
byteHolder.setBytes(ByteStreams.toByteArray(inputStream));
return mock(PutObjectResult.class);
}
});
SpillLocation blockLocation = blockWriter.write(expected);
if (blockLocation instanceof S3SpillLocation) {
assertEquals(bucket, ((S3SpillLocation) blockLocation).getBucket());
assertEquals(prefix + "/" + requestId + "/" + splitId + ".0", ((S3SpillLocation) blockLocation).getKey());
}
SpillLocation blockLocation2 = blockWriter.write(expected);
if (blockLocation2 instanceof S3SpillLocation) {
assertEquals(bucket, ((S3SpillLocation) blockLocation2).getBucket());
assertEquals(prefix + "/" + requestId + "/" + splitId + ".1", ((S3SpillLocation) blockLocation2).getKey());
}
verify(mockS3, times(1)).putObject(eq(bucket), eq(prefix + "/" + requestId + "/" + splitId + ".0"), anyObject(), anyObject());
verify(mockS3, times(1)).putObject(eq(bucket), eq(prefix + "/" + requestId + "/" + splitId + ".1"), anyObject(), anyObject());
verifyNoMoreInteractions(mockS3);
reset(mockS3);
logger.info("spillTest: Starting read test.");
when(mockS3.getObject(eq(bucket), eq(prefix + "/" + requestId + "/" + splitId + ".1"))).thenAnswer(new Answer<Object>() {
@Override
public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
S3Object mockObject = mock(S3Object.class);
when(mockObject.getObjectContent()).thenReturn(new S3ObjectInputStream(new ByteArrayInputStream(byteHolder.getBytes()), null));
return mockObject;
}
});
Block block = blockWriter.read((S3SpillLocation) blockLocation2, spillConfig.getEncryptionKey(), expected.getSchema());
assertEquals(expected, block);
verify(mockS3, times(1)).getObject(eq(bucket), eq(prefix + "/" + requestId + "/" + splitId + ".1"));
verifyNoMoreInteractions(mockS3);
logger.info("spillTest: exit");
}
use of com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation in project aws-athena-query-federation by awslabs.
the class GetSplitsResponseSerDeTest method beforeTest.
@Before
public void beforeTest() throws IOException {
String yearCol = "year";
String monthCol = "month";
String dayCol = "day";
SpillLocation spillLocation1 = S3SpillLocation.newBuilder().withBucket("athena-virtuoso-test").withPrefix("lambda-spill").withQueryId("test-query-id").withSplitId("test-split-id-1").withIsDirectory(true).build();
EncryptionKey encryptionKey1 = new EncryptionKey("test-key-1".getBytes(), "test-nonce-1".getBytes());
Split split1 = Split.newBuilder(spillLocation1, encryptionKey1).add(yearCol, "2017").add(monthCol, "11").add(dayCol, "1").build();
SpillLocation spillLocation2 = S3SpillLocation.newBuilder().withBucket("athena-virtuoso-test").withPrefix("lambda-spill").withQueryId("test-query-id").withSplitId("test-split-id-2").withIsDirectory(true).build();
EncryptionKey encryptionKey2 = new EncryptionKey("test-key-2".getBytes(), "test-nonce-2".getBytes());
Split split2 = Split.newBuilder(spillLocation2, encryptionKey2).add(yearCol, "2017").add(monthCol, "11").add(dayCol, "2").build();
expected = new GetSplitsResponse("test-catalog", ImmutableSet.of(split1, split2), "test-continuation-token");
String expectedSerDeFile = utils.getResourceOrFail("serde/v2", "GetSplitsResponse.json");
expectedSerDeText = utils.readAllAsString(expectedSerDeFile).trim();
}
use of com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation in project aws-athena-query-federation by awslabs.
the class ReadRecordsRequestSerDeTest method beforeTest.
@Before
public void beforeTest() throws IOException {
String yearCol = "year";
String monthCol = "month";
String dayCol = "day";
Schema schema = SchemaBuilder.newBuilder().addField(yearCol, new ArrowType.Int(32, true)).addField(monthCol, new ArrowType.Int(32, true)).addField(dayCol, new ArrowType.Int(32, true)).addField("col2", new ArrowType.Utf8()).addField("col3", Types.MinorType.FLOAT8.getType()).addField("col4", Types.MinorType.FLOAT8.getType()).addField("col5", Types.MinorType.FLOAT8.getType()).build();
Map<String, ValueSet> constraintsMap = new HashMap<>();
constraintsMap.put("col3", SortedRangeSet.copyOf(Types.MinorType.FLOAT8.getType(), ImmutableList.of(Range.greaterThan(allocator, Types.MinorType.FLOAT8.getType(), -10000D)), false));
constraintsMap.put("col4", EquatableValueSet.newBuilder(allocator, Types.MinorType.FLOAT8.getType(), false, true).add(1.1D).build());
constraintsMap.put("col5", new AllOrNoneValueSet(Types.MinorType.FLOAT8.getType(), false, true));
Constraints constraints = new Constraints(constraintsMap);
Block partitions = allocator.createBlock(schema);
int num_partitions = 10;
for (int i = 0; i < num_partitions; i++) {
BlockUtils.setValue(partitions.getFieldVector(yearCol), i, 2016 + i);
BlockUtils.setValue(partitions.getFieldVector(monthCol), i, (i % 12) + 1);
BlockUtils.setValue(partitions.getFieldVector(dayCol), i, (i % 28) + 1);
}
partitions.setRowCount(num_partitions);
SpillLocation spillLocation = S3SpillLocation.newBuilder().withBucket("athena-virtuoso-test").withPrefix("lambda-spill").withQueryId("test-query-id").withSplitId("test-split-id").withIsDirectory(true).build();
EncryptionKey encryptionKey = new EncryptionKey("test-key".getBytes(), "test-nonce".getBytes());
Split split = Split.newBuilder(spillLocation, encryptionKey).add("year", "2017").add("month", "11").add("day", "1").build();
expected = new ReadRecordsRequest(federatedIdentity, "test-query-id", "test-catalog", new TableName("test-schema", "test-table"), schema, split, constraints, 100_000_000_000L, 100_000_000_000L);
String expectedSerDeFile = utils.getResourceOrFail("serde/v2", "ReadRecordsRequest.json");
expectedSerDeText = utils.readAllAsString(expectedSerDeFile).trim();
}
use of com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation in project aws-athena-query-federation by awslabs.
the class RemoteReadRecordsResponseSerDeTest method beforeTest.
@Before
public void beforeTest() throws IOException {
String yearCol = "year";
String monthCol = "month";
String dayCol = "day";
Schema schema = SchemaBuilder.newBuilder().addField(yearCol, new ArrowType.Int(32, true)).addField(monthCol, new ArrowType.Int(32, true)).addField(dayCol, new ArrowType.Int(32, true)).build();
EncryptionKey encryptionKey = new EncryptionKey("test-key".getBytes(), "test-nonce".getBytes());
SpillLocation spillLocation1 = S3SpillLocation.newBuilder().withBucket("athena-virtuoso-test").withPrefix("lambda-spill").withQueryId("test-query-id").withSplitId("test-split-id-1").withIsDirectory(true).build();
SpillLocation spillLocation2 = S3SpillLocation.newBuilder().withBucket("athena-virtuoso-test").withPrefix("lambda-spill").withQueryId("test-query-id").withSplitId("test-split-id-2").withIsDirectory(true).build();
expected = new RemoteReadRecordsResponse("test-catalog", schema, ImmutableList.of(spillLocation1, spillLocation2), encryptionKey);
String expectedSerDeFile = utils.getResourceOrFail("serde/v2", "RemoteReadRecordsResponse.json");
expectedSerDeText = utils.readAllAsString(expectedSerDeFile).trim();
}
use of com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation in project aws-athena-query-federation by awslabs.
the class HiveMetadataHandler method doGetSplits.
/**
* Used to split-up the reads required to scan the requested batch of partition(s).
*
* @param blockAllocator Tool for creating and managing Apache Arrow Blocks.
* @param getSplitsRequest Provides details of the Hive catalog, database, table, and partition(s) being queried as well as
* any filter predicate.
* @return A GetSplitsResponse which primarily contains:
* 1. A Set of Splits which represent read operations Amazon Athena must perform by calling your read function.
* 2. (Optional) A continuation token which allows you to paginate the generation of splits for large queries.
*/
@Override
public GetSplitsResponse doGetSplits(BlockAllocator blockAllocator, GetSplitsRequest getSplitsRequest) {
LOGGER.info("{}: Catalog {}, table {}", getSplitsRequest.getQueryId(), getSplitsRequest.getTableName().getSchemaName(), getSplitsRequest.getTableName().getTableName());
int partitionContd = decodeContinuationToken(getSplitsRequest);
Set<Split> splits = new HashSet<>();
Block partitions = getSplitsRequest.getPartitions();
for (int curPartition = partitionContd; curPartition < partitions.getRowCount(); curPartition++) {
FieldReader locationReader = partitions.getFieldReader(HiveConstants.BLOCK_PARTITION_COLUMN_NAME);
locationReader.setPosition(curPartition);
SpillLocation spillLocation = makeSpillLocation(getSplitsRequest);
Split.Builder splitBuilder = Split.newBuilder(spillLocation, makeEncryptionKey()).add(HiveConstants.BLOCK_PARTITION_COLUMN_NAME, String.valueOf(locationReader.readText()));
splits.add(splitBuilder.build());
if (splits.size() >= HiveConstants.MAX_SPLITS_PER_REQUEST) {
return new GetSplitsResponse(getSplitsRequest.getCatalogName(), splits, encodeContinuationToken(curPartition));
}
}
return new GetSplitsResponse(getSplitsRequest.getCatalogName(), splits, null);
}
Aggregations