use of com.simiacryptus.mindseye.lang.ReferenceCountingBase in project MindsEye by SimiaCryptus.
the class CudaLayerTester method testNonstandardBounds.
/**
* Test nonstandard bounds tolerance statistics.
*
* @param log the log
* @param reference the reference
* @param inputPrototype the input prototype
* @return the tolerance statistics
*/
@Nonnull
public ToleranceStatistics testNonstandardBounds(final NotebookOutput log, @Nullable final Layer reference, @Nonnull final Tensor[] inputPrototype) {
log.h2("Irregular Input");
log.p("This layer should be able to accept non-dense inputs.");
return log.code(() -> {
Tensor[] randomized = Arrays.stream(inputPrototype).map(x -> x.map(v -> getRandom())).toArray(i -> new Tensor[i]);
logger.info("Input: " + Arrays.stream(randomized).map(Tensor::prettyPrint).collect(Collectors.toList()));
Precision precision = Precision.Double;
TensorList[] controlInput = CudaSystem.run(gpu -> {
return Arrays.stream(randomized).map(original -> {
TensorArray data = TensorArray.create(original);
CudaTensorList wrap = CudaTensorList.wrap(gpu.getTensor(data, precision, MemoryType.Managed, false), 1, original.getDimensions(), precision);
data.freeRef();
return wrap;
}).toArray(i -> new TensorList[i]);
}, 0);
@Nonnull final SimpleResult controlResult = CudaSystem.run(gpu -> {
return SimpleGpuEval.run(reference, gpu, controlInput);
}, 1);
final TensorList[] irregularInput = CudaSystem.run(gpu -> {
return Arrays.stream(randomized).map(original -> {
return buildIrregularCudaTensor(gpu, precision, original);
}).toArray(i -> new TensorList[i]);
}, 0);
@Nonnull final SimpleResult testResult = CudaSystem.run(gpu -> {
return SimpleGpuEval.run(reference, gpu, irregularInput);
}, 1);
try {
ToleranceStatistics compareOutput = compareOutput(controlResult, testResult);
ToleranceStatistics compareDerivatives = compareDerivatives(controlResult, testResult);
return compareDerivatives.combine(compareOutput);
} finally {
Arrays.stream(randomized).forEach(ReferenceCountingBase::freeRef);
Arrays.stream(controlInput).forEach(ReferenceCounting::freeRef);
Arrays.stream(irregularInput).forEach(x -> x.freeRef());
controlResult.freeRef();
testResult.freeRef();
}
});
}
Aggregations