Search in sources :

Example 16 with Tensor

use of com.simiacryptus.mindseye.lang.Tensor in project MindsEye by SimiaCryptus.

the class ExplodedConvolutionLeg method write.

/**
 * Write exploded convolution leg.
 *
 * @param filter the kernel
 * @return the exploded convolution leg
 */
@Nonnull
public ExplodedConvolutionLeg write(@Nonnull Tensor filter) {
    int inputBands = getInputBands();
    @Nonnull final int[] filterDimensions = Arrays.copyOf(this.convolutionParams.masterFilterDimensions, this.convolutionParams.masterFilterDimensions.length);
    int outputBands = this.convolutionParams.outputBands;
    int squareOutputBands = (int) (Math.ceil(convolutionParams.outputBands * 1.0 / inputBands) * inputBands);
    assert squareOutputBands >= convolutionParams.outputBands : String.format("%d >= %d", squareOutputBands, convolutionParams.outputBands);
    assert squareOutputBands % inputBands == 0 : String.format("%d %% %d", squareOutputBands, inputBands);
    filterDimensions[2] = inputBands * outputBands;
    assert Arrays.equals(filter.getDimensions(), filterDimensions) : Arrays.toString(filter.getDimensions()) + " != " + Arrays.toString(filterDimensions);
    final int inputBandsSq = inputBands * inputBands;
    IntStream.range(0, subLayers.size()).parallel().forEach(layerNumber -> {
        final int filterBandOffset = layerNumber * inputBandsSq;
        @Nonnull Tensor kernel = new Tensor(filterDimensions[0], filterDimensions[1], inputBandsSq).setByCoord(c -> {
            int[] coords = c.getCoords();
            int filterBand = getFilterBand(filterBandOffset, coords[2], squareOutputBands);
            if (filterBand < filterDimensions[2]) {
                return filter.get(coords[0], coords[1], filterBand);
            } else {
                return 0;
            }
        }, true);
        subKernels.get(layerNumber).set(kernel);
        kernel.freeRef();
    });
    return this;
}
Also used : Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) Nonnull(javax.annotation.Nonnull)

Example 17 with Tensor

use of com.simiacryptus.mindseye.lang.Tensor in project MindsEye by SimiaCryptus.

the class CudaMemory method read.

/**
 * From device double tensor.
 *
 * @param precision  the precision
 * @param dimensions the dimensions  @return the tensor
 * @return the tensor
 */
@Nonnull
public Tensor read(@Nonnull final Precision precision, final int[] dimensions) {
    synchronize();
    @Nonnull final Tensor tensor = new Tensor(dimensions);
    switch(precision) {
        case Float:
            final int length = tensor.length();
            @Nonnull final float[] data = new float[length];
            read(precision, data);
            @Nullable final double[] doubles = tensor.getData();
            for (int i = 0; i < length; i++) {
                doubles[i] = data[i];
            }
            break;
        case Double:
            read(precision, tensor.getData());
            break;
        default:
            throw new IllegalStateException();
    }
    return tensor;
}
Also used : Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 18 with Tensor

use of com.simiacryptus.mindseye.lang.Tensor in project MindsEye by SimiaCryptus.

the class CudaTensorList method addAndFree.

@Override
public TensorList addAndFree(@Nonnull final TensorList right) {
    assertAlive();
    right.assertAlive();
    if (right instanceof ReshapedTensorList)
        return addAndFree(((ReshapedTensorList) right).getInner());
    if (1 < currentRefCount()) {
        TensorList sum = add(right);
        freeRef();
        return sum;
    }
    assert length() == right.length();
    if (heapCopy == null) {
        if (right instanceof CudaTensorList) {
            @Nonnull final CudaTensorList nativeRight = (CudaTensorList) right;
            if (nativeRight.getPrecision() == this.getPrecision()) {
                if (nativeRight.heapCopy == null) {
                    assert (!nativeRight.gpuCopy.equals(CudaTensorList.this.gpuCopy));
                    CudaMemory rightMem = gpuCopy.memory;
                    CudaMemory leftMem = rightMem;
                    if (null != leftMem && null != rightMem)
                        return CudaSystem.run(gpu -> {
                            if (gpu.getDeviceId() == leftMem.getDeviceId()) {
                                return gpu.addInPlace(this, nativeRight);
                            } else {
                                assertAlive();
                                right.assertAlive();
                                TensorList add = add(right);
                                freeRef();
                                return add;
                            }
                        }, this, right);
                }
            }
        }
    }
    if (right.length() == 0)
        return this;
    if (length() == 0)
        throw new IllegalArgumentException();
    assert length() == right.length();
    return TensorArray.wrap(IntStream.range(0, length()).mapToObj(i -> {
        Tensor a = get(i);
        Tensor b = right.get(i);
        @Nullable Tensor r = a.addAndFree(b);
        b.freeRef();
        return r;
    }).toArray(i -> new Tensor[i]));
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) TestUtil(com.simiacryptus.mindseye.test.TestUtil) ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) Stream(java.util.stream.Stream) RegisteredObjectBase(com.simiacryptus.mindseye.lang.RegisteredObjectBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) TimedResult(com.simiacryptus.util.lang.TimedResult) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable)

