Search in sources :

Example 1 with ToleranceStatistics

use of com.simiacryptus.mindseye.test.ToleranceStatistics in project MindsEye by SimiaCryptus.

the class ImgCropLayerTest method getPerformanceTester.

@Nullable
@Override
public ComponentTest<ToleranceStatistics> getPerformanceTester() {
    @Nonnull ComponentTest<ToleranceStatistics> inner = new PerformanceTester().setSamples(100).setBatches(10);
    return new ComponentTestBase<ToleranceStatistics>() {

        @Override
        public ToleranceStatistics test(@Nonnull NotebookOutput log, Layer component, Tensor... inputPrototype) {
            @Nullable PrintStream apiLog = null;
            try {
                apiLog = new PrintStream(log.file("cuda_perf.log"));
                CudaSystem.addLog(apiLog);
                return inner.test(log, component, inputPrototype);
            } finally {
                log.p(log.file((String) null, "cuda_perf.log", "GPU Log"));
                if (null != apiLog) {
                    apiLog.close();
                    CudaSystem.apiLog.remove(apiLog);
                }
            }
        }

        @Override
        protected void _free() {
            inner.freeRef();
            super._free();
        }
    };
}
Also used : ComponentTestBase(com.simiacryptus.mindseye.test.unit.ComponentTestBase) PrintStream(java.io.PrintStream) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) PerformanceTester(com.simiacryptus.mindseye.test.unit.PerformanceTester) Layer(com.simiacryptus.mindseye.lang.Layer) Nullable(javax.annotation.Nullable) Nullable(javax.annotation.Nullable)

Example 2 with ToleranceStatistics

use of com.simiacryptus.mindseye.test.ToleranceStatistics in project MindsEye by SimiaCryptus.

the class BatchingTester method test.

/**
 * Test tolerance statistics.
 *
 * @param reference      the reference
 * @param inputPrototype the input prototype
 * @return the tolerance statistics
 */
@Nonnull
public ToleranceStatistics test(@Nullable final Layer reference, @Nonnull final Tensor[] inputPrototype) {
    if (null == reference)
        return new ToleranceStatistics();
    final TensorList[] inputTensorLists = Arrays.stream(inputPrototype).map(t -> TensorArray.wrap(IntStream.range(0, getBatchSize()).mapToObj(i -> t.map(v -> getRandom())).toArray(i -> new Tensor[i]))).toArray(i -> new TensorList[i]);
    @Nonnull final SimpleResult asABatch;
    final List<SimpleEval> oneAtATime;
    try {
        asABatch = SimpleListEval.run(reference, inputTensorLists);
        oneAtATime = IntStream.range(0, getBatchSize()).mapToObj(batch -> {
            Tensor[] inputTensors = IntStream.range(0, inputTensorLists.length).mapToObj(i -> inputTensorLists[i].get(batch)).toArray(i -> new Tensor[i]);
            @Nonnull SimpleEval eval = SimpleEval.run(reference, inputTensors);
            for (@Nonnull Tensor tensor : inputTensors) {
                tensor.freeRef();
            }
            return eval;
        }).collect(Collectors.toList());
    } finally {
        for (@Nonnull TensorList tensorList : inputTensorLists) {
            tensorList.freeRef();
        }
    }
    try {
        TensorList batchOutput = asABatch.getOutput();
        @Nonnull IntFunction<ToleranceStatistics> toleranceStatisticsIntFunction = batch -> {
            @Nullable Tensor batchTensor = batchOutput.get(batch);
            @Nonnull ToleranceStatistics accumulate = new ToleranceStatistics().accumulate(batchTensor.getData(), oneAtATime.get(batch).getOutput().getData());
            batchTensor.freeRef();
            return accumulate;
        };
        int batchLength = batchOutput.length();
        @Nonnull final ToleranceStatistics outputAgreement = IntStream.range(0, Math.min(getBatchSize(), batchLength)).mapToObj(toleranceStatisticsIntFunction).reduce((a, b) -> a.combine(b)).get();
        if (!(outputAgreement.absoluteTol.getMax() < tolerance)) {
            logger.info("Batch Output: " + batchOutput.stream().map(x -> {
                String str = x.prettyPrint();
                x.freeRef();
                return str;
            }).collect(Collectors.toList()));
            logger.info("Singular Output: " + oneAtATime.stream().map(x -> x.getOutput().prettyPrint()).collect(Collectors.toList()));
            throw new AssertionError("Output Corrupt: " + outputAgreement);
        }
        ToleranceStatistics derivativeAgreement = IntStream.range(0, Math.min(getBatchSize(), batchLength)).mapToObj(batch -> {
            IntFunction<ToleranceStatistics> statisticsFunction = input -> {
                @Nullable Tensor a = asABatch.getInputDerivative()[input].get(batch);
                Tensor b = oneAtATime.get(batch).getDerivative()[input];
                @Nonnull Tensor diff = a.minus(b);
                logger.info("Error: " + diff.prettyPrint());
                logger.info("Scalar Statistics: " + new ScalarStatistics().add(diff.getData()).getMetrics());
                double[][] points = Arrays.stream(diff.getData()).mapToObj(x -> new double[] { x }).toArray(i -> new double[i][]);
                // logger.info("Density: " + new DensityTree("x").setMinSplitFract(1e-8).setSplitSizeThreshold(2).new Node(points));
                diff.freeRef();
                @Nonnull ToleranceStatistics toleranceStatistics = new ToleranceStatistics().accumulate(a.getData(), b.getData());
                a.freeRef();
                return toleranceStatistics;
            };
            return IntStream.range(0, Math.min(inputPrototype.length, batchLength)).mapToObj(statisticsFunction).reduce((a, b) -> a.combine(b)).orElse(null);
        }).filter(x -> x != null).reduce((a, b) -> a.combine(b)).orElse(null);
        if (null != derivativeAgreement && !(derivativeAgreement.absoluteTol.getMax() < tolerance)) {
            throw new AssertionError("Derivatives Corrupt: " + derivativeAgreement);
        }
        return null != derivativeAgreement ? derivativeAgreement.combine(outputAgreement) : outputAgreement;
    } finally {
        asABatch.freeRef();
        oneAtATime.forEach(x -> x.freeRef());
    }
}
Also used : IntStream(java.util.stream.IntStream) SimpleResult(com.simiacryptus.mindseye.test.SimpleResult) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Collectors(java.util.stream.Collectors) List(java.util.List) SimpleListEval(com.simiacryptus.mindseye.test.SimpleListEval) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) IntFunction(java.util.function.IntFunction) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) SimpleResult(com.simiacryptus.mindseye.test.SimpleResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) IntFunction(java.util.function.IntFunction) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 3 with ToleranceStatistics

