Search in sources :

Example 36 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class ClassifyProblem method run.

@Nonnull
@Override
public ClassifyProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
    final Tensor[][] trainingData = getTrainingData(log);
    @Nonnull final DAGNetwork network = fwdFactory.imageToVector(log, categories);
    log.h3("Network Diagram");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(network)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.h3("Training");
    @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
    TestUtil.instrumentPerformance(supervisedNetwork);
    int initialSampleSize = Math.max(trainingData.length / 5, Math.min(10, trainingData.length / 2));
    @Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, initialSampleSize, getBatchSize()), new ArrayTrainable(trainingData, supervisedNetwork, getBatchSize()), monitor);
    log.code(() -> {
        trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
    });
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    try {
        @Nonnull String filename = log.getName() + "_" + ClassifyProblem.modelNo++ + "_plot.png";
        ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", log.file(filename));
        @Nonnull File file = new File(log.getResourceDir(), filename);
        log.appendFrontMatterProperty("result_plot", file.toString(), ";");
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    TestUtil.extractPerformance(log, supervisedNetwork);
    @Nonnull final String modelName = "classification_model_" + ClassifyProblem.modelNo++ + ".json";
    log.appendFrontMatterProperty("result_model", modelName, ";");
    log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));
    log.h3("Validation");
    log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
    log.code(() -> {
        return data.validationData().mapToDouble(labeledObject -> predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0).average().getAsDouble() * 100;
    });
    log.p("Let's examine some incorrectly predicted results in more detail:");
    log.code(() -> {
        try {
            @Nonnull final TableOutput table = new TableOutput();
            Lists.partition(data.validationData().collect(Collectors.toList()), 100).stream().flatMap(batch -> {
                @Nonnull TensorList batchIn = TensorArray.create(batch.stream().map(x -> x.data).toArray(i -> new Tensor[i]));
                TensorList batchOut = network.eval(new ConstantResult(batchIn)).getData();
                return IntStream.range(0, batchOut.length()).mapToObj(i -> toRow(log, batch.get(i), batchOut.get(i).getData()));
            }).filter(x -> null != x).limit(10).forEach(table::putRow);
            return table;
        } catch (@Nonnull final IOException e) {
            throw new RuntimeException(e);
        }
    });
    return this;
}
Also used : IntStream(java.util.stream.IntStream) Graphviz(guru.nidi.graphviz.engine.Graphviz) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Arrays(java.util.Arrays) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) Lists(com.google.common.collect.Lists) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) Format(guru.nidi.graphviz.engine.Format) LabeledObject(com.simiacryptus.util.test.LabeledObject) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) ImageIO(javax.imageio.ImageIO) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) Layer(com.simiacryptus.mindseye.lang.Layer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Collectors(java.util.stream.Collectors) File(java.io.File) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Comparator(java.util.Comparator) Nonnull(javax.annotation.Nonnull) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) IOException(java.io.IOException) TensorList(com.simiacryptus.mindseye.lang.TensorList) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) File(java.io.File) Nonnull(javax.annotation.Nonnull)

Example 37 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class BatchingTester method test.

/**
 * Test tolerance statistics.
 *
 * @param reference      the reference
 * @param inputPrototype the input prototype
 * @return the tolerance statistics
 */
