use of com.simiacryptus.mindseye.test.SimpleGpuEval in project MindsEye by SimiaCryptus.
the class CudaLayerTester method testNonstandardBoundsBackprop.
/**
* Test nonstandard bounds backprop tolerance statistics.
*
* @param log the log
* @param layer the layer
* @param inputPrototype the input prototype
* @return the tolerance statistics
*/
@Nonnull
public ToleranceStatistics testNonstandardBoundsBackprop(final NotebookOutput log, @Nullable final Layer layer, @Nonnull final Tensor[] inputPrototype) {
log.h2("Irregular Backprop");
log.p("This layer should accept non-dense tensors as delta input.");
return log.code(() -> {
Tensor[] randomized = Arrays.stream(inputPrototype).map(x -> x.map(v -> getRandom())).toArray(i -> new Tensor[i]);
logger.info("Input: " + Arrays.stream(randomized).map(Tensor::prettyPrint).collect(Collectors.toList()));
Precision precision = Precision.Double;
TensorList[] controlInput = Arrays.stream(randomized).map(original -> {
return TensorArray.wrap(original);
}).toArray(i -> new TensorList[i]);
@Nonnull final SimpleResult testResult = CudaSystem.run(gpu -> {
TensorList[] copy = copy(controlInput);
SimpleResult result = new SimpleGpuEval(layer, gpu, copy) {
@Nonnull
@Override
public TensorList getFeedback(@Nonnull final TensorList original) {
Tensor originalTensor = original.get(0).mapAndFree(x -> 1);
CudaTensorList cudaTensorList = buildIrregularCudaTensor(gpu, precision, originalTensor);
originalTensor.freeRef();
return cudaTensorList;
}
}.call();
Arrays.stream(copy).forEach(ReferenceCounting::freeRef);
return result;
});
@Nonnull final SimpleResult controlResult = CudaSystem.run(gpu -> {
TensorList[] copy = copy(controlInput);
SimpleResult result = SimpleGpuEval.run(layer, gpu, copy);
Arrays.stream(copy).forEach(ReferenceCounting::freeRef);
return result;
}, 1);
try {
ToleranceStatistics compareOutput = compareOutput(controlResult, testResult);
ToleranceStatistics compareDerivatives = compareDerivatives(controlResult, testResult);
return compareDerivatives.combine(compareOutput);
} finally {
Arrays.stream(controlInput).forEach(ReferenceCounting::freeRef);
controlResult.freeRef();
testResult.freeRef();
}
});
}
Aggregations