use of org.apache.thrift.protocol.TMessage in project dubbo by alibaba.
the class ThriftCodecTest method testDecodeExceptionResponse.
@Test
public void testDecodeExceptionResponse() throws Exception {
int port = NetUtils.getAvailablePort();
URL url = URL.valueOf(ThriftProtocol.NAME + "://127.0.0.1:" + port + "/" + Demo.class.getName());
Channel channel = new MockedChannel(url);
RandomAccessByteArrayOutputStream bos = new RandomAccessByteArrayOutputStream(128);
Request request = createRequest();
DefaultFuture future = DefaultFuture.newFuture(channel, request, 10, null);
TMessage message = new TMessage("echoString", TMessageType.EXCEPTION, ThriftCodec.getSeqId());
TTransport transport = new TIOStreamTransport(bos);
TBinaryProtocol protocol = new TBinaryProtocol(transport);
TApplicationException exception = new TApplicationException();
int messageLength, headerLength;
// prepare
protocol.writeI16(ThriftCodec.MAGIC);
protocol.writeI32(Integer.MAX_VALUE);
protocol.writeI16(Short.MAX_VALUE);
protocol.writeByte(ThriftCodec.VERSION);
protocol.writeString(Demo.class.getName());
// path
protocol.writeString(Demo.class.getName());
protocol.writeI64(request.getId());
protocol.getTransport().flush();
headerLength = bos.size();
protocol.writeMessageBegin(message);
exception.write(protocol);
protocol.writeMessageEnd();
protocol.getTransport().flush();
int oldIndex = messageLength = bos.size();
try {
bos.setWriteIndex(ThriftCodec.MESSAGE_LENGTH_INDEX);
protocol.writeI32(messageLength);
bos.setWriteIndex(ThriftCodec.MESSAGE_HEADER_LENGTH_INDEX);
protocol.writeI16((short) (0xffff & headerLength));
} finally {
bos.setWriteIndex(oldIndex);
}
// prepare
ChannelBuffer bis = ChannelBuffers.wrappedBuffer(encodeFrame(bos.toByteArray()));
Object obj = codec.decode((Channel) null, bis);
Assertions.assertNotNull(obj);
Assertions.assertTrue(obj instanceof Response);
Response response = (Response) obj;
Assertions.assertTrue(response.getResult() instanceof AppResponse);
AppResponse result = (AppResponse) response.getResult();
Assertions.assertTrue(result.hasException());
Assertions.assertTrue(result.getException() instanceof RpcException);
}
use of org.apache.thrift.protocol.TMessage in project dubbo by alibaba.
the class ThriftCodecTest method testDecodeReplyResponse.
@Test
public void testDecodeReplyResponse() throws Exception {
int port = NetUtils.getAvailablePort();
URL url = URL.valueOf(ThriftProtocol.NAME + "://127.0.0.1:" + port + "/" + Demo.Iface.class.getName());
Channel channel = new MockedChannel(url);
RandomAccessByteArrayOutputStream bos = new RandomAccessByteArrayOutputStream(128);
Request request = createRequest();
DefaultFuture future = DefaultFuture.newFuture(channel, request, 10, null);
TMessage message = new TMessage("echoString", TMessageType.REPLY, ThriftCodec.getSeqId());
Demo.echoString_result methodResult = new Demo.echoString_result();
methodResult.success = "Hello, World!";
TTransport transport = new TIOStreamTransport(bos);
TBinaryProtocol protocol = new TBinaryProtocol(transport);
int messageLength, headerLength;
// prepare
protocol.writeI16(ThriftCodec.MAGIC);
protocol.writeI32(Integer.MAX_VALUE);
protocol.writeI16(Short.MAX_VALUE);
protocol.writeByte(ThriftCodec.VERSION);
protocol.writeString(Demo.Iface.class.getName());
// path
protocol.writeString(Demo.Iface.class.getName());
protocol.writeI64(request.getId());
protocol.getTransport().flush();
headerLength = bos.size();
protocol.writeMessageBegin(message);
methodResult.write(protocol);
protocol.writeMessageEnd();
protocol.getTransport().flush();
int oldIndex = messageLength = bos.size();
try {
bos.setWriteIndex(ThriftCodec.MESSAGE_LENGTH_INDEX);
protocol.writeI32(messageLength);
bos.setWriteIndex(ThriftCodec.MESSAGE_HEADER_LENGTH_INDEX);
protocol.writeI16((short) (0xffff & headerLength));
} finally {
bos.setWriteIndex(oldIndex);
}
// prepare
byte[] buf = new byte[4 + bos.size()];
System.arraycopy(bos.toByteArray(), 0, buf, 4, bos.size());
ChannelBuffer bis = ChannelBuffers.wrappedBuffer(buf);
Object obj = codec.decode((Channel) null, bis);
Assertions.assertNotNull(obj);
Assertions.assertTrue(obj instanceof Response);
Response response = (Response) obj;
Assertions.assertEquals(request.getId(), response.getId());
Assertions.assertTrue(response.getResult() instanceof AppResponse);
AppResponse result = (AppResponse) response.getResult();
Assertions.assertTrue(result.getValue() instanceof String);
Assertions.assertEquals(methodResult.success, result.getValue());
}
use of org.apache.thrift.protocol.TMessage in project dubbo by alibaba.
the class ThriftCodecTest method testEncodeRequest.
@Test
public void testEncodeRequest() throws Exception {
Request request = createRequest();
ChannelBuffer output = ChannelBuffers.dynamicBuffer(1024);
codec.encode(channel, output, request);
byte[] bytes = new byte[output.readableBytes()];
output.readBytes(bytes);
ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
TTransport transport = new TIOStreamTransport(bis);
TBinaryProtocol protocol = new TBinaryProtocol(transport);
// frame
byte[] length = new byte[4];
transport.read(length, 0, 4);
if (bis.markSupported()) {
bis.mark(0);
}
// magic
Assertions.assertEquals(ThriftCodec.MAGIC, protocol.readI16());
// message length
int messageLength = protocol.readI32();
Assertions.assertEquals(messageLength + 4, bytes.length);
// header length
short headerLength = protocol.readI16();
// version
Assertions.assertEquals(ThriftCodec.VERSION, protocol.readByte());
// service name
Assertions.assertEquals(Demo.Iface.class.getName(), protocol.readString());
// path
Assertions.assertEquals(Demo.Iface.class.getName(), protocol.readString());
// dubbo request id
Assertions.assertEquals(request.getId(), protocol.readI64());
// test message header length
if (bis.markSupported()) {
bis.reset();
bis.skip(headerLength);
}
TMessage message = protocol.readMessageBegin();
Demo.echoString_args args = new Demo.echoString_args();
args.read(protocol);
protocol.readMessageEnd();
Assertions.assertEquals("echoString", message.name);
Assertions.assertEquals(TMessageType.CALL, message.type);
Assertions.assertEquals("Hello, World!", args.getArg());
}
use of org.apache.thrift.protocol.TMessage in project hive by apache.
the class TUGIBasedProcessor method process.
@SuppressWarnings("unchecked")
@Override
public void process(final TProtocol in, final TProtocol out) throws TException {
setIpAddress(in);
final TMessage msg = in.readMessageBegin();
final ProcessFunction<Iface, ? extends TBase> fn = functions.get(msg.name);
if (fn == null) {
TProtocolUtil.skip(in, TType.STRUCT);
in.readMessageEnd();
TApplicationException x = new TApplicationException(TApplicationException.UNKNOWN_METHOD, "Invalid method name: '" + msg.name + "'");
out.writeMessageBegin(new TMessage(msg.name, TMessageType.EXCEPTION, msg.seqid));
x.write(out);
out.writeMessageEnd();
out.getTransport().flush();
return;
}
TUGIContainingTransport ugiTrans = (TUGIContainingTransport) in.getTransport();
// Store ugi in transport if the rpc is set_ugi
if (msg.name.equalsIgnoreCase("set_ugi")) {
try {
handleSetUGI(ugiTrans, (ThriftHiveMetastore.Processor.set_ugi<Iface>) fn, msg, in, out);
} catch (TException e) {
throw e;
} catch (Exception e) {
throw new TException(e.getCause());
}
return;
}
UserGroupInformation clientUgi = ugiTrans.getClientUGI();
if (null == clientUgi) {
// At this point, transport must contain client ugi, if it doesn't then its an old client.
fn.process(msg.seqid, in, out, iface);
return;
} else {
// Found ugi, perform doAs().
PrivilegedExceptionAction<Void> pvea = new PrivilegedExceptionAction<Void>() {
@Override
public Void run() {
try {
fn.process(msg.seqid, in, out, iface);
return null;
} catch (TException te) {
throw new RuntimeException(te);
}
}
};
try {
clientUgi.doAs(pvea);
return;
} catch (RuntimeException rte) {
if (rte.getCause() instanceof TException) {
throw (TException) rte.getCause();
}
throw rte;
} catch (InterruptedException ie) {
// unexpected!
throw new RuntimeException(ie);
} catch (IOException ioe) {
// unexpected!
throw new RuntimeException(ioe);
} finally {
try {
FileSystem.closeAllForUGI(clientUgi);
} catch (IOException e) {
LOG.error("Could not clean up file-system handles for UGI: " + clientUgi, e);
}
}
}
}
use of org.apache.thrift.protocol.TMessage in project dubbo by alibaba.
the class ThriftCodec method decode.
private Object decode(TProtocol protocol) throws IOException {
// version
String serviceName;
long id;
TMessage message;
try {
protocol.readI16();
protocol.readByte();
serviceName = protocol.readString();
id = protocol.readI64();
message = protocol.readMessageBegin();
} catch (TException e) {
throw new IOException(e.getMessage(), e);
}
if (message.type == TMessageType.CALL) {
RpcInvocation result = new RpcInvocation();
result.setAttachment(Constants.INTERFACE_KEY, serviceName);
result.setMethodName(message.name);
String argsClassName = ExtensionLoader.getExtensionLoader(ClassNameGenerator.class).getExtension(ThriftClassNameGenerator.NAME).generateArgsClassName(serviceName, message.name);
if (StringUtils.isEmpty(argsClassName)) {
throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, "The specified interface name incorrect.");
}
Class clazz = cachedClass.get(argsClassName);
if (clazz == null) {
try {
clazz = ClassHelper.forNameWithThreadContextClassLoader(argsClassName);
cachedClass.putIfAbsent(argsClassName, clazz);
} catch (ClassNotFoundException e) {
throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
}
}
TBase args;
try {
args = (TBase) clazz.newInstance();
} catch (InstantiationException e) {
throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
} catch (IllegalAccessException e) {
throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
}
try {
args.read(protocol);
protocol.readMessageEnd();
} catch (TException e) {
throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
}
List<Object> parameters = new ArrayList<Object>();
List<Class<?>> parameterTypes = new ArrayList<Class<?>>();
int index = 1;
while (true) {
TFieldIdEnum fieldIdEnum = args.fieldForId(index++);
if (fieldIdEnum == null) {
break;
}
String fieldName = fieldIdEnum.getFieldName();
String getMethodName = ThriftUtils.generateGetMethodName(fieldName);
Method getMethod;
try {
getMethod = clazz.getMethod(getMethodName);
} catch (NoSuchMethodException e) {
throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
}
parameterTypes.add(getMethod.getReturnType());
try {
parameters.add(getMethod.invoke(args));
} catch (IllegalAccessException e) {
throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
} catch (InvocationTargetException e) {
throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
}
}
result.setArguments(parameters.toArray());
result.setParameterTypes(parameterTypes.toArray(new Class[parameterTypes.size()]));
Request request = new Request(id);
request.setData(result);
cachedRequest.putIfAbsent(id, RequestData.create(message.seqid, serviceName, message.name));
return request;
} else if (message.type == TMessageType.EXCEPTION) {
TApplicationException exception;
try {
exception = TApplicationException.read(protocol);
protocol.readMessageEnd();
} catch (TException e) {
throw new IOException(e.getMessage(), e);
}
RpcResult result = new RpcResult();
result.setException(new RpcException(exception.getMessage()));
Response response = new Response();
response.setResult(result);
response.setId(id);
return response;
} else if (message.type == TMessageType.REPLY) {
String resultClassName = ExtensionLoader.getExtensionLoader(ClassNameGenerator.class).getExtension(ThriftClassNameGenerator.NAME).generateResultClassName(serviceName, message.name);
if (StringUtils.isEmpty(resultClassName)) {
throw new IllegalArgumentException(new StringBuilder(32).append("Could not infer service result class name from service name ").append(serviceName).append(", the service name you specified may not generated by thrift idl compiler").toString());
}
Class<?> clazz = cachedClass.get(resultClassName);
if (clazz == null) {
try {
clazz = ClassHelper.forNameWithThreadContextClassLoader(resultClassName);
cachedClass.putIfAbsent(resultClassName, clazz);
} catch (ClassNotFoundException e) {
throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
}
}
TBase<?, ? extends TFieldIdEnum> result;
try {
result = (TBase<?, ?>) clazz.newInstance();
} catch (InstantiationException e) {
throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
} catch (IllegalAccessException e) {
throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
}
try {
result.read(protocol);
protocol.readMessageEnd();
} catch (TException e) {
throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
}
Object realResult = null;
int index = 0;
while (true) {
TFieldIdEnum fieldIdEnum = result.fieldForId(index++);
if (fieldIdEnum == null) {
break;
}
Field field;
try {
field = clazz.getDeclaredField(fieldIdEnum.getFieldName());
field.setAccessible(true);
} catch (NoSuchFieldException e) {
throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
}
try {
realResult = field.get(result);
} catch (IllegalAccessException e) {
throw new RpcException(RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e);
}
if (realResult != null) {
break;
}
}
Response response = new Response();
response.setId(id);
RpcResult rpcResult = new RpcResult();
if (realResult instanceof Throwable) {
rpcResult.setException((Throwable) realResult);
} else {
rpcResult.setValue(realResult);
}
response.setResult(rpcResult);
return response;
} else {
// Impossible
throw new IOException();
}
}
Aggregations