use of org.apache.beam.sdk.transforms.SerializableFunction in project beam by apache.
the class SparkCoGroupByKeyStreamingTest method testInStreamingMode.
@Category(StreamingTest.class)
@Test
public void testInStreamingMode() throws Exception {
Instant instant = new Instant(0);
CreateStream<KV<Integer, Integer>> source1 = CreateStream.of(KvCoder.of(VarIntCoder.of(), VarIntCoder.of()), batchDuration()).emptyBatch().advanceWatermarkForNextBatch(instant).nextBatch(TimestampedValue.of(KV.of(1, 1), instant), TimestampedValue.of(KV.of(1, 2), instant), TimestampedValue.of(KV.of(1, 3), instant)).advanceWatermarkForNextBatch(instant.plus(Duration.standardSeconds(1L))).nextBatch(TimestampedValue.of(KV.of(2, 4), instant.plus(Duration.standardSeconds(1L))), TimestampedValue.of(KV.of(2, 5), instant.plus(Duration.standardSeconds(1L))), TimestampedValue.of(KV.of(2, 6), instant.plus(Duration.standardSeconds(1L)))).advanceNextBatchWatermarkToInfinity();
CreateStream<KV<Integer, Integer>> source2 = CreateStream.of(KvCoder.of(VarIntCoder.of(), VarIntCoder.of()), batchDuration()).emptyBatch().advanceWatermarkForNextBatch(instant).nextBatch(TimestampedValue.of(KV.of(1, 11), instant), TimestampedValue.of(KV.of(1, 12), instant), TimestampedValue.of(KV.of(1, 13), instant)).advanceWatermarkForNextBatch(instant.plus(Duration.standardSeconds(1L))).nextBatch(TimestampedValue.of(KV.of(2, 14), instant.plus(Duration.standardSeconds(1L))), TimestampedValue.of(KV.of(2, 15), instant.plus(Duration.standardSeconds(1L))), TimestampedValue.of(KV.of(2, 16), instant.plus(Duration.standardSeconds(1L)))).advanceNextBatchWatermarkToInfinity();
PCollection<KV<Integer, Integer>> input1 = pipeline.apply("create source1", source1).apply("window input1", Window.<KV<Integer, Integer>>into(FixedWindows.of(Duration.standardSeconds(3L))).withAllowedLateness(Duration.ZERO));
PCollection<KV<Integer, Integer>> input2 = pipeline.apply("create source2", source2).apply("window input2", Window.<KV<Integer, Integer>>into(FixedWindows.of(Duration.standardSeconds(3L))).withAllowedLateness(Duration.ZERO));
PCollection<KV<Integer, CoGbkResult>> output = KeyedPCollectionTuple.of(INPUT1_TAG, input1).and(INPUT2_TAG, input2).apply(CoGroupByKey.create());
PAssert.that("Wrong output of the join using CoGroupByKey in streaming mode", output).satisfies((SerializableFunction<Iterable<KV<Integer, CoGbkResult>>, Void>) input -> {
assertEquals("Wrong size of the output PCollection", 2, Iterables.size(input));
for (KV<Integer, CoGbkResult> element : input) {
if (element.getKey() == 1) {
Iterable<Integer> input1Elements = element.getValue().getAll(INPUT1_TAG);
assertEquals("Wrong number of values for output elements for tag input1 and key 1", 3, Iterables.size(input1Elements));
assertThat("Elements of PCollection input1 for key \"1\" are not present in the output PCollection", input1Elements, containsInAnyOrder(1, 2, 3));
Iterable<Integer> input2Elements = element.getValue().getAll(INPUT2_TAG);
assertEquals("Wrong number of values for output elements for tag input2 and key 1", 3, Iterables.size(input2Elements));
assertThat("Elements of PCollection input2 for key \"1\" are not present in the output PCollection", input2Elements, containsInAnyOrder(11, 12, 13));
} else if (element.getKey() == 2) {
Iterable<Integer> input1Elements = element.getValue().getAll(INPUT1_TAG);
assertEquals("Wrong number of values for output elements for tag input1 and key 2", 3, Iterables.size(input1Elements));
assertThat("Elements of PCollection input1 for key \"2\" are not present in the output PCollection", input1Elements, containsInAnyOrder(4, 5, 6));
Iterable<Integer> input2Elements = element.getValue().getAll(INPUT2_TAG);
assertEquals("Wrong number of values for output elements for tag input2 and key 2", 3, Iterables.size(input2Elements));
assertThat("Elements of PCollection input2 for key \"2\" are not present in the output PCollection", input2Elements, containsInAnyOrder(14, 15, 16));
} else {
fail("Unknown key in the output PCollection");
}
}
return null;
});
pipeline.run();
}
use of org.apache.beam.sdk.transforms.SerializableFunction in project beam by apache.
the class BigQueryIOWriteTest method testWriteAvroWithCustomWriter.
@Test
public void testWriteAvroWithCustomWriter() throws Exception {
if (useStorageApi || useStreaming) {
return;
}
SerializableFunction<AvroWriteRequest<InputRecord>, GenericRecord> formatFunction = r -> {
GenericRecord rec = new GenericData.Record(r.getSchema());
InputRecord i = r.getElement();
rec.put("strVal", i.strVal());
rec.put("longVal", i.longVal());
rec.put("doubleVal", i.doubleVal());
rec.put("instantVal", i.instantVal().getMillis() * 1000);
return rec;
};
SerializableFunction<org.apache.avro.Schema, DatumWriter<GenericRecord>> customWriterFactory = s -> new GenericDatumWriter<GenericRecord>() {
@Override
protected void writeString(org.apache.avro.Schema schema, Object datum, Encoder out) throws IOException {
super.writeString(schema, datum.toString() + "_custom", out);
}
};
p.apply(Create.of(InputRecord.create("test", 1, 1.0, Instant.parse("2019-01-01T00:00:00Z")), InputRecord.create("test2", 2, 2.0, Instant.parse("2019-02-01T00:00:00Z"))).withCoder(INPUT_RECORD_CODER)).apply(BigQueryIO.<InputRecord>write().to("dataset-id.table-id").withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED).withSchema(new TableSchema().setFields(ImmutableList.of(new TableFieldSchema().setName("strVal").setType("STRING"), new TableFieldSchema().setName("longVal").setType("INTEGER"), new TableFieldSchema().setName("doubleVal").setType("FLOAT"), new TableFieldSchema().setName("instantVal").setType("TIMESTAMP")))).withTestServices(fakeBqServices).withAvroWriter(formatFunction, customWriterFactory).withoutValidation());
p.run();
assertThat(fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id"), containsInAnyOrder(new TableRow().set("strVal", "test_custom").set("longVal", "1").set("doubleVal", 1.0D).set("instantVal", "2019-01-01 00:00:00 UTC"), new TableRow().set("strVal", "test2_custom").set("longVal", "2").set("doubleVal", 2.0D).set("instantVal", "2019-02-01 00:00:00 UTC")));
}
use of org.apache.beam.sdk.transforms.SerializableFunction in project beam by apache.
the class ConvertHelpers method getConvertPrimitive.
/**
* Returns a function to convert a Row into a primitive type. This only works when the row schema
* contains a single field, and that field is convertible to the primitive type.
*/
@SuppressWarnings("unchecked")
public static <OutputT> SerializableFunction<?, OutputT> getConvertPrimitive(FieldType fieldType, TypeDescriptor<?> outputTypeDescriptor, TypeConversionsFactory typeConversionsFactory) {
FieldType expectedFieldType = StaticSchemaInference.fieldFromType(outputTypeDescriptor, JavaFieldTypeSupplier.INSTANCE);
if (!expectedFieldType.equals(fieldType)) {
throw new IllegalArgumentException("Element argument type " + outputTypeDescriptor + " does not work with expected schema field type " + fieldType);
}
Type expectedInputType = typeConversionsFactory.createTypeConversion(false).convert(outputTypeDescriptor);
TypeDescriptor<?> outputType = outputTypeDescriptor;
if (outputType.getRawType().isPrimitive()) {
// A SerializableFunction can only return an Object type, so if the DoFn parameter is a
// primitive type, then box it for the return. The return type will be unboxed before being
// forwarded to the DoFn parameter.
outputType = TypeDescriptor.of(Primitives.wrap(outputType.getRawType()));
}
TypeDescription.Generic genericType = TypeDescription.Generic.Builder.parameterizedType(SerializableFunction.class, expectedInputType, outputType.getType()).build();
DynamicType.Builder<SerializableFunction> builder = (DynamicType.Builder<SerializableFunction>) new ByteBuddy().subclass(genericType);
try {
return builder.visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)).method(ElementMatchers.named("apply")).intercept(new ConvertPrimitiveInstruction(outputType, typeConversionsFactory)).make().load(ReflectHelpers.findClassLoader(), ClassLoadingStrategy.Default.INJECTION).getLoaded().getDeclaredConstructor().newInstance();
} catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
throw new RuntimeException(e);
}
}
use of org.apache.beam.sdk.transforms.SerializableFunction in project beam by apache.
the class BigQueryHllSketchCompatibilityIT method writeSketchToBigQuery.
private void writeSketchToBigQuery(List<String> testData, String expectedChecksum) {
String tableSpec = String.format("%s.%s", DATASET_ID, SKETCH_TABLE_ID);
String query = String.format("SELECT HLL_COUNT.EXTRACT(%s) FROM %s", SKETCH_FIELD_NAME, tableSpec);
TableSchema tableSchema = new TableSchema().setFields(Collections.singletonList(new TableFieldSchema().setName(SKETCH_FIELD_NAME).setType(SKETCH_FIELD_TYPE)));
TestPipelineOptions options = TestPipeline.testingPipelineOptions().as(TestPipelineOptions.class);
Pipeline p = Pipeline.create(options);
// until we have a stub class for BigQuery TableRow
@SuppressWarnings("nullness") SerializableFunction<byte[], TableRow> formatFn = sketch -> new TableRow().set(SKETCH_FIELD_NAME, sketch.length == 0 ? null : sketch);
p.apply(Create.of(testData).withType(TypeDescriptor.of(String.class))).apply(HllCount.Init.forStrings().globally()).apply(BigQueryIO.<byte[]>write().to(tableSpec).withSchema(tableSchema).withFormatFunction(formatFn).withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE));
p.run().waitUntilFinish();
// BigqueryMatcher will send a query to retrieve the estimated count and verifies its
// correctness using checksum.
assertThat(createQueryUsingStandardSql(APP_NAME, PROJECT_ID, query), queryResultHasChecksum(expectedChecksum));
}
use of org.apache.beam.sdk.transforms.SerializableFunction in project beam by apache.
the class Neo4jIOIT method testLargeWriteUnwind.
@Test
public void testLargeWriteUnwind() throws Exception {
final int startId = 5000;
final int endId = 6000;
// Create 1000 IDs
List<Integer> idList = new ArrayList<>();
for (int id = startId; id < endId; id++) {
idList.add(id);
}
PCollection<Integer> idCollection = largeWriteUnwindPipeline.apply(Create.of(idList));
// Every row is represented by a Map<String, Object> in the parameters map.
// We accumulate the rows and 'unwind' those to Neo4j for performance reasons.
//
SerializableFunction<Integer, Map<String, Object>> parametersFunction = id -> ImmutableMap.of("id", id, "name", "Casters", "firstName", "Matt");
// 1000 rows with a batch size of 123 should trigger most scenarios we can think of
// We've put a unique constraint on Something.id
//
Neo4jIO.WriteUnwind<Integer> read = Neo4jIO.<Integer>writeUnwind().withDriverConfiguration(Neo4jTestUtil.getDriverConfiguration(containerHostname, containerPort)).withSessionConfig(SessionConfig.forDatabase(Neo4jTestUtil.NEO4J_DATABASE)).withBatchSize(123).withUnwindMapName("rows").withCypher("UNWIND $rows AS row CREATE(n:Something { id : row.id })").withParametersFunction(parametersFunction).withCypherLogging();
idCollection.apply(read);
// Now run this pipeline
//
PipelineResult pipelineResult = largeWriteUnwindPipeline.run();
Assert.assertEquals(PipelineResult.State.DONE, pipelineResult.getState());
//
try (Driver driver = Neo4jTestUtil.getDriver(containerHostname, containerPort)) {
try (Session session = Neo4jTestUtil.getSession(driver, true)) {
List<Integer> values = session.readTransaction(tx -> {
List<Integer> v = null;
int nrRows = 0;
Result result = tx.run("MATCH(n:Something) RETURN count(n), min(n.id), max(n.id)");
while (result.hasNext()) {
Record record = result.next();
v = Arrays.asList(record.get(0).asInt(), record.get(1).asInt(), record.get(2).asInt(), ++nrRows);
}
return v;
});
Assert.assertNotNull(values);
assertThat(values, contains(endId - startId, startId, endId - 1, 1));
}
}
}
Aggregations