use of com.simiacryptus.mindseye.test.ToleranceStatistics in project MindsEye by SimiaCryptus.

the class EquivalencyTester method test.

/**
 * Test tolerance statistics.
 *
 * @param subject        the subject
 * @param inputPrototype the input prototype
 * @return the tolerance statistics
 */
public ToleranceStatistics test(@Nullable final Layer subject, @Nonnull final Tensor[] inputPrototype) {
    if (null == reference || null == subject)
        return new ToleranceStatistics();
    reference.assertAlive();
    final Tensor subjectOutput = SimpleEval.run(subject, inputPrototype).getOutputAndFree();
    final Tensor referenceOutput = SimpleEval.run(reference, inputPrototype).getOutputAndFree();
    @Nonnull final Tensor error = subjectOutput.minus(referenceOutput);
    @Nonnull final ToleranceStatistics result = IntStream.range(0, subjectOutput.length()).mapToObj(i1 -> {
        return new ToleranceStatistics().accumulate(subjectOutput.getData()[i1], referenceOutput.getData()[i1]);
    }).reduce((a, b) -> a.combine(b)).get();
    try {
        try {
            if (!(result.absoluteTol.getMax() < tolerance))
                throw new AssertionError(result.toString());
        } catch (@Nonnull final Throwable e) {
            log.info(String.format("Inputs: %s", Arrays.stream(inputPrototype).map(t -> t.prettyPrint()).reduce((a, b) -> a + ",\n" + b)));
            log.info(String.format("Subject Output: %s", subjectOutput.prettyPrint()));
            log.info(String.format("Reference Output: %s", referenceOutput.prettyPrint()));
            log.info(String.format("Error: %s", error.prettyPrint()));
            System.out.flush();
            throw e;
        }
        log.info(String.format("Inputs: %s", Arrays.stream(inputPrototype).map(t -> t.prettyPrint()).reduce((a, b) -> a + ",\n" + b).get()));
        log.info(String.format("Error: %s", error.prettyPrint()));
        log.info(String.format("Accuracy:"));
        log.info(String.format("absoluteTol: %s", result.absoluteTol.toString()));
        log.info(String.format("relativeTol: %s", result.relativeTol.toString()));
        return result;
    } finally {
        subjectOutput.freeRef();
        referenceOutput.freeRef();
        error.freeRef();
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) LoggerFactory(org.slf4j.LoggerFactory) Layer(com.simiacryptus.mindseye.lang.Layer) Tensor(com.simiacryptus.mindseye.lang.Tensor) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) GsonBuilder(com.google.gson.GsonBuilder) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics)

