Search in sources :

Example 11 with Tensor

use of com.simiacryptus.mindseye.lang.Tensor in project MindsEye by SimiaCryptus.

the class ImgBandScaleLayer method eval.

/**
 * Eval nn result.
 *
 * @param input the input
 * @return the nn result
 */
@Nonnull
public Result eval(@Nonnull final Result input) {
    @Nullable final double[] weights = getWeights();
    final TensorList inData = input.getData();
    inData.addRef();
    input.addRef();
    @Nullable Function<Tensor, Tensor> tensorTensorFunction = tensor -> {
        if (tensor.getDimensions().length != 3) {
            throw new IllegalArgumentException(Arrays.toString(tensor.getDimensions()));
        }
        if (tensor.getDimensions()[2] != weights.length) {
            throw new IllegalArgumentException(String.format("%s: %s does not have %s bands", getName(), Arrays.toString(tensor.getDimensions()), weights.length));
        }
        @Nullable Tensor tensor1 = tensor.mapCoords(c -> tensor.get(c) * weights[c.getCoords()[2]]);
        tensor.freeRef();
        return tensor1;
    };
    Tensor[] data = inData.stream().parallel().map(tensorTensorFunction).toArray(i -> new Tensor[i]);
    return new Result(TensorArray.wrap(data), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        if (!isFrozen()) {
            final Delta<Layer> deltaBuffer = buffer.get(ImgBandScaleLayer.this, weights);
            IntStream.range(0, delta.length()).forEach(index -> {
                @Nonnull int[] dimensions = delta.getDimensions();
                int z = dimensions[2];
                int y = dimensions[1];
                int x = dimensions[0];
                final double[] array = RecycleBin.DOUBLES.obtain(z);
                Tensor deltaTensor = delta.get(index);
                @Nullable final double[] deltaArray = deltaTensor.getData();
                Tensor inputTensor = inData.get(index);
                @Nullable final double[] inputData = inputTensor.getData();
                for (int i = 0; i < z; i++) {
                    for (int j = 0; j < y * x; j++) {
                        // array[i] += deltaArray[i + z * j];
                        array[i] += deltaArray[i * x * y + j] * inputData[i * x * y + j];
                    }
                }
                inputTensor.freeRef();
                deltaTensor.freeRef();
                assert Arrays.stream(array).allMatch(v -> Double.isFinite(v));
                deltaBuffer.addInPlace(array);
                RecycleBin.DOUBLES.recycle(array, array.length);
            });
            deltaBuffer.freeRef();
        }
        if (input.isAlive()) {
            Tensor[] tensors = delta.stream().map(t -> {
                @Nullable Tensor tensor = t.mapCoords((c) -> t.get(c) * weights[c.getCoords()[2]]);
                t.freeRef();
                return tensor;
            }).toArray(i -> new Tensor[i]);
            @Nonnull TensorArray tensorArray = TensorArray.wrap(tensors);
            input.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            inData.freeRef();
            input.freeRef();
        }

        @Override
        public boolean isAlive() {
            return input.isAlive() || !isFrozen();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Util(com.simiacryptus.util.Util) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) IntToDoubleFunction(java.util.function.IntToDoubleFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) Function(java.util.function.Function) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) JsonUtil(com.simiacryptus.util.io.JsonUtil) Delta(com.simiacryptus.mindseye.lang.Delta) RecycleBin(com.simiacryptus.mindseye.lang.RecycleBin) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) DoubleSupplier(java.util.function.DoubleSupplier) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 12 with Tensor

use of com.simiacryptus.mindseye.lang.Tensor in project MindsEye by SimiaCryptus.

the class ImgConcatLayer method eval.

@Nullable
@Override
public Result eval(@Nonnull final Result... inObj) {
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    assert Arrays.stream(inObj).allMatch(x -> x.getData().getDimensions().length == 3) : "This component is for use mapCoords 3d image tensors only";
    final int numBatches = inObj[0].getData().length();
    assert Arrays.stream(inObj).allMatch(x -> x.getData().length() == numBatches) : "All inputs must use same batch size";
    @Nonnull final int[] outputDims = Arrays.copyOf(inObj[0].getData().getDimensions(), 3);
    outputDims[2] = Arrays.stream(inObj).mapToInt(x -> x.getData().getDimensions()[2]).sum();
    if (maxBands > 0)
        outputDims[2] = Math.min(maxBands, outputDims[2]);
    assert Arrays.stream(inObj).allMatch(x -> x.getData().getDimensions()[0] == outputDims[0]) : "Inputs must be same size";
    assert Arrays.stream(inObj).allMatch(x -> x.getData().getDimensions()[1] == outputDims[1]) : "Inputs must be same size";
    @Nonnull final List<Tensor> outputTensors = new ArrayList<>();
    for (int b = 0; b < numBatches; b++) {
        @Nonnull final Tensor outputTensor = new Tensor(outputDims);
        int pos = 0;
        @Nullable final double[] outputTensorData = outputTensor.getData();
        for (int i = 0; i < inObj.length; i++) {
            @Nullable Tensor tensor = inObj[i].getData().get(b);
            @Nullable final double[] data = tensor.getData();
            System.arraycopy(data, 0, outputTensorData, pos, Math.min(data.length, outputTensorData.length - pos));
            pos += data.length;
            tensor.freeRef();
        }
        outputTensors.add(outputTensor);
    }
    return new Result(TensorArray.wrap(outputTensors.toArray(new Tensor[] {})), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        assert numBatches == data.length();
        @Nonnull final List<Tensor[]> splitBatches = new ArrayList<>();
        for (int b = 0; b < numBatches; b++) {
            @Nullable final Tensor tensor = data.get(b);
            @Nonnull final Tensor[] outputTensors2 = new Tensor[inObj.length];
            int pos = 0;
            for (int i = 0; i < inObj.length; i++) {
                @Nonnull final Tensor dest = new Tensor(inObj[i].getData().getDimensions());
                @Nullable double[] tensorData = tensor.getData();
                System.arraycopy(tensorData, pos, dest.getData(), 0, Math.min(dest.length(), tensorData.length - pos));
                pos += dest.length();
                outputTensors2[i] = dest;
            }
            tensor.freeRef();
            splitBatches.add(outputTensors2);
        }
        @Nonnull final Tensor[][] splitData = new Tensor[inObj.length][];
        for (int i = 0; i < splitData.length; i++) {
            splitData[i] = new Tensor[numBatches];
        }
        for (int i = 0; i < inObj.length; i++) {
            for (int b = 0; b < numBatches; b++) {
                splitData[i][b] = splitBatches.get(b)[i];
            }
        }
        for (int i = 0; i < inObj.length; i++) {
            @Nonnull TensorArray tensorArray = TensorArray.wrap(splitData[i]);
            inObj[i].accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
        }

        @Override
        public boolean isAlive() {
            for (@Nonnull final Result element : inObj) if (element.isAlive()) {
                return true;
            }
            return false;
        }
    };
}
Also used : Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nullable(javax.annotation.Nullable) Nullable(javax.annotation.Nullable)