@Nonnull
public ToleranceStatistics test(@Nullable final Layer reference, @Nonnull final Tensor[] inputPrototype) {
    if (null == reference)
        return new ToleranceStatistics();
    final TensorList[] inputTensorLists = Arrays.stream(inputPrototype).map(t -> TensorArray.wrap(IntStream.range(0, getBatchSize()).mapToObj(i -> t.map(v -> getRandom())).toArray(i -> new Tensor[i]))).toArray(i -> new TensorList[i]);
    @Nonnull final SimpleResult asABatch;
    final List<SimpleEval> oneAtATime;
    try {
        asABatch = SimpleListEval.run(reference, inputTensorLists);
        oneAtATime = IntStream.range(0, getBatchSize()).mapToObj(batch -> {
            Tensor[] inputTensors = IntStream.range(0, inputTensorLists.length).mapToObj(i -> inputTensorLists[i].get(batch)).toArray(i -> new Tensor[i]);
            @Nonnull SimpleEval eval = SimpleEval.run(reference, inputTensors);
            for (@Nonnull Tensor tensor : inputTensors) {
                tensor.freeRef();
            }
            return eval;
        }).collect(Collectors.toList());
    } finally {
        for (@Nonnull TensorList tensorList : inputTensorLists) {
            tensorList.freeRef();
        }
    }
    try {
        TensorList batchOutput = asABatch.getOutput();
        @Nonnull IntFunction<ToleranceStatistics> toleranceStatisticsIntFunction = batch -> {
            @Nullable Tensor batchTensor = batchOutput.get(batch);
            @Nonnull ToleranceStatistics accumulate = new ToleranceStatistics().accumulate(batchTensor.getData(), oneAtATime.get(batch).getOutput().getData());
            batchTensor.freeRef();
            return accumulate;
        };
        int batchLength = batchOutput.length();
        @Nonnull final ToleranceStatistics outputAgreement = IntStream.range(0, Math.min(getBatchSize(), batchLength)).mapToObj(toleranceStatisticsIntFunction).reduce((a, b) -> a.combine(b)).get();
        if (!(outputAgreement.absoluteTol.getMax() < tolerance)) {
            logger.info("Batch Output: " + batchOutput.stream().map(x -> {
                String str = x.prettyPrint();
                x.freeRef();
                return str;
            }).collect(Collectors.toList()));
            logger.info("Singular Output: " + oneAtATime.stream().map(x -> x.getOutput().prettyPrint()).collect(Collectors.toList()));
            throw new AssertionError("Output Corrupt: " + outputAgreement);
        }
        ToleranceStatistics derivativeAgreement = IntStream.range(0, Math.min(getBatchSize(), batchLength)).mapToObj(batch -> {
            IntFunction<ToleranceStatistics> statisticsFunction = input -> {
                @Nullable Tensor a = asABatch.getInputDerivative()[input].get(batch);
                Tensor b = oneAtATime.get(batch).getDerivative()[input];
                @Nonnull Tensor diff = a.minus(b);
                logger.info("Error: " + diff.prettyPrint());
                logger.info("Scalar Statistics: " + new ScalarStatistics().add(diff.getData()).getMetrics());
                double[][] points = Arrays.stream(diff.getData()).mapToObj(x -> new double[] { x }).toArray(i -> new double[i][]);
                // logger.info("Density: " + new DensityTree("x").setMinSplitFract(1e-8).setSplitSizeThreshold(2).new Node(points));
                diff.freeRef();
                @Nonnull ToleranceStatistics toleranceStatistics = new ToleranceStatistics().accumulate(a.getData(), b.getData());
                a.freeRef();
                return toleranceStatistics;
            };
            return IntStream.range(0, Math.min(inputPrototype.length, batchLength)).mapToObj(statisticsFunction).reduce((a, b) -> a.combine(b)).orElse(null);
        }).filter(x -> x != null).reduce((a, b) -> a.combine(b)).orElse(null);
        if (null != derivativeAgreement && !(derivativeAgreement.absoluteTol.getMax() < tolerance)) {
            throw new AssertionError("Derivatives Corrupt: " + derivativeAgreement);
        }
        return null != derivativeAgreement ? derivativeAgreement.combine(outputAgreement) : outputAgreement;
    } finally {
        asABatch.freeRef();
        oneAtATime.forEach(x -> x.freeRef());
    }
}
Also used : IntStream(java.util.stream.IntStream) SimpleResult(com.simiacryptus.mindseye.test.SimpleResult) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Collectors(java.util.stream.Collectors) List(java.util.List) SimpleListEval(com.simiacryptus.mindseye.test.SimpleListEval) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) IntFunction(java.util.function.IntFunction) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) SimpleResult(com.simiacryptus.mindseye.test.SimpleResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) IntFunction(java.util.function.IntFunction) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 38 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class SingleDerivativeTester method testUnFrozen.

/**
 * Test un frozen.
 *
 * @param component      the component
 * @param inputPrototype the input prototype
 */