Example 4 with ToleranceStatistics

use of com.simiacryptus.mindseye.test.ToleranceStatistics in project MindsEye by SimiaCryptus.

the class SingleDerivativeTester method testFeedback.

/**
 * Test feedback tolerance statistics.
 *
 * @param statistics      the statistics
 * @param component       the component
 * @param inputPrototype  the input prototype
 * @param outputPrototype the output prototype
 * @return the tolerance statistics
 */
@Nonnull
public ToleranceStatistics testFeedback(@Nonnull ToleranceStatistics statistics, @Nonnull Layer component, @Nonnull Tensor[] inputPrototype, @Nonnull Tensor outputPrototype) {
    Optional<ToleranceStatistics> optional = IntStream.range(0, inputPrototype.length).mapToObj(i -> {
        @Nullable final Tensor measuredGradient = !verify ? null : measureFeedbackGradient(component, i, outputPrototype, inputPrototype);
        @Nonnull final Tensor implementedGradient = getFeedbackGradient(component, i, outputPrototype, inputPrototype);
        @Nonnull Tensor difference = measuredGradient.minus(implementedGradient);
        try {
            final ToleranceStatistics result = IntStream.range(0, null == measuredGradient ? 0 : measuredGradient.length()).mapToObj(i1 -> {
                return new ToleranceStatistics().accumulate(measuredGradient.getData()[i1], implementedGradient.getData()[i1]);
            }).reduce((a, b) -> a.combine(b)).orElse(new ToleranceStatistics());
            if (!(result.absoluteTol.getMax() < tolerance))
                throw new AssertionError(result.toString());
            // log.info(String.format("Component: %s", component));
            if (verbose) {
                log.info(String.format("Feedback for input %s", i));
                log.info(String.format("Inputs Values: %s", inputPrototype[i].prettyPrint()));
                log.info(String.format("Value Statistics: %s", new ScalarStatistics().add(inputPrototype[i].getData())));
                log.info(String.format("Implemented Feedback: %s", implementedGradient.prettyPrint()));
                log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
                if (null != measuredGradient) {
                    log.info(String.format("Measured Feedback: %s", measuredGradient.prettyPrint()));
                    log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
                    log.info(String.format("Feedback Error: %s", difference.prettyPrint()));
                    log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(difference.getData())));
                }
            }
            difference.freeRef();
            measuredGradient.freeRef();
            implementedGradient.freeRef();
            return result;
        } catch (@Nonnull final Throwable e) {
            // log.info(String.format("Component: %s", component));
            log.info(String.format("Feedback for input %s", i));
            log.info(String.format("Inputs Values: %s", inputPrototype[i].prettyPrint()));
            log.info(String.format("Value Statistics: %s", new ScalarStatistics().add(inputPrototype[i].getData())));
            log.info(String.format("Implemented Feedback: %s", implementedGradient.prettyPrint()));
            log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
            if (null != measuredGradient) {
                log.info(String.format("Measured: %s", measuredGradient.prettyPrint()));
                log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
                log.info(String.format("Feedback Error: %s", difference.prettyPrint()));
                log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(difference.getData())));
            }
            measuredGradient.freeRef();
            implementedGradient.freeRef();
            difference.freeRef();
            throw e;
        }
    }).reduce((a, b) -> a.combine(b));
    if (!optional.isPresent())
        return statistics;
    return statistics.combine(optional.orElse(null));
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Result(com.simiacryptus.mindseye.lang.Result) Collectors(java.util.stream.Collectors) Delta(com.simiacryptus.mindseye.lang.Delta) List(java.util.List) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Optional(java.util.Optional) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) Nonnull(javax.annotation.Nonnull) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) Nonnull(javax.annotation.Nonnull)

Example 5 with ToleranceStatistics

