Search in sources :

Example 51 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class BatchDerivativeTester method testUnFrozen.

/**
 * Test un frozen.
 *
 * @param component      the component
 * @param inputPrototype the input prototype
 */
public void testUnFrozen(@Nonnull final Layer component, final Tensor[] inputPrototype) {
    @Nonnull final AtomicBoolean reachedInputFeedback = new AtomicBoolean(false);
    @Nonnull final Layer frozen = component.copy().setFrozen(false);
    @Nullable final Result eval = frozen.eval(new Result(TensorArray.create(inputPrototype), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        reachedInputFeedback.set(true);
    }) {

        @Override
        public boolean isAlive() {
            return true;
        }
    });
    @Nonnull final DeltaSet<Layer> buffer = new DeltaSet<Layer>();
    TensorList data = eval.getData();
    eval.accumulate(buffer, data);
    @Nullable final List<double[]> stateList = frozen.state();
    final List<Delta<Layer>> deltas = stateList.stream().map(doubles -> {
        return buffer.stream().filter(x -> x.target == doubles).findFirst().orElse(null);
    }).filter(x -> x != null).collect(Collectors.toList());
    if (deltas.isEmpty() && !stateList.isEmpty()) {
        throw new AssertionError("Nonfrozen component not listed in delta. Deltas: " + deltas);
    }
    if (!reachedInputFeedback.get()) {
        throw new RuntimeException("Nonfrozen component did not pass input backwards");
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Result(com.simiacryptus.mindseye.lang.Result) Collectors(java.util.stream.Collectors) Delta(com.simiacryptus.mindseye.lang.Delta) List(java.util.List) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) Delta(com.simiacryptus.mindseye.lang.Delta) Nullable(javax.annotation.Nullable)

Example 52 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class SumMetaLayer method eval.

