Search in sources :

Example 21 with ToleranceStatistics

use of com.simiacryptus.mindseye.test.ToleranceStatistics in project MindsEye by SimiaCryptus.

the class CudaLayerTestBase method getPerformanceTester.

@Nullable
@Override
public ComponentTest<ToleranceStatistics> getPerformanceTester() {
    @Nullable ComponentTest<ToleranceStatistics> inner = new PerformanceTester().setBatches(testingBatchSize);
    return new ComponentTestBase<ToleranceStatistics>() {

        @Override
        protected void _free() {
            inner.freeRef();
            super._free();
        }

        @Override
        public ToleranceStatistics test(@Nonnull NotebookOutput log, Layer component, Tensor... inputPrototype) {
            @Nullable PrintStream apiLog = null;
            try {
                @Nonnull String logName = "cuda_" + log.getName() + "_perf.log";
                log.p(log.file((String) null, logName, "GPU Log"));
                apiLog = new PrintStream(log.file(logName));
                CudaSystem.addLog(apiLog);
                return inner.test(log, component, inputPrototype);
            } finally {
                if (null != apiLog) {
                    apiLog.close();
                    CudaSystem.apiLog.remove(apiLog);
                }
            }
        }
    };
}
Also used : ComponentTestBase(com.simiacryptus.mindseye.test.unit.ComponentTestBase) PrintStream(java.io.PrintStream) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) PerformanceTester(com.simiacryptus.mindseye.test.unit.PerformanceTester) Layer(com.simiacryptus.mindseye.lang.Layer) Nullable(javax.annotation.Nullable) Nullable(javax.annotation.Nullable)

Aggregations

ToleranceStatistics (com.simiacryptus.mindseye.test.ToleranceStatistics)21 Nonnull (javax.annotation.Nonnull)20 Layer (com.simiacryptus.mindseye.lang.Layer)19 Tensor (com.simiacryptus.mindseye.lang.Tensor)19 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)19 Nullable (javax.annotation.Nullable)19 Arrays (java.util.Arrays)16 IntStream (java.util.stream.IntStream)14 Logger (org.slf4j.Logger)14 LoggerFactory (org.slf4j.LoggerFactory)14 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)13 TensorList (com.simiacryptus.mindseye.lang.TensorList)13 Collectors (java.util.stream.Collectors)13 SimpleEval (com.simiacryptus.mindseye.test.SimpleEval)9 ReferenceCounting (com.simiacryptus.mindseye.lang.ReferenceCounting)7 SimpleResult (com.simiacryptus.mindseye.test.SimpleResult)7 ScalarStatistics (com.simiacryptus.util.data.ScalarStatistics)7 List (java.util.List)7 IntFunction (java.util.function.IntFunction)7 ConstantResult (com.simiacryptus.mindseye.lang.ConstantResult)6