use of com.simiacryptus.mindseye.test.ToleranceStatistics in project MindsEye by SimiaCryptus.
the class CudaLayerTestBase method getPerformanceTester.
@Nullable
@Override
public ComponentTest<ToleranceStatistics> getPerformanceTester() {
@Nullable ComponentTest<ToleranceStatistics> inner = new PerformanceTester().setBatches(testingBatchSize);
return new ComponentTestBase<ToleranceStatistics>() {
@Override
protected void _free() {
inner.freeRef();
super._free();
}
@Override
public ToleranceStatistics test(@Nonnull NotebookOutput log, Layer component, Tensor... inputPrototype) {
@Nullable PrintStream apiLog = null;
try {
@Nonnull String logName = "cuda_" + log.getName() + "_perf.log";
log.p(log.file((String) null, logName, "GPU Log"));
apiLog = new PrintStream(log.file(logName));
CudaSystem.addLog(apiLog);
return inner.test(log, component, inputPrototype);
} finally {
if (null != apiLog) {
apiLog.close();
CudaSystem.apiLog.remove(apiLog);
}
}
}
};
}
Aggregations