use of com.simiacryptus.mindseye.test.ToleranceStatistics in project MindsEye by SimiaCryptus.

the class SingleDerivativeTester method testLearning.

/**
 * Test learning tolerance statistics.
 *
 * @param prev            the prev
 * @param component       the component
 * @param inputPrototype  the input prototype
 * @param outputPrototype the output prototype
 * @return the tolerance statistics
 */
public ToleranceStatistics testLearning(@Nonnull ToleranceStatistics prev, @Nonnull Layer component, Tensor[] inputPrototype, @Nonnull Tensor outputPrototype) {
    return IntStream.range(0, component.state().size()).mapToObj(i -> {
        @Nullable final Tensor measuredGradient = !verify ? null : measureLearningGradient(component, i, outputPrototype, inputPrototype);
        @Nonnull final Tensor implementedGradient = getLearningGradient(component, i, outputPrototype, inputPrototype);
        @Nonnull Tensor difference = measuredGradient.minus(implementedGradient);
        try {
            final ToleranceStatistics result = IntStream.range(0, null == measuredGradient ? 0 : measuredGradient.length()).mapToObj(i1 -> {
                return new ToleranceStatistics().accumulate(measuredGradient.getData()[i1], implementedGradient.getData()[i1]);
            }).reduce((a, b) -> a.combine(b)).orElse(new ToleranceStatistics());
            if (!(result.absoluteTol.getMax() < tolerance)) {
                throw new AssertionError(result.toString());
            } else {
                // log.info(String.format("Component: %s", component));
                if (verbose) {
                    log.info(String.format("Learning Gradient for weight setByCoord %s", i));
                    log.info(String.format("Weights: %s", Tensor.prettyPrint(component.state().get(i))));
                    log.info(String.format("Implemented Gradient: %s", implementedGradient.prettyPrint()));
                    log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
                    if (null != measuredGradient) {
                        log.info(String.format("Measured Gradient: %s", measuredGradient.prettyPrint()));
                        log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
                        log.info(String.format("Gradient Error: %s", difference.prettyPrint()));
                        log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(difference.getData())));
                    }
                }
                difference.freeRef();
                return result;
            }
        } catch (@Nonnull final Throwable e) {
            // log.info(String.format("Component: %s", component));
            log.info(String.format("Learning Gradient for weight setByCoord %s", i));
            log.info(String.format("Implemented Gradient: %s", implementedGradient.prettyPrint()));
            log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
            if (null != measuredGradient) {
                log.info(String.format("Measured Gradient: %s", measuredGradient.prettyPrint()));
                log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
                log.info(String.format("Gradient Error: %s", difference.prettyPrint()));
                log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(difference.getData())));
            }
            difference.freeRef();
            throw e;
        } finally {
            measuredGradient.freeRef();
            implementedGradient.freeRef();
        }
    }).reduce((a, b) -> a.combine(b)).map(x -> x.combine(prev)).orElseGet(() -> prev);
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Result(com.simiacryptus.mindseye.lang.Result) Collectors(java.util.stream.Collectors) Delta(com.simiacryptus.mindseye.lang.Delta) List(java.util.List) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Optional(java.util.Optional) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) Nonnull(javax.annotation.Nonnull) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics)

Aggregations

ToleranceStatistics (com.simiacryptus.mindseye.test.ToleranceStatistics)21 Nonnull (javax.annotation.Nonnull)20 Layer (com.simiacryptus.mindseye.lang.Layer)19 Tensor (com.simiacryptus.mindseye.lang.Tensor)19 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)19 Nullable (javax.annotation.Nullable)19 Arrays (java.util.Arrays)16 IntStream (java.util.stream.IntStream)14 Logger (org.slf4j.Logger)14 LoggerFactory (org.slf4j.LoggerFactory)14 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)13 TensorList (com.simiacryptus.mindseye.lang.TensorList)13 Collectors (java.util.stream.Collectors)13 SimpleEval (com.simiacryptus.mindseye.test.SimpleEval)9 ReferenceCounting (com.simiacryptus.mindseye.lang.ReferenceCounting)7 SimpleResult (com.simiacryptus.mindseye.test.SimpleResult)7 ScalarStatistics (com.simiacryptus.util.data.ScalarStatistics)7 List (java.util.List)7 IntFunction (java.util.function.IntFunction)7 ConstantResult (com.simiacryptus.mindseye.lang.ConstantResult)6