@Nullable
@Override
public Result eval(@Nonnull final Result... inObj) {
    final Result input = inObj[0];
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    final int itemCnt = input.getData().length();
    if (null == lastResult || minBatches < itemCnt) {
        @Nonnull final ToDoubleFunction<Coordinate> f = (c) -> IntStream.range(0, itemCnt).mapToDouble(dataIndex -> input.getData().get(dataIndex).get(c)).sum();
        lastResult = input.getData().get(0).mapCoords(f);
    }
    return new Result(TensorArray.wrap(lastResult), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (input.isAlive()) {
            @Nullable final Tensor delta = data.get(0);
            @Nonnull final Tensor[] feedback = new Tensor[itemCnt];
            Arrays.parallelSetAll(feedback, i -> new Tensor(delta.getDimensions()));
            @Nonnull final ToDoubleFunction<Coordinate> f = (inputCoord) -> {
                for (int inputItem = 0; inputItem < itemCnt; inputItem++) {
                    feedback[inputItem].add(inputCoord, delta.get(inputCoord));
                }
                return 0;
            };
            delta.mapCoords(f);
            @Nonnull TensorArray tensorArray = TensorArray.wrap(feedback);
            input.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
        }

        @Override
        public boolean isAlive() {
            return input.isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) ToDoubleFunction(java.util.function.ToDoubleFunction) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nullable(javax.annotation.Nullable)

Example 53 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class ProductInputsLayer method eval.

@Nonnull
@Override
public Result eval(@Nonnull final Result... inObj) {
    assert inObj.length > 1;
    Arrays.stream(inObj).forEach(x -> x.getData().addRef());
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    for (int i = 1; i < inObj.length; i++) {
        final int dim0 = Tensor.length(inObj[0].getData().getDimensions());
        final int dimI = Tensor.length(inObj[i].getData().getDimensions());
        if (dim0 != 1 && dimI != 1 && dim0 != dimI) {
            throw new IllegalArgumentException(Arrays.toString(inObj[0].getData().getDimensions()) + " != " + Arrays.toString(inObj[i].getData().getDimensions()));
        }
    }
    return new Result(Arrays.stream(inObj).parallel().map(x -> {
        TensorList data = x.getData();
        data.addRef();
        return data;
    }).reduce((l, r) -> {
        TensorArray productArray = TensorArray.wrap(IntStream.range(0, Math.max(l.length(), r.length())).parallel().mapToObj(i1 -> {
            @Nullable final Tensor left = l.get(1 == l.length() ? 0 : i1);
            @Nullable final Tensor right = r.get(1 == r.length() ? 0 : i1);
            Tensor product = Tensor.product(left, right);
            left.freeRef();
            right.freeRef();
            return product;
        }).toArray(i -> new Tensor[i]));
        l.freeRef();
        r.freeRef();
        return productArray;
    }).get(), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        for (@Nonnull final Result input : inObj) {
            if (input.isAlive()) {
                @Nonnull TensorList passback = Arrays.stream(inObj).parallel().map(x -> {
                    TensorList tensorList = x == input ? delta : x.getData();
                    tensorList.addRef();
                    return tensorList;
                }).reduce((l, r) -> {
                    TensorArray productList = TensorArray.wrap(IntStream.range(0, Math.max(l.length(), r.length())).parallel().mapToObj(j -> {
                        @Nullable final Tensor left = l.get(1 == l.length() ? 0 : j);
                        @Nullable final Tensor right = r.get(1 == r.length() ? 0 : j);
                        Tensor product = Tensor.product(left, right);
                        left.freeRef();
                        right.freeRef();
                        return product;
                    }).toArray(j -> new Tensor[j]));
                    l.freeRef();
                    r.freeRef();
                    return productList;
                }).get();
                final TensorList inputData = input.getData();
                if (1 == inputData.length() && 1 < passback.length()) {
                    TensorArray newValue = TensorArray.wrap(passback.stream().reduce((a, b) -> {
                        @Nullable Tensor c = a.addAndFree(b);
                        b.freeRef();
                        return c;
                    }).get());
                    passback.freeRef();
                    passback = newValue;
                }
                if (1 == Tensor.length(inputData.getDimensions()) && 1 < Tensor.length(passback.getDimensions())) {
                    TensorArray newValue = TensorArray.wrap(passback.stream().map((a) -> {
                        @Nonnull Tensor b = new Tensor(a.sum());
                        a.freeRef();
                        return b;
                    }).toArray(i -> new Tensor[i]));
                    passback.freeRef();
                    passback = newValue;
                }
                input.accumulate(buffer, passback);
            }
        }
    }) {

        @Override
        public boolean isAlive() {
            for (@Nonnull final Result element : inObj) if (element.isAlive()) {
                return true;
            }
            return false;
        }

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
            Arrays.stream(inObj).forEach(x -> x.getData().freeRef());
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Example 54 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class ReLuActivationLayer method eval.

@Nonnull
@Override
public Result eval(final Result... inObj) {
    final Result input = inObj[0];
    final TensorList indata = input.getData();
    input.addRef();
    indata.addRef();
    weights.addRef();
    final int itemCnt = indata.length();
    return new Result(TensorArray.wrap(IntStream.range(0, itemCnt).parallel().mapToObj(dataIndex -> {
        @Nullable Tensor tensorElement = indata.get(dataIndex);
        @Nonnull final Tensor tensor = tensorElement.multiply(weights.get(0));
        tensorElement.freeRef();
        @Nullable final double[] outputData = tensor.getData();
        for (int i = 0; i < outputData.length; i++) {
            if (outputData[i] < 0) {
                outputData[i] = 0;
            }
        }
        return tensor;
    }).toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        if (!isFrozen()) {
            IntStream.range(0, delta.length()).parallel().forEach(dataIndex -> {
                @Nullable Tensor deltaTensor = delta.get(dataIndex);
                @Nullable final double[] deltaData = deltaTensor.getData();
                @Nullable Tensor inputTensor = indata.get(dataIndex);
                @Nullable final double[] inputData = inputTensor.getData();
                @Nonnull final Tensor weightDelta = new Tensor(weights.getDimensions());
                @Nullable final double[] weightDeltaData = weightDelta.getData();
                for (int i = 0; i < deltaData.length; i++) {
                    weightDeltaData[0] += inputData[i] < 0 ? 0 : deltaData[i] * inputData[i];
                }
                buffer.get(ReLuActivationLayer.this, weights.getData()).addInPlace(weightDeltaData).freeRef();
                deltaTensor.freeRef();
                inputTensor.freeRef();
                weightDelta.freeRef();
            });
        }
        if (input.isAlive()) {
            final double weight = weights.getData()[0];
            @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, delta.length()).parallel().mapToObj(dataIndex -> {
                @Nullable Tensor deltaTensor = delta.get(dataIndex);
                @Nullable final double[] deltaData = deltaTensor.getData();
                @Nullable Tensor inTensor = indata.get(dataIndex);
                @Nullable final double[] inputData = inTensor.getData();
                @Nonnull final int[] dims = inTensor.getDimensions();
                @Nonnull final Tensor passback = new Tensor(dims);
                for (int i = 0; i < passback.length(); i++) {
                    passback.set(i, inputData[i] < 0 ? 0 : deltaData[i] * weight);
                }
                inTensor.freeRef();
                deltaTensor.freeRef();
                return passback;
            }).toArray(i -> new Tensor[i]));
            input.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            input.freeRef();
            indata.freeRef();
            weights.freeRef();
        }

        @Override
        public boolean isAlive() {
            return input.isAlive() || !isFrozen();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Util(com.simiacryptus.util.Util) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) DoubleSupplier(java.util.function.DoubleSupplier) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Example 55 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class RescaledSubnetLayer method eval.

@Nullable
@Override
public Result eval(@Nonnull final Result... inObj) {
    assert 1 == inObj.length;
    final TensorList batch = inObj[0].getData();
    @Nonnull final int[] inputDims = batch.getDimensions();
    assert 3 == inputDims.length;
    if (1 == scale)
        return subnetwork.eval(inObj);
    @Nonnull final PipelineNetwork network = new PipelineNetwork();
    @Nullable final DAGNode condensed = network.wrap(new ImgReshapeLayer(scale, scale, false));
    network.wrap(new ImgConcatLayer(), IntStream.range(0, scale * scale).mapToObj(subband -> {
        @Nonnull final int[] select = new int[inputDims[2]];
        for (int i = 0; i < inputDims[2]; i++) {
            select[i] = subband * inputDims[2] + i;
        }
        return network.add(subnetwork, network.wrap(new ImgBandSelectLayer(select), condensed));
    }).toArray(i -> new DAGNode[i]));
    network.wrap(new ImgReshapeLayer(scale, scale, true));
    Result eval = network.eval(inObj);
    network.freeRef();
    return eval;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Result(com.simiacryptus.mindseye.lang.Result) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) ArrayList(java.util.ArrayList) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) ImgConcatLayer(com.simiacryptus.mindseye.layers.cudnn.ImgConcatLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) ImgConcatLayer(com.simiacryptus.mindseye.layers.cudnn.ImgConcatLayer) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) TensorList(com.simiacryptus.mindseye.lang.TensorList) DAGNode(com.simiacryptus.mindseye.network.DAGNode) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nullable(javax.annotation.Nullable)

Aggregations

TensorList (com.simiacryptus.mindseye.lang.TensorList)110 Nonnull (javax.annotation.Nonnull)109 Nullable (javax.annotation.Nullable)103 Result (com.simiacryptus.mindseye.lang.Result)95 Arrays (java.util.Arrays)93 Layer (com.simiacryptus.mindseye.lang.Layer)91 Tensor (com.simiacryptus.mindseye.lang.Tensor)88 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)87 IntStream (java.util.stream.IntStream)82 List (java.util.List)80 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)76 Map (java.util.Map)68 JsonObject (com.google.gson.JsonObject)64 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)63 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)61 Logger (org.slf4j.Logger)57 LoggerFactory (org.slf4j.LoggerFactory)57 ReferenceCounting (com.simiacryptus.mindseye.lang.ReferenceCounting)33 CudaTensor (com.simiacryptus.mindseye.lang.cudnn.CudaTensor)30 CudaTensorList (com.simiacryptus.mindseye.lang.cudnn.CudaTensorList)30