use of org.bytedeco.javacpp.Pointer in project deeplearning4j by deeplearning4j.
the class CudnnBatchNormalizationHelper method backpropGradient.
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, INDArray dGammaView, INDArray dBetaView, double eps) {
if (eps < CUDNN_BN_MIN_EPSILON) {
throw new IllegalArgumentException("Error: eps < CUDNN_BN_MIN_EPSILON (" + eps + " < " + CUDNN_BN_MIN_EPSILON + ")");
}
int miniBatch = input.size(0);
int depth = input.size(1);
int inH = input.size(2);
int inW = input.size(3);
Gradient retGradient = new DefaultGradient();
if (!Shape.strideDescendingCAscendingF(epsilon)) {
// apparently not supported by cuDNN
epsilon = epsilon.dup();
}
int[] srcStride = input.stride();
int[] deltaStride = epsilon.stride();
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, depth, inH, inW, srcStride[0], srcStride[1], srcStride[2], srcStride[3]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, miniBatch, depth, inH, inW, deltaStride[0], deltaStride[1], deltaStride[2], deltaStride[3]));
INDArray nextEpsilon = Nd4j.createUninitialized(new int[] { miniBatch, depth, inH, inW }, 'c');
int[] dstStride = nextEpsilon.stride();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW, dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
int[] gammaStride = gamma.stride();
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, tensorFormat, dataType, shape[0], shape[1], shape.length > 2 ? shape[2] : 1, shape.length > 3 ? shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma, dGammaView, dBetaView);
Pointer srcData = allocator.getPointer(input, context);
Pointer epsData = allocator.getPointer(epsilon, context);
Pointer dstData = allocator.getPointer(nextEpsilon, context);
Pointer gammaData = allocator.getPointer(gamma, context);
Pointer dGammaData = allocator.getPointer(dGammaView, context);
Pointer dBetaData = allocator.getPointer(dBetaView, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, beta, alpha, alpha, cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData, cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData, dBetaData, eps, meanCache, varCache));
allocator.getFlowController().registerActionAllWrite(context, input, epsilon, nextEpsilon, gamma, dGammaView, dBetaView);
retGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView);
retGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView);
return new Pair<>(retGradient, nextEpsilon);
}
use of org.bytedeco.javacpp.Pointer in project deeplearning4j by deeplearning4j.
the class BaseStatsListener method iterationDone.
@Override
public void iterationDone(Model model, int iteration) {
StatsUpdateConfiguration config = updateConfig;
ModelInfo modelInfo = getModelInfo(model);
boolean backpropParamsOnly = backpropParamsOnly(model);
long currentTime = getTime();
if (modelInfo.iterCount == 0) {
modelInfo.initTime = currentTime;
doInit(model);
}
if (config.collectPerformanceStats()) {
updateExamplesMinibatchesCounts(model);
}
if (config.reportingFrequency() > 1 && (iteration == 0 || iteration % config.reportingFrequency() != 0)) {
modelInfo.iterCount = iteration;
return;
}
StatsReport report = getNewStatsReport();
//TODO support NTP time
report.reportIDs(getSessionID(model), TYPE_ID, workerID, System.currentTimeMillis());
//--- Performance and System Stats ---
if (config.collectPerformanceStats()) {
//Stats to collect: total runtime, total examples, total minibatches, iterations/second, examples/second
double examplesPerSecond;
double minibatchesPerSecond;
if (modelInfo.iterCount == 0) {
//Not possible to work out perf/second: first iteration...
examplesPerSecond = 0.0;
minibatchesPerSecond = 0.0;
} else {
long deltaTimeMS = currentTime - modelInfo.lastReportTime;
examplesPerSecond = 1000.0 * modelInfo.examplesSinceLastReport / deltaTimeMS;
minibatchesPerSecond = 1000.0 * modelInfo.minibatchesSinceLastReport / deltaTimeMS;
}
long totalRuntimeMS = currentTime - modelInfo.initTime;
report.reportPerformance(totalRuntimeMS, modelInfo.totalExamples, modelInfo.totalMinibatches, examplesPerSecond, minibatchesPerSecond);
modelInfo.examplesSinceLastReport = 0;
modelInfo.minibatchesSinceLastReport = 0;
}
if (config.collectMemoryStats()) {
Runtime runtime = Runtime.getRuntime();
long jvmTotal = runtime.totalMemory();
long jvmMax = runtime.maxMemory();
//Off-heap memory
long offheapTotal = Pointer.totalBytes();
long offheapMax = Pointer.maxBytes();
//GPU
long[] gpuCurrentBytes = null;
long[] gpuMaxBytes = null;
NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
int nDevices = nativeOps.getAvailableDevices();
if (nDevices > 0) {
gpuCurrentBytes = new long[nDevices];
gpuMaxBytes = new long[nDevices];
for (int i = 0; i < nDevices; i++) {
try {
Pointer p = getDevicePointer(i);
if (p == null) {
gpuMaxBytes[i] = 0;
gpuCurrentBytes[i] = 0;
} else {
gpuMaxBytes[i] = nativeOps.getDeviceTotalMemory(p);
gpuCurrentBytes[i] = gpuMaxBytes[i] - nativeOps.getDeviceFreeMemory(p);
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
report.reportMemoryUse(jvmTotal, jvmMax, offheapTotal, offheapMax, gpuCurrentBytes, gpuMaxBytes);
}
if (config.collectGarbageCollectionStats()) {
if (modelInfo.lastReportIteration == -1 || gcBeans == null) {
//Haven't reported GC stats before...
gcBeans = ManagementFactory.getGarbageCollectorMXBeans();
gcStatsAtLastReport = new HashMap<>();
for (GarbageCollectorMXBean bean : gcBeans) {
long count = bean.getCollectionCount();
long timeMs = bean.getCollectionTime();
gcStatsAtLastReport.put(bean.getName(), new Pair<>(count, timeMs));
}
} else {
for (GarbageCollectorMXBean bean : gcBeans) {
long count = bean.getCollectionCount();
long timeMs = bean.getCollectionTime();
Pair<Long, Long> lastStats = gcStatsAtLastReport.get(bean.getName());
long deltaGCCount = count - lastStats.getFirst();
long deltaGCTime = timeMs - lastStats.getSecond();
lastStats.setFirst(count);
lastStats.setSecond(timeMs);
report.reportGarbageCollection(bean.getName(), (int) deltaGCCount, (int) deltaGCTime);
}
}
}
//--- General ---
//Always report score
report.reportScore(model.score());
if (config.collectLearningRates()) {
Map<String, Double> lrs = new HashMap<>();
if (model instanceof MultiLayerNetwork) {
//Need to append "0_", "1_" etc to param names from layers...
int layerIdx = 0;
for (Layer l : ((MultiLayerNetwork) model).getLayers()) {
NeuralNetConfiguration conf = l.conf();
Map<String, Double> layerLrs = conf.getLearningRateByParam();
Set<String> backpropParams = l.paramTable(true).keySet();
for (Map.Entry<String, Double> entry : layerLrs.entrySet()) {
if (!backpropParams.contains(entry.getKey()))
//Skip pretrain params
continue;
lrs.put(layerIdx + "_" + entry.getKey(), entry.getValue());
}
layerIdx++;
}
} else if (model instanceof ComputationGraph) {
for (Layer l : ((ComputationGraph) model).getLayers()) {
//Need to append layer name
NeuralNetConfiguration conf = l.conf();
Map<String, Double> layerLrs = conf.getLearningRateByParam();
String layerName = conf.getLayer().getLayerName();
Set<String> backpropParams = l.paramTable(true).keySet();
for (Map.Entry<String, Double> entry : layerLrs.entrySet()) {
if (!backpropParams.contains(entry.getKey()))
//Skip pretrain params
continue;
lrs.put(layerName + "_" + entry.getKey(), entry.getValue());
}
}
} else if (model instanceof Layer) {
Layer l = (Layer) model;
Map<String, Double> map = l.conf().getLearningRateByParam();
lrs.putAll(map);
}
report.reportLearningRates(lrs);
}
if (config.collectHistograms(StatsType.Parameters)) {
Map<String, Histogram> paramHistograms = getHistograms(model.paramTable(backpropParamsOnly), config.numHistogramBins(StatsType.Parameters));
report.reportHistograms(StatsType.Parameters, paramHistograms);
}
if (config.collectHistograms(StatsType.Gradients)) {
Map<String, Histogram> gradientHistograms = getHistograms(gradientsPreUpdateMap, config.numHistogramBins(StatsType.Gradients));
report.reportHistograms(StatsType.Gradients, gradientHistograms);
}
if (config.collectHistograms(StatsType.Updates)) {
Map<String, Histogram> updateHistograms = getHistograms(model.gradient().gradientForVariable(), config.numHistogramBins(StatsType.Updates));
report.reportHistograms(StatsType.Updates, updateHistograms);
}
if (config.collectHistograms(StatsType.Activations)) {
Map<String, Histogram> activationHistograms = getHistograms(activationsMap, config.numHistogramBins(StatsType.Activations));
report.reportHistograms(StatsType.Activations, activationHistograms);
}
if (config.collectMean(StatsType.Parameters)) {
Map<String, Double> meanParams = calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.Mean);
report.reportMean(StatsType.Parameters, meanParams);
}
if (config.collectMean(StatsType.Gradients)) {
Map<String, Double> meanGradients = calculateSummaryStats(gradientsPreUpdateMap, StatType.Mean);
report.reportMean(StatsType.Gradients, meanGradients);
}
if (config.collectMean(StatsType.Updates)) {
Map<String, Double> meanUpdates = calculateSummaryStats(model.gradient().gradientForVariable(), StatType.Mean);
report.reportMean(StatsType.Updates, meanUpdates);
}
if (config.collectMean(StatsType.Activations)) {
Map<String, Double> meanActivations = calculateSummaryStats(activationsMap, StatType.Mean);
report.reportMean(StatsType.Activations, meanActivations);
}
if (config.collectStdev(StatsType.Parameters)) {
Map<String, Double> stdevParams = calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.Stdev);
report.reportStdev(StatsType.Parameters, stdevParams);
}
if (config.collectStdev(StatsType.Gradients)) {
Map<String, Double> stdevGradient = calculateSummaryStats(gradientsPreUpdateMap, StatType.Stdev);
report.reportStdev(StatsType.Gradients, stdevGradient);
}
if (config.collectStdev(StatsType.Updates)) {
Map<String, Double> stdevUpdates = calculateSummaryStats(model.gradient().gradientForVariable(), StatType.Stdev);
report.reportStdev(StatsType.Updates, stdevUpdates);
}
if (config.collectStdev(StatsType.Activations)) {
Map<String, Double> stdevActivations = calculateSummaryStats(activationsMap, StatType.Stdev);
report.reportStdev(StatsType.Activations, stdevActivations);
}
if (config.collectMeanMagnitudes(StatsType.Parameters)) {
Map<String, Double> meanMagParams = calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.MeanMagnitude);
report.reportMeanMagnitudes(StatsType.Parameters, meanMagParams);
}
if (config.collectMeanMagnitudes(StatsType.Gradients)) {
Map<String, Double> meanMagGradients = calculateSummaryStats(gradientsPreUpdateMap, StatType.MeanMagnitude);
report.reportMeanMagnitudes(StatsType.Gradients, meanMagGradients);
}
if (config.collectMeanMagnitudes(StatsType.Updates)) {
Map<String, Double> meanMagUpdates = calculateSummaryStats(model.gradient().gradientForVariable(), StatType.MeanMagnitude);
report.reportMeanMagnitudes(StatsType.Updates, meanMagUpdates);
}
if (config.collectMeanMagnitudes(StatsType.Activations)) {
Map<String, Double> meanMagActivations = calculateSummaryStats(activationsMap, StatType.MeanMagnitude);
report.reportMeanMagnitudes(StatsType.Activations, meanMagActivations);
}
long endTime = getTime();
//Amount of time required to alculate all histograms, means etc.
report.reportStatsCollectionDurationMS((int) (endTime - currentTime));
modelInfo.lastReportTime = currentTime;
modelInfo.lastReportIteration = iteration;
report.reportIterationCount(iteration);
this.router.putUpdate(report);
modelInfo.iterCount = iteration;
activationsMap = null;
}
use of org.bytedeco.javacpp.Pointer in project deeplearning4j by deeplearning4j.
the class BaseStatsListener method getDevicePointer.
private synchronized Pointer getDevicePointer(int device) {
if (devPointers.containsKey(device)) {
return devPointers.get(device);
}
try {
Class<?> c = Class.forName("org.nd4j.jita.allocator.pointers.CudaPointer");
Constructor<?> constructor = c.getConstructor(long.class);
Pointer p = (Pointer) constructor.newInstance((long) device);
devPointers.put(device, p);
return p;
} catch (Throwable t) {
//Stops attempting the failure again later...
devPointers.put(device, null);
return null;
}
}
use of org.bytedeco.javacpp.Pointer in project bigbluebutton by bigbluebutton.
the class Generator method returnAfter.
void returnAfter(MethodInformation methodInfo) {
String indent = methodInfo.throwsException != null ? " " : " ";
String[] typeName = methodInfo.returnRaw ? new String[] { "" } : cppCastTypeName(methodInfo.returnType, methodInfo.annotations);
Annotation returnBy = by(methodInfo.annotations);
String valueTypeName = valueTypeName(typeName);
AdapterInformation adapterInfo = adapterInformation(false, valueTypeName, methodInfo.annotations);
String suffix = methodInfo.deallocator ? "" : ";";
if (!methodInfo.returnType.isPrimitive() && adapterInfo != null) {
suffix = ")" + suffix;
}
if ((Pointer.class.isAssignableFrom(methodInfo.returnType) || (methodInfo.returnType.isArray() && methodInfo.returnType.getComponentType().isPrimitive()) || Buffer.class.isAssignableFrom(methodInfo.returnType))) {
if (returnBy instanceof ByVal && adapterInfo == null) {
suffix = ")" + suffix;
} else if (returnBy instanceof ByPtrPtr) {
out.println(suffix);
suffix = "";
out.println(indent + "if (rptrptr == NULL) {");
out.println(indent + " env->ThrowNew(JavaCPP_getClass(env, " + jclasses.index(NullPointerException.class) + "), \"Return pointer address is NULL.\");");
out.println(indent + "} else {");
out.println(indent + " rptr = *rptrptr;");
out.println(indent + "}");
}
}
out.println(suffix);
if (methodInfo.returnType == void.class) {
if (methodInfo.allocator || methodInfo.arrayAllocator) {
out.println(indent + "jlong rcapacity = " + (methodInfo.arrayAllocator ? "arg0;" : "1;"));
boolean noDeallocator = methodInfo.cls == Pointer.class || methodInfo.cls.isAnnotationPresent(NoDeallocator.class) || methodInfo.method.isAnnotationPresent(NoDeallocator.class);
out.print(indent + "JavaCPP_initPointer(env, obj, rptr, rcapacity, rptr, ");
if (noDeallocator) {
out.println("NULL);");
} else if (methodInfo.arrayAllocator) {
out.println("&JavaCPP_" + mangle(methodInfo.cls.getName()) + "_deallocateArray);");
arrayDeallocators.index(methodInfo.cls);
} else {
out.println("&JavaCPP_" + mangle(methodInfo.cls.getName()) + "_deallocate);");
deallocators.index(methodInfo.cls);
}
if (virtualFunctions.containsKey(methodInfo.cls)) {
typeName = cppTypeName(methodInfo.cls);
valueTypeName = valueTypeName(typeName);
String subType = "JavaCPP_" + mangle(valueTypeName);
out.println(indent + "((" + subType + "*)rptr)->obj = env->NewWeakGlobalRef(obj);");
}
}
} else {
if (methodInfo.valueSetter || methodInfo.memberSetter || methodInfo.noReturnGetter) {
// nothing
} else if (methodInfo.returnType.isPrimitive()) {
out.println(indent + "rarg = (" + jniTypeName(methodInfo.returnType) + ")rvalue;");
} else if (methodInfo.returnRaw) {
out.println(indent + "rarg = rptr;");
} else {
boolean needInit = false;
if (adapterInfo != null) {
out.println(indent + "rptr = radapter;");
if (methodInfo.returnType != String.class) {
out.println(indent + "jlong rcapacity = (jlong)radapter.size;");
if (Pointer.class.isAssignableFrom(methodInfo.returnType)) {
out.println(indent + "void* rowner = radapter.owner;");
}
out.println(indent + "void (*deallocator)(void*) = " + (adapterInfo.constant ? "NULL;" : "&" + adapterInfo.name + "::deallocate;"));
}
needInit = true;
} else if (returnBy instanceof ByVal || FunctionPointer.class.isAssignableFrom(methodInfo.returnType)) {
out.println(indent + "jlong rcapacity = 1;");
out.println(indent + "void* rowner = (void*)rptr;");
out.println(indent + "void (*deallocator)(void*) = &JavaCPP_" + mangle(methodInfo.returnType.getName()) + "_deallocate;");
deallocators.index(methodInfo.returnType);
needInit = true;
}
if (Pointer.class.isAssignableFrom(methodInfo.returnType)) {
out.print(indent);
if (!(returnBy instanceof ByVal)) {
// check if we can reuse one of the Pointer objects from the arguments
if (Modifier.isStatic(methodInfo.modifiers) && methodInfo.parameterTypes.length > 0) {
for (int i = 0; i < methodInfo.parameterTypes.length; i++) {
String cast = cast(methodInfo, i);
if (Arrays.equals(methodInfo.parameterAnnotations[i], methodInfo.annotations) && methodInfo.parameterTypes[i] == methodInfo.returnType) {
out.println("if (rptr == " + cast + "ptr" + i + ") {");
out.println(indent + " rarg = arg" + i + ";");
out.print(indent + "} else ");
}
}
} else if (!Modifier.isStatic(methodInfo.modifiers) && methodInfo.cls == methodInfo.returnType) {
out.println("if (rptr == ptr) {");
out.println(indent + " rarg = obj;");
out.print(indent + "} else ");
}
}
out.println("if (rptr != NULL) {");
out.println(indent + " rarg = JavaCPP_createPointer(env, " + jclasses.index(methodInfo.returnType) + (methodInfo.parameterTypes.length > 0 && methodInfo.parameterTypes[0] == Class.class ? ", arg0);" : ");"));
out.println(indent + " if (rarg != NULL) {");
if (needInit) {
out.println(indent + " JavaCPP_initPointer(env, rarg, rptr, rcapacity, rowner, deallocator);");
} else {
out.println(indent + " env->SetLongField(rarg, JavaCPP_addressFID, ptr_to_jlong(rptr));");
}
out.println(indent + " }");
out.println(indent + "}");
} else if (methodInfo.returnType == String.class) {
passesStrings = true;
out.println(indent + "if (rptr != NULL) {");
out.println(indent + " rarg = JavaCPP_createString(env, rptr);");
out.println(indent + "}");
} else if (methodInfo.returnType.isArray() && methodInfo.returnType.getComponentType().isPrimitive()) {
if (adapterInfo == null && !(returnBy instanceof ByVal)) {
out.println(indent + "jlong rcapacity = rptr != NULL ? 1 : 0;");
}
String componentName = methodInfo.returnType.getComponentType().getName();
String componentNameUpperCase = Character.toUpperCase(componentName.charAt(0)) + componentName.substring(1);
out.println(indent + "if (rptr != NULL) {");
out.println(indent + " rarg = env->New" + componentNameUpperCase + "Array(rcapacity < INT_MAX ? rcapacity : INT_MAX);");
out.println(indent + " env->Set" + componentNameUpperCase + "ArrayRegion(rarg, 0, rcapacity < INT_MAX ? rcapacity : INT_MAX, (j" + componentName + "*)rptr);");
out.println(indent + "}");
if (adapterInfo != null) {
out.println(indent + "if (deallocator != 0 && rptr != NULL) {");
out.println(indent + " (*(void(*)(void*))jlong_to_ptr(deallocator))((void*)rptr);");
out.println(indent + "}");
}
} else if (Buffer.class.isAssignableFrom(methodInfo.returnType)) {
if (methodInfo.bufferGetter) {
out.println(indent + "jlong rcapacity = size;");
} else if (adapterInfo == null && !(returnBy instanceof ByVal)) {
out.println(indent + "jlong rcapacity = rptr != NULL ? 1 : 0;");
}
out.println(indent + "if (rptr != NULL) {");
out.println(indent + " jlong rcapacityptr = rcapacity * sizeof(rptr[0]);");
out.println(indent + " rarg = env->NewDirectByteBuffer((void*)rptr, rcapacityptr < INT_MAX ? rcapacityptr : INT_MAX);");
out.println(indent + "}");
}
}
}
}
use of org.bytedeco.javacpp.Pointer in project bigbluebutton by bigbluebutton.
the class Generator method callback.
void callback(Class<?> cls, Method callbackMethod, String callbackName, boolean needDefinition, MethodInformation methodInfo) {
Class<?> callbackReturnType = callbackMethod.getReturnType();
Class<?>[] callbackParameterTypes = callbackMethod.getParameterTypes();
Annotation[] callbackAnnotations = callbackMethod.getAnnotations();
Annotation[][] callbackParameterAnnotations = callbackMethod.getParameterAnnotations();
String instanceTypeName = functionClassName(cls);
String[] callbackTypeName = cppFunctionTypeName(callbackMethod);
String[] returnConvention = callbackTypeName[0].split("\\(");
returnConvention[1] = constValueTypeName(returnConvention[1]);
String parameterDeclaration = callbackTypeName[1].substring(1);
String fieldName = mangle(callbackMethod.getName()) + "__" + mangle(signature(callbackMethod.getParameterTypes()));
String firstLine = "";
if (methodInfo != null) {
// stuff from a virtualized class
String nonconstParamDeclaration = parameterDeclaration.endsWith(" const") ? parameterDeclaration.substring(0, parameterDeclaration.length() - 6) : parameterDeclaration;
String[] typeName = methodInfo.returnRaw ? new String[] { "" } : cppTypeName(methodInfo.cls);
String valueTypeName = valueTypeName(typeName);
String subType = "JavaCPP_" + mangle(valueTypeName);
Set<String> memberList = virtualMembers.get(cls);
if (memberList == null) {
virtualMembers.put(cls, memberList = new LinkedHashSet<String>());
}
String member = " ";
if (methodInfo.arrayAllocator) {
return;
} else if (methodInfo.allocator) {
member += subType + nonconstParamDeclaration + " : " + valueTypeName + "(";
for (int j = 0; j < callbackParameterTypes.length; j++) {
member += "arg" + j;
if (j < callbackParameterTypes.length - 1) {
member += ", ";
}
}
member += "), obj(NULL) { }";
} else {
Set<String> functionList = virtualFunctions.get(cls);
if (functionList == null) {
virtualFunctions.put(cls, functionList = new LinkedHashSet<String>());
}
member += "using " + valueTypeName + "::" + methodInfo.memberName[0] + ";\n " + "virtual " + returnConvention[0] + (returnConvention.length > 1 ? returnConvention[1] : "") + methodInfo.memberName[0] + parameterDeclaration + ";\n " + returnConvention[0] + "super_" + methodInfo.memberName[0] + nonconstParamDeclaration + " { ";
if (methodInfo.method.getAnnotation(Virtual.class).value()) {
member += "throw JavaCPP_exception(\"Cannot call a pure virtual function.\"); }";
} else {
member += (callbackReturnType != void.class ? "return " : "") + valueTypeName + "::" + methodInfo.memberName[0] + "(";
for (int j = 0; j < callbackParameterTypes.length; j++) {
member += "arg" + j;
if (j < callbackParameterTypes.length - 1) {
member += ", ";
}
}
member += "); }";
}
firstLine = returnConvention[0] + (returnConvention.length > 1 ? returnConvention[1] : "") + subType + "::" + methodInfo.memberName[0] + parameterDeclaration + " {";
functionList.add(fieldName);
}
memberList.add(member);
} else if (callbackName != null) {
callbacks.index("static " + instanceTypeName + " " + callbackName + "_instance;");
Convention convention = cls.getAnnotation(Convention.class);
if (convention != null && !convention.extern().equals("C")) {
out.println("extern \"" + convention.extern() + "\" {");
if (out2 != null) {
out2.println("extern \"" + convention.extern() + "\" {");
}
}
if (out2 != null) {
out2.println("JNIIMPORT " + returnConvention[0] + (returnConvention.length > 1 ? returnConvention[1] : "") + callbackName + parameterDeclaration + ";");
}
out.println("JNIEXPORT " + returnConvention[0] + (returnConvention.length > 1 ? returnConvention[1] : "") + callbackName + parameterDeclaration + " {");
out.print((callbackReturnType != void.class ? " return " : " ") + callbackName + "_instance(");
for (int j = 0; j < callbackParameterTypes.length; j++) {
out.print("arg" + j);
if (j < callbackParameterTypes.length - 1) {
out.print(", ");
}
}
out.println(");");
out.println("}");
if (convention != null && !convention.extern().equals("C")) {
out.println("}");
if (out2 != null) {
out2.println("}");
}
}
firstLine = returnConvention[0] + instanceTypeName + "::operator()" + parameterDeclaration + " {";
}
if (!needDefinition) {
return;
}
out.println(firstLine);
String returnPrefix = "";
if (callbackReturnType != void.class) {
out.println(" " + jniTypeName(callbackReturnType) + " rarg = 0;");
returnPrefix = "rarg = ";
if (callbackReturnType == String.class) {
returnPrefix += "(jstring)";
}
}
String callbackReturnCast = cast(callbackReturnType, callbackAnnotations);
Annotation returnBy = by(callbackAnnotations);
String[] returnTypeName = cppTypeName(callbackReturnType);
String returnValueTypeName = valueTypeName(returnTypeName);
AdapterInformation returnAdapterInfo = adapterInformation(false, returnValueTypeName, callbackAnnotations);
out.println(" jthrowable exc = NULL;");
out.println(" JNIEnv* env;");
out.println(" bool attached = JavaCPP_getEnv(&env);");
out.println(" if (env == NULL) {");
out.println(" goto end;");
out.println(" }");
out.println("{");
if (callbackParameterTypes.length > 0) {
out.println(" jvalue args[" + callbackParameterTypes.length + "];");
for (int j = 0; j < callbackParameterTypes.length; j++) {
if (callbackParameterTypes[j].isPrimitive()) {
out.println(" args[" + j + "]." + signature(callbackParameterTypes[j]).toLowerCase() + " = (" + jniTypeName(callbackParameterTypes[j]) + ")arg" + j + ";");
} else {
Annotation passBy = by(callbackParameterAnnotations[j]);
String[] typeName = cppTypeName(callbackParameterTypes[j]);
String valueTypeName = valueTypeName(typeName);
AdapterInformation adapterInfo = adapterInformation(false, valueTypeName, callbackParameterAnnotations[j]);
if (adapterInfo != null) {
usesAdapters = true;
out.println(" " + adapterInfo.name + " adapter" + j + "(arg" + j + ");");
}
if (Pointer.class.isAssignableFrom(callbackParameterTypes[j]) || Buffer.class.isAssignableFrom(callbackParameterTypes[j]) || (callbackParameterTypes[j].isArray() && callbackParameterTypes[j].getComponentType().isPrimitive())) {
String cast = "(" + typeName[0] + typeName[1] + ")";
if (FunctionPointer.class.isAssignableFrom(callbackParameterTypes[j])) {
functions.index(callbackParameterTypes[j]);
typeName[0] = functionClassName(callbackParameterTypes[j]) + "*";
typeName[1] = "";
valueTypeName = valueTypeName(typeName);
} else if (virtualFunctions.containsKey(callbackParameterTypes[j])) {
String subType = "JavaCPP_" + mangle(valueTypeName);
valueTypeName = subType;
}
out.println(" " + jniTypeName(callbackParameterTypes[j]) + " obj" + j + " = NULL;");
out.println(" " + typeName[0] + " ptr" + j + typeName[1] + " = NULL;");
if (FunctionPointer.class.isAssignableFrom(callbackParameterTypes[j])) {
out.println(" ptr" + j + " = new (std::nothrow) " + valueTypeName + ";");
out.println(" if (ptr" + j + " != NULL) {");
out.println(" ptr" + j + "->ptr = " + cast + "&arg" + j + ";");
out.println(" }");
} else if (adapterInfo != null) {
out.println(" ptr" + j + " = adapter" + j + ";");
} else if (passBy instanceof ByVal && callbackParameterTypes[j] != Pointer.class) {
out.println(" ptr" + j + (noException(callbackParameterTypes[j], callbackMethod) ? " = new (std::nothrow) " : " = new ") + valueTypeName + typeName[1] + "(*" + cast + "&arg" + j + ");");
} else if (passBy instanceof ByVal || passBy instanceof ByRef) {
out.println(" ptr" + j + " = " + cast + "&arg" + j + ";");
} else if (passBy instanceof ByPtrPtr) {
out.println(" if (arg" + j + " == NULL) {");
out.println(" JavaCPP_log(\"Pointer address of argument " + j + " is NULL in callback for " + cls.getCanonicalName() + ".\");");
out.println(" } else {");
out.println(" ptr" + j + " = " + cast + "*arg" + j + ";");
out.println(" }");
} else {
// ByPtr || ByPtrRef
out.println(" ptr" + j + " = " + cast + "arg" + j + ";");
}
}
boolean needInit = false;
if (adapterInfo != null) {
if (callbackParameterTypes[j] != String.class) {
out.println(" jlong size" + j + " = (jlong)adapter" + j + ".size;");
out.println(" void* owner" + j + " = adapter" + j + ".owner;");
out.println(" void (*deallocator" + j + ")(void*) = &" + adapterInfo.name + "::deallocate;");
}
needInit = true;
} else if ((passBy instanceof ByVal && callbackParameterTypes[j] != Pointer.class) || FunctionPointer.class.isAssignableFrom(callbackParameterTypes[j])) {
out.println(" jlong size" + j + " = 1;");
out.println(" void* owner" + j + " = ptr" + j + ";");
out.println(" void (*deallocator" + j + ")(void*) = &JavaCPP_" + mangle(callbackParameterTypes[j].getName()) + "_deallocate;");
deallocators.index(callbackParameterTypes[j]);
needInit = true;
}
if (Pointer.class.isAssignableFrom(callbackParameterTypes[j])) {
String s = " obj" + j + " = JavaCPP_createPointer(env, " + jclasses.index(callbackParameterTypes[j]) + ");";
adapterInfo = adapterInformation(true, valueTypeName, callbackParameterAnnotations[j]);
if (adapterInfo != null || passBy instanceof ByPtrPtr || passBy instanceof ByPtrRef) {
out.println(s);
} else {
out.println(" if (ptr" + j + " != NULL) { ");
out.println(" " + s);
out.println(" }");
}
out.println(" if (obj" + j + " != NULL) { ");
if (needInit) {
out.println(" JavaCPP_initPointer(env, obj" + j + ", ptr" + j + ", size" + j + ", owner" + j + ", deallocator" + j + ");");
} else {
out.println(" env->SetLongField(obj" + j + ", JavaCPP_addressFID, ptr_to_jlong(ptr" + j + "));");
}
out.println(" }");
out.println(" args[" + j + "].l = obj" + j + ";");
} else if (callbackParameterTypes[j] == String.class) {
passesStrings = true;
out.println(" jstring obj" + j + " = JavaCPP_createString(env, (const char*)" + (adapterInfo != null ? "adapter" : "arg") + j + ");");
out.println(" args[" + j + "].l = obj" + j + ";");
} else if (callbackParameterTypes[j].isArray() && callbackParameterTypes[j].getComponentType().isPrimitive()) {
if (adapterInfo == null) {
out.println(" jlong size" + j + " = ptr" + j + " != NULL ? 1 : 0;");
}
String componentType = callbackParameterTypes[j].getComponentType().getName();
String S = Character.toUpperCase(componentType.charAt(0)) + componentType.substring(1);
out.println(" if (ptr" + j + " != NULL) {");
out.println(" obj" + j + " = env->New" + S + "Array(size" + j + " < INT_MAX ? size" + j + " : INT_MAX);");
out.println(" env->Set" + S + "ArrayRegion(obj" + j + ", 0, size" + j + " < INT_MAX ? size" + j + " : INT_MAX, (j" + componentType + "*)ptr" + j + ");");
out.println(" }");
if (adapterInfo != null) {
out.println(" if (deallocator" + j + " != 0 && ptr" + j + " != NULL) {");
out.println(" (*(void(*)(void*))jlong_to_ptr(deallocator" + j + "))((void*)ptr" + j + ");");
out.println(" }");
}
} else if (Buffer.class.isAssignableFrom(callbackParameterTypes[j])) {
if (adapterInfo == null) {
out.println(" jlong size" + j + " = ptr" + j + " != NULL ? 1 : 0;");
}
out.println(" if (ptr" + j + " != NULL) {");
out.println(" jlong sizeptr = size" + j + " * sizeof(ptr" + j + "[0]);");
out.println(" obj" + j + " = env->NewDirectByteBuffer((void*)ptr" + j + ", sizeptr < INT_MAX ? sizeptr : INT_MAX);");
out.println(" }");
} else {
logger.warn("Callback \"" + callbackMethod + "\" has unsupported parameter type \"" + callbackParameterTypes[j].getCanonicalName() + "\". Compilation will most likely fail.");
}
}
}
}
if (methodInfo != null) {
out.println(" if (" + fieldName + " == NULL) {");
out.println(" " + fieldName + " = JavaCPP_getMethodID(env, " + jclasses.index(cls) + ", \"" + methodInfo.method.getName() + "\", \"(" + signature(methodInfo.method.getParameterTypes()) + ")" + signature(methodInfo.method.getReturnType()) + "\");");
out.println(" }");
out.println(" jmethodID mid = " + fieldName + ";");
} else if (callbackName != null) {
out.println(" if (obj == NULL) {");
out.println(" obj = JavaCPP_createPointer(env, " + jclasses.index(cls) + ");");
out.println(" obj = obj == NULL ? NULL : env->NewGlobalRef(obj);");
out.println(" if (obj == NULL) {");
out.println(" JavaCPP_log(\"Error creating global reference of " + cls.getCanonicalName() + " instance for callback.\");");
out.println(" } else {");
out.println(" env->SetLongField(obj, JavaCPP_addressFID, ptr_to_jlong(this));");
out.println(" }");
out.println(" ptr = &" + callbackName + ";");
out.println(" }");
out.println(" if (mid == NULL) {");
out.println(" mid = JavaCPP_getMethodID(env, " + jclasses.index(cls) + ", \"" + callbackMethod.getName() + "\", \"(" + signature(callbackMethod.getParameterTypes()) + ")" + signature(callbackMethod.getReturnType()) + "\");");
out.println(" }");
}
out.println(" if (env->IsSameObject(obj, NULL)) {");
out.println(" JavaCPP_log(\"Function pointer object is NULL in callback for " + cls.getCanonicalName() + ".\");");
out.println(" } else if (mid == NULL) {");
out.println(" JavaCPP_log(\"Error getting method ID of function caller \\\"" + callbackMethod + "\\\" for callback.\");");
out.println(" } else {");
String s = "Object";
if (callbackReturnType.isPrimitive()) {
s = callbackReturnType.getName();
s = Character.toUpperCase(s.charAt(0)) + s.substring(1);
}
out.println(" " + returnPrefix + "env->Call" + s + "MethodA(obj, mid, " + (callbackParameterTypes.length == 0 ? "NULL);" : "args);"));
out.println(" if ((exc = env->ExceptionOccurred()) != NULL) {");
out.println(" env->ExceptionClear();");
out.println(" }");
out.println(" }");
for (int j = 0; j < callbackParameterTypes.length; j++) {
if (Pointer.class.isAssignableFrom(callbackParameterTypes[j])) {
String[] typeName = cppTypeName(callbackParameterTypes[j]);
Annotation passBy = by(callbackParameterAnnotations[j]);
String cast = cast(callbackParameterTypes[j], callbackParameterAnnotations[j]);
String valueTypeName = valueTypeName(typeName);
AdapterInformation adapterInfo = adapterInformation(true, valueTypeName, callbackParameterAnnotations[j]);
if ("void*".equals(typeName[0]) && !callbackParameterTypes[j].isAnnotationPresent(Opaque.class)) {
typeName[0] = "char*";
}
if (adapterInfo != null || passBy instanceof ByPtrPtr || passBy instanceof ByPtrRef) {
out.println(" " + typeName[0] + " rptr" + j + typeName[1] + " = (" + typeName[0] + typeName[1] + ")jlong_to_ptr(env->GetLongField(obj" + j + ", JavaCPP_addressFID));");
if (adapterInfo != null) {
out.println(" jlong rsize" + j + " = env->GetLongField(obj" + j + ", JavaCPP_limitFID);");
out.println(" void* rowner" + j + " = JavaCPP_getPointerOwner(env, obj" + j + ");");
}
if (!callbackParameterTypes[j].isAnnotationPresent(Opaque.class)) {
out.println(" jlong rposition" + j + " = env->GetLongField(obj" + j + ", JavaCPP_positionFID);");
out.println(" rptr" + j + " += rposition" + j + ";");
if (adapterInfo != null) {
out.println(" rsize" + j + " -= rposition" + j + ";");
}
}
if (adapterInfo != null) {
out.println(" adapter" + j + ".assign(rptr" + j + ", rsize" + j + ", rowner" + j + ");");
} else if (passBy instanceof ByPtrPtr) {
out.println(" if (arg" + j + " != NULL) {");
out.println(" *arg" + j + " = *" + cast + "&rptr" + j + ";");
out.println(" }");
} else if (passBy instanceof ByPtrRef) {
out.println(" arg" + j + " = " + cast + "rptr" + j + ";");
}
}
}
if (!callbackParameterTypes[j].isPrimitive()) {
out.println(" env->DeleteLocalRef(obj" + j + ");");
}
}
out.println("}");
out.println("end:");
if (callbackReturnType != void.class) {
if ("void*".equals(returnTypeName[0]) && !callbackReturnType.isAnnotationPresent(Opaque.class)) {
returnTypeName[0] = "char*";
}
if (Pointer.class.isAssignableFrom(callbackReturnType)) {
out.println(" " + returnTypeName[0] + " rptr" + returnTypeName[1] + " = rarg == NULL ? NULL : (" + returnTypeName[0] + returnTypeName[1] + ")jlong_to_ptr(env->GetLongField(rarg, JavaCPP_addressFID));");
if (returnAdapterInfo != null) {
out.println(" jlong rsize = rarg == NULL ? 0 : env->GetLongField(rarg, JavaCPP_limitFID);");
out.println(" void* rowner = JavaCPP_getPointerOwner(env, rarg);");
}
if (!callbackReturnType.isAnnotationPresent(Opaque.class)) {
out.println(" jlong rposition = rarg == NULL ? 0 : env->GetLongField(rarg, JavaCPP_positionFID);");
out.println(" rptr += rposition;");
if (returnAdapterInfo != null) {
out.println(" rsize -= rposition;");
}
}
} else if (callbackReturnType == String.class) {
passesStrings = true;
out.println(" " + returnTypeName[0] + " rptr" + returnTypeName[1] + " = JavaCPP_getStringBytes(env, rarg);");
if (returnAdapterInfo != null) {
out.println(" jlong rsize = 0;");
out.println(" void* rowner = (void*)rptr");
}
} else if (Buffer.class.isAssignableFrom(callbackReturnType)) {
out.println(" " + returnTypeName[0] + " rptr" + returnTypeName[1] + " = rarg == NULL ? NULL : env->GetDirectBufferAddress(rarg);");
if (returnAdapterInfo != null) {
out.println(" jlong rsize = rarg == NULL ? 0 : env->GetDirectBufferCapacity(rarg);");
out.println(" void* rowner = (void*)rptr;");
}
} else if (!callbackReturnType.isPrimitive()) {
logger.warn("Callback \"" + callbackMethod + "\" has unsupported return type \"" + callbackReturnType.getCanonicalName() + "\". Compilation will most likely fail.");
}
}
out.println(" if (exc != NULL) {");
out.println(" jstring str = (jstring)env->CallObjectMethod(exc, JavaCPP_toStringMID);");
out.println(" env->DeleteLocalRef(exc);");
out.println(" const char *msg = JavaCPP_getStringBytes(env, str);");
out.println(" JavaCPP_exception e(msg);");
out.println(" JavaCPP_releaseStringBytes(env, str, msg);");
out.println(" env->DeleteLocalRef(str);");
out.println(" JavaCPP_detach(attached);");
out.println(" throw e;");
out.println(" } else {");
out.println(" JavaCPP_detach(attached);");
out.println(" }");
if (callbackReturnType != void.class) {
if (callbackReturnType.isPrimitive()) {
out.println(" return " + callbackReturnCast + "rarg;");
} else if (returnAdapterInfo != null) {
usesAdapters = true;
out.println(" return " + returnAdapterInfo.name + "(" + callbackReturnCast + "rptr, rsize, rowner);");
} else if (FunctionPointer.class.isAssignableFrom(callbackReturnType)) {
functions.index(callbackReturnType);
out.println(" return " + callbackReturnCast + "(rptr == NULL ? NULL : rptr->ptr);");
} else if (returnBy instanceof ByVal || returnBy instanceof ByRef) {
out.println(" if (rptr == NULL) {");
out.println(" JavaCPP_log(\"Return pointer address is NULL in callback for " + cls.getCanonicalName() + ".\");");
out.println(" static " + returnConvention[0] + " empty" + returnTypeName[1] + ";");
out.println(" return empty;");
out.println(" } else {");
out.println(" return *" + callbackReturnCast + "rptr;");
out.println(" }");
} else if (returnBy instanceof ByPtrPtr) {
out.println(" return " + callbackReturnCast + "&rptr;");
} else {
// ByPtr || ByPtrRef
out.println(" return " + callbackReturnCast + "rptr;");
}
}
out.println("}");
}
Aggregations