public void testUnFrozen(@Nonnull final Layer component, Tensor[] inputPrototype) {
    inputPrototype = Arrays.stream(inputPrototype).map(tensor -> tensor.copy()).toArray(i -> new Tensor[i]);
    @Nonnull final AtomicBoolean reachedInputFeedback = new AtomicBoolean(false);
    @Nonnull final Layer frozen = component.copy().setFrozen(false);
    List<TensorArray> inputCopies = Arrays.stream(inputPrototype).map(TensorArray::wrap).collect(Collectors.toList());
    Result[] inputs = inputCopies.stream().map(tensor -> new Result(tensor, (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        reachedInputFeedback.set(true);
    }) {

        @Override
        public boolean isAlive() {
            return true;
        }
    }).toArray(i -> new Result[i]);
    @Nullable final Result eval;
    try {
        eval = frozen.eval(inputs);
    } finally {
        for (@Nonnull Result result : inputs) {
            result.freeRef();
        }
        for (@Nonnull TensorArray tensorArray : inputCopies) {
            tensorArray.freeRef();
        }
    }
    @Nonnull final DeltaSet<Layer> buffer = new DeltaSet<Layer>();
    TensorList tensorList = eval.getData();
    eval.accumulate(buffer, tensorList);
    eval.freeRef();
    @Nullable final List<double[]> stateList = frozen.state();
    final List<Delta<Layer>> deltas = stateList.stream().map(doubles -> {
        return buffer.stream().filter(x -> x.target == doubles).findFirst().orElse(null);
    }).filter(x -> x != null).collect(Collectors.toList());
    if (deltas.isEmpty() && !stateList.isEmpty()) {
        throw new AssertionError("Nonfrozen component not listed in delta. Deltas: " + deltas);
    }
    frozen.freeRef();
    buffer.freeRef();
    if (!reachedInputFeedback.get() && inputPrototype.length != 0) {
        throw new RuntimeException("Nonfrozen component did not pass input backwards");
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Result(com.simiacryptus.mindseye.lang.Result) Collectors(java.util.stream.Collectors) Delta(com.simiacryptus.mindseye.lang.Delta) List(java.util.List) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Optional(java.util.Optional) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) Delta(com.simiacryptus.mindseye.lang.Delta) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nullable(javax.annotation.Nullable)

Example 39 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class ImageClassifier method evaluatePrototype.

/**
 * Evaluate prototype tensor.
 *
 * @param layer         the layer
 * @param prevPrototype the prev prototype
 * @param cnt           the cnt
 * @return the tensor
 */
@Nonnull
protected static Tensor evaluatePrototype(@Nonnull final Layer layer, final Tensor prevPrototype, int cnt) {
    int numberOfParameters = layer.state().stream().mapToInt(x -> x.length).sum();
    @Nonnull int[] prev_dimensions = prevPrototype.getDimensions();
    Result eval = layer.eval(prevPrototype);
    TensorList newPrototype = eval.getData();
    if (null != prevPrototype)
        prevPrototype.freeRef();
    eval.freeRef();
    try {
        @Nonnull int[] new_dimensions = newPrototype.getDimensions();
        log.info(// 
        String.format(// 
        "Added layer #%d: %s; %s params, dimensions %s (%s) -> %s (%s)", // 
        cnt, // 
        layer, // 
        numberOfParameters, // 
        Arrays.toString(prev_dimensions), // 
        Tensor.length(prev_dimensions), Arrays.toString(new_dimensions), Tensor.length(new_dimensions)));
        return newPrototype.get(0);
    } finally {
        newPrototype.freeRef();
    }
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Arrays(java.util.Arrays) ActivationLayer(com.simiacryptus.mindseye.layers.cudnn.ActivationLayer) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) SimpleConvolutionLayer(com.simiacryptus.mindseye.layers.cudnn.SimpleConvolutionLayer) FullyConnectedLayer(com.simiacryptus.mindseye.layers.cudnn.FullyConnectedLayer) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) Function(java.util.function.Function) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) Trainable(com.simiacryptus.mindseye.eval.Trainable) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) Lists(com.google.common.collect.Lists) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Layer(com.simiacryptus.mindseye.lang.Layer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) NetworkFactory(com.simiacryptus.mindseye.models.NetworkFactory) Logger(org.slf4j.Logger) QQN(com.simiacryptus.mindseye.opt.orient.QQN) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Collectors(java.util.stream.Collectors) ConvolutionLayer(com.simiacryptus.mindseye.layers.cudnn.ConvolutionLayer) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) Explodable(com.simiacryptus.mindseye.layers.cudnn.Explodable) TensorList(com.simiacryptus.mindseye.lang.TensorList) LinearActivationLayer(com.simiacryptus.mindseye.layers.java.LinearActivationLayer) MultiPrecision(com.simiacryptus.mindseye.layers.cudnn.MultiPrecision) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Comparator(java.util.Comparator) BiasLayer(com.simiacryptus.mindseye.layers.java.BiasLayer) Nonnull(javax.annotation.Nonnull) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) Nonnull(javax.annotation.Nonnull)

