use of org.infinispan.protostream.impl.SerializationContextImpl in project kogito-apps by kiegroup.
the class ProtoIndexParserTest method testConfigureBuilderWithInvalidFile.
@Test
void testConfigureBuilderWithInvalidFile() {
SerializationContext ctx = new SerializationContextImpl(configureBuilder().build());
FileDescriptorSource invalidFileDescriptorSource = FileDescriptorSource.fromString("invalid", "invalid");
try {
ctx.registerProtoFiles(invalidFileDescriptorSource);
fail("Failed to process invalid proto file");
} catch (DescriptorParserException ex) {
// Successfully throw exception
}
}
use of org.infinispan.protostream.impl.SerializationContextImpl in project kogito-apps by kiegroup.
the class TestUtils method getTestFileDescriptor.
static FileDescriptor getTestFileDescriptor() {
String content = getTestFileContent();
SerializationContext ctx = new SerializationContextImpl(Configuration.builder().build());
ctx.registerProtoFiles(FileDescriptorSource.fromString(DOMAIN_MODEL_PROTO_NAME, content));
return ctx.getFileDescriptors().get(DOMAIN_MODEL_PROTO_NAME);
}
use of org.infinispan.protostream.impl.SerializationContextImpl in project protostream by infinispan.
the class WrappedMessage method readMessage.
private static <T> T readMessage(ImmutableSerializationContext ctx, TagReader in, boolean nulls) throws IOException {
String typeName = null;
Integer typeId = null;
int enumValue = -1;
byte[] messageBytes = null;
Object value = null;
int fieldCount = 0;
int expectedFieldCount = 1;
int tag;
out: while ((tag = in.readTag()) != 0) {
fieldCount++;
switch(tag) {
case WRAPPED_CONTAINER_SIZE << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
case WRAPPED_CONTAINER_TYPE_ID << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
case WRAPPED_CONTAINER_TYPE_NAME << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED:
case WRAPPED_CONTAINER_MESSAGE << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED:
{
expectedFieldCount = 1;
value = readContainer(ctx, in, tag);
break out;
}
case WRAPPED_EMPTY << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
if (!nulls) {
throw new IllegalStateException("Encountered a null message but nulls are not accepted");
}
expectedFieldCount = 1;
// We ignore the actual boolean value! Will be returning null anyway.
in.readBool();
break out;
}
case WRAPPED_TYPE_NAME << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED:
{
expectedFieldCount = 2;
typeName = in.readString();
break;
}
case WRAPPED_TYPE_ID << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 2;
typeId = mapTypeIdIn(in.readInt32(), ctx);
break;
}
case WRAPPED_ENUM << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 2;
enumValue = in.readEnum();
break;
}
case WRAPPED_MESSAGE << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED:
{
expectedFieldCount = 2;
messageBytes = in.readByteArray();
break;
}
case WRAPPED_STRING << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED:
{
expectedFieldCount = 1;
value = in.readString();
break out;
}
case WRAPPED_CHAR << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 1;
value = (char) in.readInt32();
break out;
}
case WRAPPED_SHORT << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 1;
value = (short) in.readInt32();
break out;
}
case WRAPPED_BYTE << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 1;
value = (byte) in.readInt32();
break out;
}
case WRAPPED_DATE_MILLIS << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 1;
value = new Date(in.readInt64());
break out;
}
case WRAPPED_INSTANT_SECONDS << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 2;
long seconds = in.readInt64();
value = value == null ? Instant.ofEpochSecond(seconds, 0) : Instant.ofEpochSecond(seconds, ((Instant) value).getNano());
break;
}
case WRAPPED_INSTANT_NANOS << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 2;
int nanos = in.readInt32();
value = value == null ? Instant.ofEpochSecond(0, nanos) : Instant.ofEpochSecond(((Instant) value).getEpochSecond(), nanos);
break;
}
case WRAPPED_BYTES << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED:
{
expectedFieldCount = 1;
value = in.readByteArray();
break out;
}
case WRAPPED_BOOL << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 1;
value = in.readBool();
break out;
}
case WRAPPED_DOUBLE << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED64:
{
expectedFieldCount = 1;
value = in.readDouble();
break out;
}
case WRAPPED_FLOAT << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED32:
{
expectedFieldCount = 1;
value = in.readFloat();
break out;
}
case WRAPPED_FIXED32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED32:
{
expectedFieldCount = 1;
value = in.readFixed32();
break out;
}
case WRAPPED_SFIXED32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED32:
{
expectedFieldCount = 1;
value = in.readSFixed32();
break out;
}
case WRAPPED_FIXED64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED64:
{
expectedFieldCount = 1;
value = in.readFixed64();
break out;
}
case WRAPPED_SFIXED64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_FIXED64:
{
expectedFieldCount = 1;
value = in.readSFixed64();
break out;
}
case WRAPPED_INT64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 1;
value = in.readInt64();
break out;
}
case WRAPPED_UINT64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 1;
value = in.readUInt64();
break out;
}
case WRAPPED_SINT64 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 1;
value = in.readSInt64();
break out;
}
case WRAPPED_INT32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 1;
value = in.readInt32();
break out;
}
case WRAPPED_UINT32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 1;
value = in.readUInt32();
break out;
}
case WRAPPED_SINT32 << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_VARINT:
{
expectedFieldCount = 1;
value = in.readSInt32();
break out;
}
default:
throw new IllegalStateException("Unexpected tag : " + tag + " (Field number : " + WireType.getTagFieldNumber(tag) + ", Wire type : " + WireType.getTagWireType(tag) + ")");
}
}
if (value == null && typeName == null && typeId == null && messageBytes == null) {
return null;
}
if (value != null) {
if (fieldCount != expectedFieldCount) {
throw new IOException("Invalid WrappedMessage encoding.");
}
return (T) value;
}
if (typeName == null && typeId == null || typeName != null && typeId != null || fieldCount != 2) {
throw new IOException("Invalid WrappedMessage encoding.");
}
if (typeId != null) {
typeName = ctx.getDescriptorByTypeId(typeId).getFullName();
}
BaseMarshallerDelegate marshallerDelegate = ((SerializationContextImpl) ctx).getMarshallerDelegate(typeName);
if (messageBytes != null) {
// it's a Message type
TagReaderImpl nestedInput = TagReaderImpl.newInstance(ctx, messageBytes);
return (T) marshallerDelegate.unmarshall(nestedInput, null);
} else {
// it's an Enum
EnumMarshaller marshaller = (EnumMarshaller) marshallerDelegate.getMarshaller();
T e = (T) marshaller.decode(enumValue);
if (e == null) {
// Unknown enum value cause by schema evolution. We cannot handle data loss here so we throw!
throw new IOException("Unknown enum value " + enumValue + " for Protobuf enum type " + typeName);
}
return e;
}
}
use of org.infinispan.protostream.impl.SerializationContextImpl in project protostream by infinispan.
the class ProtobufUtil method newSerializationContext.
public static SerializationContext newSerializationContext(Configuration configuration) {
SerializationContextImpl serializationContext = new SerializationContextImpl(configuration);
try {
// always register message-wrapping.proto
serializationContext.registerProtoFiles(FileDescriptorSource.fromResources(WrappedMessage.PROTO_FILE));
} catch (IOException | DescriptorParserException e) {
throw new RuntimeException("Failed to initialize serialization context", e);
}
serializationContext.registerMarshaller(WrappedMessage.MARSHALLER);
return serializationContext;
}
use of org.infinispan.protostream.impl.SerializationContextImpl in project kogito-runtimes by kiegroup.
the class AbstractMarshallerGenerator method generate.
public List<CompilationUnit> generate(FileDescriptorSource proto) throws IOException {
List<CompilationUnit> units = new ArrayList<>();
TemplatedGenerator generator = TemplatedGenerator.builder().withFallbackContext(JavaKogitoBuildContext.CONTEXT_NAME).withTemplateBasePath(TEMPLATE_PERSISTENCE_FOLDER).build(context, "MessageMarshaller");
Predicate<String> typeExclusions = ExclusionTypeUtils.createTypeExclusions();
// filter types that don't require to create a marshaller
Predicate<Descriptor> packagePredicate = (msg) -> !msg.getFileDescriptor().getPackage().equals("kogito");
Predicate<Descriptor> jacksonPredicate = (msg) -> !typeExclusions.test(packageFromOption(msg.getFileDescriptor(), msg) + "." + msg.getName());
Predicate<Descriptor> predicate = packagePredicate.and(jacksonPredicate);
CompilationUnit parsedClazzFile = generator.compilationUnitOrThrow();
SerializationContext serializationContext = new SerializationContextImpl(Configuration.builder().build());
FileDescriptorSource kogitoTypesDescriptor = new FileDescriptorSource().addProtoFile("kogito-types.proto", context.getClassLoader().getResourceAsStream("META-INF/kogito-types.proto"));
serializationContext.registerProtoFiles(kogitoTypesDescriptor);
serializationContext.registerProtoFiles(proto);
Map<String, FileDescriptor> descriptors = serializationContext.getFileDescriptors();
for (Entry<String, FileDescriptor> entry : descriptors.entrySet()) {
FileDescriptor d = entry.getValue();
List<Descriptor> messages = d.getMessageTypes().stream().filter(predicate).collect(Collectors.toList());
for (Descriptor msg : messages) {
CompilationUnit clazzFile = parsedClazzFile.clone();
units.add(clazzFile);
String javaType = packageFromOption(d, msg) + "." + msg.getName();
clazzFile.setPackageDeclaration(d.getPackage());
ClassOrInterfaceDeclaration clazz = clazzFile.findFirst(ClassOrInterfaceDeclaration.class, sl -> true).orElseThrow(() -> new InvalidTemplateException(generator, "No class found"));
clazz.setName(msg.getName() + "MessageMarshaller");
clazz.getImplementedTypes(0).setTypeArguments(NodeList.nodeList(new ClassOrInterfaceType(null, javaType)));
MethodDeclaration getJavaClassMethod = clazz.findFirst(MethodDeclaration.class, md -> md.getNameAsString().equals("getJavaClass")).orElseThrow(() -> new InvalidTemplateException(generator, "No getJavaClass method found"));
getJavaClassMethod.setType(new ClassOrInterfaceType(null, new SimpleName(Class.class.getName()), NodeList.nodeList(new ClassOrInterfaceType(null, javaType))));
BlockStmt getJavaClassMethodBody = new BlockStmt();
getJavaClassMethodBody.addStatement(new ReturnStmt(new NameExpr(javaType + ".class")));
getJavaClassMethod.setBody(getJavaClassMethodBody);
MethodDeclaration getTypeNameMethod = clazz.findFirst(MethodDeclaration.class, md -> md.getNameAsString().equals("getTypeName")).orElseThrow(() -> new InvalidTemplateException(generator, "No getTypeName method found"));
BlockStmt getTypeNameMethodBody = new BlockStmt();
getTypeNameMethodBody.addStatement(new ReturnStmt(new StringLiteralExpr(msg.getFullName())));
getTypeNameMethod.setBody(getTypeNameMethodBody);
MethodDeclaration readFromMethod = clazz.findFirst(MethodDeclaration.class, md -> md.getNameAsString().equals("readFrom")).orElseThrow(() -> new InvalidTemplateException(generator, "No readFrom method found"));
readFromMethod.setType(javaType);
readFromMethod.setBody(new BlockStmt());
MethodDeclaration writeToMethod = clazz.findFirst(MethodDeclaration.class, md -> md.getNameAsString().equals("writeTo")).orElseThrow(() -> new InvalidTemplateException(generator, "No writeTo method found"));
writeToMethod.getParameter(1).setType(javaType);
writeToMethod.setBody(new BlockStmt());
ClassOrInterfaceType classType = new ClassOrInterfaceType(null, javaType);
// read method
VariableDeclarationExpr instance = new VariableDeclarationExpr(new VariableDeclarator(classType, "value", new ObjectCreationExpr(null, classType, NodeList.nodeList())));
readFromMethod.getBody().ifPresent(b -> b.addStatement(instance));
for (FieldDescriptor field : msg.getFields()) {
String protoStreamMethodType = protoStreamMethodType(field.getTypeName());
Expression write = null;
Expression read = null;
if (protoStreamMethodType != null && !field.isRepeated()) {
// has a mapped type
read = new MethodCallExpr(new NameExpr("reader"), "read" + protoStreamMethodType).addArgument(new StringLiteralExpr(field.getName()));
String accessor = protoStreamMethodType.equals("Boolean") ? "is" : "get";
write = new MethodCallExpr(new NameExpr("writer"), "write" + protoStreamMethodType).addArgument(new StringLiteralExpr(field.getName())).addArgument(new MethodCallExpr(new NameExpr("t"), accessor + StringUtils.ucFirst(field.getName())));
} else {
// custom types
String customTypeName = javaTypeForMessage(d, field.getTypeName(), serializationContext);
if (field.isRepeated()) {
if (null == customTypeName || customTypeName.isEmpty()) {
customTypeName = primaryTypeClassName(field.getTypeName());
}
String writeMethod;
if (isArray(javaType, field)) {
writeMethod = "writeArray";
read = new MethodCallExpr(new NameExpr("reader"), "readArray").addArgument(new StringLiteralExpr(field.getName())).addArgument(new NameExpr(customTypeName + ".class"));
} else {
writeMethod = "writeCollection";
read = new MethodCallExpr(new NameExpr("reader"), "readCollection").addArgument(new StringLiteralExpr(field.getName())).addArgument(new ObjectCreationExpr(null, new ClassOrInterfaceType(null, ArrayList.class.getCanonicalName()), NodeList.nodeList())).addArgument(new NameExpr(customTypeName + ".class"));
}
write = new MethodCallExpr(new NameExpr("writer"), writeMethod).addArgument(new StringLiteralExpr(field.getName())).addArgument(new MethodCallExpr(new NameExpr("t"), "get" + StringUtils.ucFirst(field.getName()))).addArgument(new NameExpr(customTypeName + ".class"));
} else {
read = new MethodCallExpr(new NameExpr("reader"), "readObject").addArgument(new StringLiteralExpr(field.getName())).addArgument(new NameExpr(customTypeName + ".class"));
write = new MethodCallExpr(new NameExpr("writer"), "writeObject").addArgument(new StringLiteralExpr(field.getName())).addArgument(new MethodCallExpr(new NameExpr("t"), "get" + StringUtils.ucFirst(field.getName()))).addArgument(new NameExpr(customTypeName + ".class"));
}
if (customTypeName.equals(Serializable.class.getName())) {
String fieldClazz = (String) field.getOptionByName(KOGITO_JAVA_CLASS_OPTION);
if (fieldClazz == null) {
throw new IllegalArgumentException(format("Serializable proto field '%s' is missing value for option %s", field.getName(), KOGITO_JAVA_CLASS_OPTION));
} else {
read = new CastExpr().setExpression(new EnclosedExpr(read)).setType(fieldClazz);
}
}
}
MethodCallExpr setter = new MethodCallExpr(new NameExpr("value"), "set" + StringUtils.ucFirst(field.getName())).addArgument(read);
readFromMethod.getBody().ifPresent(b -> b.addStatement(setter));
// write method
writeToMethod.getBody().orElseThrow(() -> new NoSuchElementException("A method declaration doesn't contain a body!")).addStatement(write);
}
readFromMethod.getBody().ifPresent(b -> b.addStatement(new ReturnStmt(new NameExpr("value"))));
clazz.getMembers().sort(new BodyDeclarationComparator());
}
for (EnumDescriptor msg : d.getEnumTypes()) {
CompilationUnit compilationUnit = new CompilationUnit();
units.add(compilationUnit);
String javaType = packageFromOption(d, msg) + "." + msg.getName();
ClassOrInterfaceDeclaration classDeclaration = compilationUnit.setPackageDeclaration(d.getPackage()).addClass(msg.getName() + "EnumMarshaller").setPublic(true);
classDeclaration.addImplementedType(EnumMarshaller.class).getImplementedTypes(0).setTypeArguments(NodeList.nodeList(new ClassOrInterfaceType(null, javaType)));
classDeclaration.addMethod("getTypeName", PUBLIC).setType(String.class).setBody(new BlockStmt().addStatement(new ReturnStmt(new StringLiteralExpr(msg.getFullName()))));
classDeclaration.addMethod("getJavaClass", PUBLIC).setType(new ClassOrInterfaceType(null, new SimpleName(Class.class.getName()), NodeList.nodeList(new ClassOrInterfaceType(null, javaType)))).setBody(new BlockStmt().addStatement(new ReturnStmt(new ClassExpr(new ClassOrInterfaceType(null, javaType)))));
BlockStmt encodeBlock = new BlockStmt().addStatement(new IfStmt(new BinaryExpr(new NullLiteralExpr(), new NameExpr(STATE_PARAM), EQUALS), new ThrowStmt(new ObjectCreationExpr(null, new ClassOrInterfaceType(null, IllegalArgumentException.class.getName()), NodeList.nodeList(new StringLiteralExpr("Invalid value provided to enum")))), null)).addStatement(new ReturnStmt(new MethodCallExpr(new NameExpr(STATE_PARAM), "ordinal")));
classDeclaration.addMethod("encode", PUBLIC).setType("int").addParameter(javaType, STATE_PARAM).setBody(encodeBlock);
MethodDeclaration decode = classDeclaration.addMethod("decode", PUBLIC).setType(javaType).addParameter("int", "value");
SwitchStmt decodeSwitch = new SwitchStmt().setSelector(new NameExpr("value"));
msg.getValues().forEach(v -> {
SwitchEntry dEntry = new SwitchEntry();
dEntry.getLabels().add(new IntegerLiteralExpr(v.getNumber()));
dEntry.addStatement(new ReturnStmt(new NameExpr(javaType + "." + v.getName())));
decodeSwitch.getEntries().add(dEntry);
});
decodeSwitch.getEntries().add(new SwitchEntry().addStatement(new ThrowStmt(new ObjectCreationExpr(null, new ClassOrInterfaceType(null, IllegalArgumentException.class.getName()), NodeList.nodeList(new StringLiteralExpr("Invalid value provided to enum"))))));
decode.setBody(new BlockStmt().addStatement(decodeSwitch));
}
}
return units;
}
Aggregations