Example 19 with Tensor

use of com.simiacryptus.mindseye.lang.Tensor in project MindsEye by SimiaCryptus.

the class CudaTensorList method add.

@Override
public TensorList add(@Nonnull final TensorList right) {
    assertAlive();
    right.assertAlive();
    assert length() == right.length();
    if (right instanceof ReshapedTensorList)
        return add(((ReshapedTensorList) right).getInner());
    if (heapCopy == null) {
        if (right instanceof CudaTensorList) {
            @Nonnull final CudaTensorList nativeRight = (CudaTensorList) right;
            if (nativeRight.getPrecision() == this.getPrecision()) {
                if (nativeRight.heapCopy == null) {
                    return CudaSystem.run(gpu -> {
                        return gpu.add(this, nativeRight);
                    }, this);
                }
            }
        }
    }
    if (right.length() == 0)
        return this;
    if (length() == 0)
        throw new IllegalArgumentException();
    assert length() == right.length();
    return TensorArray.wrap(IntStream.range(0, length()).mapToObj(i -> {
        Tensor a = get(i);
        Tensor b = right.get(i);
        @Nullable Tensor r = a.addAndFree(b);
        b.freeRef();
        return r;
    }).toArray(i -> new Tensor[i]));
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) TestUtil(com.simiacryptus.mindseye.test.TestUtil) ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) Stream(java.util.stream.Stream) RegisteredObjectBase(com.simiacryptus.mindseye.lang.RegisteredObjectBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) TimedResult(com.simiacryptus.util.lang.TimedResult) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable)

Example 20 with Tensor

use of com.simiacryptus.mindseye.lang.Tensor in project MindsEye by SimiaCryptus.

the class BiasMetaLayer method eval.

@Nullable
@Override
public Result eval(@Nonnull final Result... inObj) {
    final int itemCnt = inObj[0].getData().length();
    Tensor tensor1 = inObj[1].getData().get(0);
    final Tensor[] tensors = IntStream.range(0, itemCnt).parallel().mapToObj(dataIndex -> {
        Tensor tensor = inObj[0].getData().get(dataIndex);
        Tensor mapIndex = tensor.mapIndex((v, c) -> {
            return v + tensor1.get(c);
        });
        tensor.freeRef();
        return mapIndex;
    }).toArray(i -> new Tensor[i]);
    tensor1.freeRef();
    Tensor tensor0 = tensors[0];
    tensor0.addRef();
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    return new Result(TensorArray.wrap(tensors), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (inObj[0].isAlive()) {
            data.addRef();
            inObj[0].accumulate(buffer, data);
        }
        if (inObj[1].isAlive()) {
            @Nonnull final ToDoubleFunction<Coordinate> f = (c) -> {
                return IntStream.range(0, itemCnt).mapToDouble(i -> {
                    Tensor tensor = data.get(i);
                    double v = tensor.get(c);
                    tensor.freeRef();
                    return v;
                }).sum();
            };
            @Nullable final Tensor passback = tensor0.mapCoords(f);
            @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, inObj[1].getData().length()).mapToObj(i -> {
                if (i == 0)
                    return passback;
                else {
                    @Nullable Tensor map = passback.map(v -> 0);
                    passback.freeRef();
                    return map;
                }
            }).toArray(i -> new Tensor[i]));
            inObj[1].accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            tensor0.freeRef();
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
        }

        @Override
        public boolean isAlive() {
            return inObj[0].isAlive() || inObj[1].isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) ToDoubleFunction(java.util.function.ToDoubleFunction) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nullable(javax.annotation.Nullable) Nullable(javax.annotation.Nullable)

Aggregations

Tensor (com.simiacryptus.mindseye.lang.Tensor)183 Nonnull (javax.annotation.Nonnull)172 Nullable (javax.annotation.Nullable)137 Layer (com.simiacryptus.mindseye.lang.Layer)126 Arrays (java.util.Arrays)119 IntStream (java.util.stream.IntStream)109 List (java.util.List)108 Result (com.simiacryptus.mindseye.lang.Result)96 TensorList (com.simiacryptus.mindseye.lang.TensorList)96 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)90 Logger (org.slf4j.Logger)81 LoggerFactory (org.slf4j.LoggerFactory)81 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)80 Map (java.util.Map)72 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)67 JsonObject (com.google.gson.JsonObject)59 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)56 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)56 Collectors (java.util.stream.Collectors)51 Stream (java.util.stream.Stream)42