Example 13 with Tensor

use of com.simiacryptus.mindseye.lang.Tensor in project MindsEye by SimiaCryptus.

the class ValueLayer method getJson.

@Nonnull
@Override
public JsonObject getJson(Map<CharSequence, byte[]> resources, @Nonnull DataSerializer dataSerializer) {
    @Nonnull final JsonObject json = super.getJsonStub();
    Tensor tensor = tensorList.get(0);
    json.add("value", tensor.toJson(resources, dataSerializer));
    tensor.freeRef();
    json.addProperty("precision", precision.name());
    return json;
}
Also used : CudaTensor(com.simiacryptus.mindseye.lang.cudnn.CudaTensor) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) JsonObject(com.google.gson.JsonObject) Nonnull(javax.annotation.Nonnull)

Example 14 with Tensor

use of com.simiacryptus.mindseye.lang.Tensor in project MindsEye by SimiaCryptus.

the class AvgMetaLayer method eval.

@Nonnull
@Override
public Result eval(final Result... inObj) {
    final Result input = inObj[0];
    input.addRef();
    TensorList inputData = input.getData();
    final int itemCnt = inputData.length();
    @Nullable Tensor thisResult;
    boolean passback;
    if (null == lastResult || inputData.length() > minBatchCount) {
        @Nonnull final ToDoubleFunction<Coordinate> f = (c) -> IntStream.range(0, itemCnt).mapToDouble(dataIndex -> {
            Tensor tensor = inputData.get(dataIndex);
            double v = tensor.get(c);
            tensor.freeRef();
            return v;
        }).sum() / itemCnt;
        Tensor tensor = inputData.get(0);
        thisResult = tensor.mapCoords(f);
        tensor.freeRef();
        passback = true;
        if (null != lastResult)
            lastResult.freeRef();
        lastResult = thisResult;
        lastResult.addRef();
    } else {
        passback = false;
        thisResult = lastResult;
        thisResult.freeRef();
    }
    return new Result(TensorArray.create(thisResult), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (passback && input.isAlive()) {
            @Nullable final Tensor delta = data.get(0);
            @Nonnull final Tensor[] feedback = new Tensor[itemCnt];
            Arrays.parallelSetAll(feedback, i -> new Tensor(delta.getDimensions()));
            thisResult.coordStream(true).forEach((inputCoord) -> {
                for (int inputItem = 0; inputItem < itemCnt; inputItem++) {
                    feedback[inputItem].add(inputCoord, delta.get(inputCoord) / itemCnt);
                }
            });
            delta.freeRef();
            @Nonnull TensorArray tensorArray = TensorArray.wrap(feedback);
            input.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        public boolean isAlive() {
            return input.isAlive();
        }

        @Override
        protected void _free() {
            thisResult.freeRef();
            input.freeRef();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) ToDoubleFunction(java.util.function.ToDoubleFunction) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 15 with Tensor

use of com.simiacryptus.mindseye.lang.Tensor in project MindsEye by SimiaCryptus.

the class ExplodedConvolutionLeg method read.

/**
 * Read tensor.
 *
 * @return the tensor
 */
@Nonnull
public Tensor read() {
    return read((sublayer) -> {
        Tensor kernel = sublayer.kernel;
        kernel.addRef();
        return kernel;
    });
}
Also used : Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull)

Aggregations

Tensor (com.simiacryptus.mindseye.lang.Tensor)183 Nonnull (javax.annotation.Nonnull)172 Nullable (javax.annotation.Nullable)137 Layer (com.simiacryptus.mindseye.lang.Layer)126 Arrays (java.util.Arrays)119 IntStream (java.util.stream.IntStream)109 List (java.util.List)108 Result (com.simiacryptus.mindseye.lang.Result)96 TensorList (com.simiacryptus.mindseye.lang.TensorList)96 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)90 Logger (org.slf4j.Logger)81 LoggerFactory (org.slf4j.LoggerFactory)81 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)80 Map (java.util.Map)72 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)67 JsonObject (com.google.gson.JsonObject)59 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)56 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)56 Collectors (java.util.stream.Collectors)51 Stream (java.util.stream.Stream)42