Example 40 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class AvgReducerLayer method evalAndFree.

@Nullable
@Override
public Result evalAndFree(final Result... inObj) {
    if (!CudaSystem.isEnabled())
        return getCompatibilityLayer().evalAndFree(inObj);
    final Result input = inObj[0];
    final TensorList inputData = input.getData();
    @Nonnull final int[] inputSize = inputData.getDimensions();
    int length = inputData.length();
    CudaTensorList result = CudaSystem.run(gpu -> {
        CudaTensor inputTensor = gpu.getTensor(inputData, precision, MemoryType.Device, false);
        inputData.freeRef();
        CudaMemory inputMemory = inputTensor.getMemory(gpu);
        @Nonnull final CudaDevice.CudaTensorDescriptor outputDescriptor = gpu.newTensorDescriptor(precision, length, 1, 1, 1);
        long size = (long) precision.size * outputDescriptor.nStride * length;
        @Nonnull final CudaMemory outputMemory = gpu.allocate(size, MemoryType.Managed, true);
        CudaResource<cudnnReduceTensorDescriptor> reduceTensorDescriptor = gpu.cudnnCreateReduceTensorDescriptor(cudnnReduceTensorOp.CUDNN_REDUCE_TENSOR_AVG, precision.code, cudnnNanPropagation.CUDNN_NOT_PROPAGATE_NAN, cudnnReduceTensorIndices.CUDNN_REDUCE_TENSOR_NO_INDICES, cudnnIndicesType.CUDNN_32BIT_INDICES);
        @Nonnull final CudaMemory workspacePtr = gpu.allocate(inputMemory.size, MemoryType.Device, true);
        @Nonnull final CudaMemory indexPtr = gpu.allocate(12 * length, MemoryType.Device, false);
        // outputPtr.synchronize();
        gpu.cudnnReduceTensor(reduceTensorDescriptor.getPtr(), indexPtr.getPtr(), indexPtr.size, workspacePtr.getPtr(), workspacePtr.size, precision.getPointer(1.0), inputTensor.descriptor.getPtr(), inputMemory.getPtr(), precision.getPointer(0.0), outputDescriptor.getPtr(), outputMemory.getPtr());
        outputMemory.dirty();
        inputMemory.dirty();
        Stream.of(inputTensor, inputMemory, reduceTensorDescriptor, workspacePtr, indexPtr).forEach(ReferenceCounting::freeRef);
        return CudaTensorList.wrap(CudaTensor.wrap(outputMemory, outputDescriptor, precision), length, new int[] { 1, 1, 1 }, precision);
    });
    return new Result(result, (DeltaSet<Layer> ctx, TensorList delta) -> {
        // Not supported by CuDNN?
        // CudaTensorList passback = CudaSystem.run(gpu -> {
        // CudaTensor deltaTensor = gpu.getTensor(delta, precision, MemoryType.Device, false);
        // CudaMemory deltaMemory = deltaTensor.getMemory(gpu);
        // 
        // @Nonnull final CudaDevice.CudaTensorDescriptor passbackDescriptor1 = gpu.newTensorDescriptor(
        // precision, length, inputSize[2], inputSize[1], inputSize[0]
        // );
        // @Nonnull final CudaMemory passbackPtr1 = gpu.allocate((long) precision.size * passbackDescriptor1.nStride * length, MemoryType.Device, false);
        // gpu.cudnnAddTensor(precision.getPointer(1.0), deltaTensor.descriptor.getPtr(), deltaMemory.getPtr(),
        // precision.getPointer(1.0), passbackDescriptor1.getPtr(), passbackPtr1.getPtr());
        // passbackPtr1.dirty();
        // 
        // Stream.of(deltaTensor, deltaMemory, passbackDescriptor1, passbackPtr1).forEach(ReferenceCounting::freeRef);
        // return CudaTensorList.wrap(CudaTensor.wrap(passbackPtr1, passbackDescriptor1, precision), length, inputSize, precision);
        // });
        TensorList passback = TensorArray.wrap(IntStream.range(0, length).mapToObj(i -> {
            Tensor tensor = delta.get(i);
            Tensor tensor1 = new Tensor(inputSize).setAll((double) tensor.get(0) / Tensor.length(inputSize));
            tensor.freeRef();
            return tensor1;
        }).toArray(i -> new Tensor[i]));
        input.accumulate(ctx, passback);
    }) {

        @Override
        protected void _free() {
            super._free();
            input.freeRef();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) CudaMemory(com.simiacryptus.mindseye.lang.cudnn.CudaMemory) jcuda.jcudnn.cudnnReduceTensorDescriptor(jcuda.jcudnn.cudnnReduceTensorDescriptor) Tensor(com.simiacryptus.mindseye.lang.Tensor) jcuda.jcudnn.cudnnReduceTensorOp(jcuda.jcudnn.cudnnReduceTensorOp) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) CudaResource(com.simiacryptus.mindseye.lang.cudnn.CudaResource) CudaDevice(com.simiacryptus.mindseye.lang.cudnn.CudaDevice) CudaTensor(com.simiacryptus.mindseye.lang.cudnn.CudaTensor) CudaTensorList(com.simiacryptus.mindseye.lang.cudnn.CudaTensorList) jcuda.jcudnn.cudnnIndicesType(jcuda.jcudnn.cudnnIndicesType) jcuda.jcudnn.cudnnNanPropagation(jcuda.jcudnn.cudnnNanPropagation) jcuda.jcudnn.cudnnReduceTensorIndices(jcuda.jcudnn.cudnnReduceTensorIndices) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) Stream(java.util.stream.Stream) CudaSystem(com.simiacryptus.mindseye.lang.cudnn.CudaSystem) TensorList(com.simiacryptus.mindseye.lang.TensorList) MemoryType(com.simiacryptus.mindseye.lang.cudnn.MemoryType) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) CudaTensor(com.simiacryptus.mindseye.lang.cudnn.CudaTensor) Tensor(com.simiacryptus.mindseye.lang.Tensor) CudaTensor(com.simiacryptus.mindseye.lang.cudnn.CudaTensor) CudaDevice(com.simiacryptus.mindseye.lang.cudnn.CudaDevice) Nonnull(javax.annotation.Nonnull) jcuda.jcudnn.cudnnReduceTensorDescriptor(jcuda.jcudnn.cudnnReduceTensorDescriptor) CudaMemory(com.simiacryptus.mindseye.lang.cudnn.CudaMemory) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) CudaTensorList(com.simiacryptus.mindseye.lang.cudnn.CudaTensorList) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) CudaTensorList(com.simiacryptus.mindseye.lang.cudnn.CudaTensorList) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nullable(javax.annotation.Nullable)

Aggregations

TensorList (com.simiacryptus.mindseye.lang.TensorList)110 Nonnull (javax.annotation.Nonnull)109 Nullable (javax.annotation.Nullable)103 Result (com.simiacryptus.mindseye.lang.Result)95 Arrays (java.util.Arrays)93 Layer (com.simiacryptus.mindseye.lang.Layer)91 Tensor (com.simiacryptus.mindseye.lang.Tensor)88 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)87 IntStream (java.util.stream.IntStream)82 List (java.util.List)80 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)76 Map (java.util.Map)68 JsonObject (com.google.gson.JsonObject)64 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)63 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)61 Logger (org.slf4j.Logger)57 LoggerFactory (org.slf4j.LoggerFactory)57 ReferenceCounting (com.simiacryptus.mindseye.lang.ReferenceCounting)33 CudaTensor (com.simiacryptus.mindseye.lang.cudnn.CudaTensor)30 CudaTensorList (com.simiacryptus.mindseye.lang.cudnn.CudaTensorList)30