Search in sources :

Example 96 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class LinearActivationLayer method eval.

@Nonnull
@Override
public Result eval(final Result... inObj) {
    final Result in0 = inObj[0];
    final TensorList inData = in0.getData();
    in0.addRef();
    inData.addRef();
    final int itemCnt = inData.length();
    final double scale = weights.get(0);
    final double bias = weights.get(1);
    weights.addRef();
    return new Result(TensorArray.wrap(IntStream.range(0, itemCnt).mapToObj(dataIndex -> {
        @Nullable final Tensor input = inData.get(dataIndex);
        @Nullable Tensor map = input.map(v -> scale * v + bias);
        input.freeRef();
        return map;
    }).toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        if (!isFrozen()) {
            IntStream.range(0, delta.length()).forEach(dataIndex -> {
                @Nullable Tensor deltaT = delta.get(dataIndex);
                @Nullable Tensor inputT = inData.get(dataIndex);
                @Nullable final double[] deltaData = deltaT.getData();
                @Nullable final double[] inputData = inputT.getData();
                @Nonnull final Tensor weightDelta = new Tensor(weights.getDimensions());
                for (int i = 0; i < deltaData.length; i++) {
                    weightDelta.add(0, deltaData[i] * inputData[inputData.length == 1 ? 0 : i]);
                    weightDelta.add(1, deltaData[i]);
                }
                buffer.get(LinearActivationLayer.this, weights.getData()).addInPlace(weightDelta.getData()).freeRef();
                inputT.freeRef();
                deltaT.freeRef();
                weightDelta.freeRef();
            });
        }
        if (in0.isAlive()) {
            @Nonnull final TensorList tensorList = TensorArray.wrap(IntStream.range(0, delta.length()).mapToObj(dataIndex -> {
                @Nullable Tensor tensor = delta.get(dataIndex);
                @Nullable final double[] deltaData = tensor.getData();
                @Nonnull final Tensor passback = new Tensor(inData.getDimensions());
                for (int i = 0; i < passback.length(); i++) {
                    passback.set(i, deltaData[i] * weights.getData()[0]);
                }
                tensor.freeRef();
                return passback;
            }).toArray(i -> new Tensor[i]));
            in0.accumulate(buffer, tensorList);
        }
    }) {

        @Override
        public boolean isAlive() {
            return in0.isAlive() || !isFrozen();
        }

        @Override
        protected void _free() {
            weights.freeRef();
            inData.freeRef();
            in0.freeRef();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Example 97 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class LoggingWrapperLayer method eval.

@Override
public Result eval(@Nonnull final Result... inObj) {
    final Result[] wrappedInput = IntStream.range(0, inObj.length).mapToObj(i -> {
        final Result inputToWrap = inObj[i];
        inputToWrap.addRef();
        return new Result(inputToWrap.getData(), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
            @Nonnull final String formatted = data.stream().map(x -> {
                String str = x.prettyPrint();
                x.freeRef();
                return str;
            }).reduce((a, b) -> a + "\n" + b).get();
            log.info(String.format("Feedback Output %s for layer %s: \n\t%s", i, getInner().getName(), formatted.replaceAll("\n", "\n\t")));
            data.addRef();
            inputToWrap.accumulate(buffer, data);
        }) {

            @Override
            protected void _free() {
                inputToWrap.freeRef();
            }

            @Override
            public boolean isAlive() {
                return inputToWrap.isAlive();
            }
        };
    }).toArray(i -> new Result[i]);
    for (int i = 0; i < inObj.length; i++) {
        final TensorList tensorList = inObj[i].getData();
        @Nonnull final String formatted = tensorList.stream().map(x -> {
            String str = x.prettyPrint();
            x.freeRef();
            return str;
        }).reduce((a, b) -> a + "\n" + b).get();
        log.info(String.format("Input %s for layer %s: \n\t%s", i, getInner().getName(), formatted.replaceAll("\n", "\n\t")));
    }
    @Nullable final Result output = getInner().eval(wrappedInput);
    Arrays.stream(wrappedInput).forEach(ReferenceCounting::freeRef);
    {
        final TensorList tensorList = output.getData();
        @Nonnull final String formatted = tensorList.stream().map(x -> {
            String str = x.prettyPrint();
            x.freeRef();
            return str;
        }).reduce((a, b) -> a + "\n" + b).get();
        log.info(String.format("Output for layer %s: \n\t%s", getInner().getName(), formatted.replaceAll("\n", "\n\t")));
    }
    return new Result(output.getData(), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        @Nonnull final String formatted = data.stream().map(x -> {
            String str = x.prettyPrint();
            x.freeRef();
            return str;
        }).reduce((a, b) -> a + "\n" + b).get();
        log.info(String.format("Feedback Input for layer %s: \n\t%s", getInner().getName(), formatted.replaceAll("\n", "\n\t")));
        data.addRef();
        output.accumulate(buffer, data);
    }) {

        @Override
        protected void _free() {
            output.freeRef();
        }

        @Override
        public boolean isAlive() {
            return output.isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Result(com.simiacryptus.mindseye.lang.Result) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nullable(javax.annotation.Nullable)

Example 98 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class MaxDropoutNoiseLayer method eval.

@Nonnull
@Override
public Result eval(final Result... inObj) {
    final Result in0 = inObj[0];
    final TensorList data0 = in0.getData();
    final int itemCnt = data0.length();
    in0.addRef();
    data0.addRef();
    final Tensor[] mask = IntStream.range(0, itemCnt).mapToObj(dataIndex -> {
        @Nullable final Tensor input = data0.get(dataIndex);
        @Nullable final Tensor output = input.map(x -> 0);
        final List<List<Coordinate>> cells = getCellMap_cached.apply(new IntArray(output.getDimensions()));
        cells.forEach(cell -> {
            output.set(cell.stream().max(Comparator.comparingDouble(c -> input.get(c))).get(), 1);
        });
        input.freeRef();
        return output;
    }).toArray(i -> new Tensor[i]);
    return new Result(TensorArray.wrap(IntStream.range(0, itemCnt).mapToObj(dataIndex -> {
        Tensor inputData = data0.get(dataIndex);
        @Nullable final double[] input = inputData.getData();
        @Nullable final double[] maskT = mask[dataIndex].getData();
        @Nonnull final Tensor output = new Tensor(inputData.getDimensions());
        @Nullable final double[] outputData = output.getData();
        for (int i = 0; i < outputData.length; i++) {
            outputData[i] = input[i] * maskT[i];
        }
        inputData.freeRef();
        return output;
    }).toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        if (in0.isAlive()) {
            @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, delta.length()).mapToObj(dataIndex -> {
                Tensor deltaTensor = delta.get(dataIndex);
                @Nullable final double[] deltaData = deltaTensor.getData();
                @Nonnull final int[] dims = data0.getDimensions();
                @Nullable final double[] maskData = mask[dataIndex].getData();
                @Nonnull final Tensor passback = new Tensor(dims);
                for (int i = 0; i < passback.length(); i++) {
                    passback.set(i, maskData[i] * deltaData[i]);
                }
                deltaTensor.freeRef();
                return passback;
            }).toArray(i -> new Tensor[i]));
            in0.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            in0.freeRef();
            data0.freeRef();
            Arrays.stream(mask).forEach(ReferenceCounting::freeRef);
        }

        @Override
        public boolean isAlive() {
            return in0.isAlive() || !isFrozen();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) Arrays(java.util.Arrays) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) IntArray(com.simiacryptus.util.data.IntArray) Function(java.util.function.Function) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) ArrayList(java.util.ArrayList) JsonUtil(com.simiacryptus.util.io.JsonUtil) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) Logger(org.slf4j.Logger) Collectors(java.util.stream.Collectors) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Comparator(java.util.Comparator) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) IntArray(com.simiacryptus.util.data.IntArray) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) ArrayList(java.util.ArrayList) List(java.util.List) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 99 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class SoftmaxActivationLayer method eval.

@Nonnull
@Override
public Result eval(@Nonnull final Result... inObj) {
    final int itemCnt = inObj[0].getData().length();
    @Nonnull final double[] sumA = new double[itemCnt];
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    @Nonnull final Tensor[] expA = new Tensor[itemCnt];
    final Tensor[] outputA = IntStream.range(0, itemCnt).mapToObj(dataIndex -> {
        @Nullable final Tensor input = inObj[0].getData().get(dataIndex);
        assert 1 < input.length() : "input.length() = " + input.length();
        @Nullable final Tensor exp;
        final DoubleSummaryStatistics summaryStatistics = DoubleStream.of(input.getData()).filter(x -> Double.isFinite(x)).summaryStatistics();
        final double max = summaryStatistics.getMax();
        // final double min = summaryStatistics.getMin();
        exp = input.map(x -> {
            double xx = Math.exp(x - max);
            return Double.isFinite(xx) ? xx : 0;
        });
        input.freeRef();
        assert Arrays.stream(exp.getData()).allMatch(Double::isFinite);
        assert Arrays.stream(exp.getData()).allMatch(v -> v >= 0);
        // assert exp.sum() > 0;
        final double sum = 0 < exp.sum() ? exp.sum() : 1;
        assert Double.isFinite(sum);
        expA[dataIndex] = exp;
        sumA[dataIndex] = sum;
        @Nullable Tensor result = exp.map(x -> x / sum);
        return result;
    }).toArray(i -> new Tensor[i]);
    assert Arrays.stream(outputA).flatMapToDouble(x -> Arrays.stream(x.getData())).allMatch(v -> Double.isFinite(v));
    return new Result(TensorArray.wrap(outputA), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (inObj[0].isAlive()) {
            final Tensor[] passbackA = IntStream.range(0, itemCnt).mapToObj(dataIndex -> {
                Tensor deltaTensor = data.get(dataIndex);
                @Nullable final double[] delta = deltaTensor.getData();
                @Nullable final double[] expdata = expA[dataIndex].getData();
                @Nonnull final Tensor passback = new Tensor(data.getDimensions());
                final int dim = expdata.length;
                double dot = 0;
                for (int i = 0; i < expdata.length; i++) {
                    dot += delta[i] * expdata[i];
                }
                final double sum = sumA[dataIndex];
                for (int i = 0; i < dim; i++) {
                    double value = 0;
                    value = (sum * delta[i] - dot) * expdata[i] / (sum * sum);
                    passback.set(i, value);
                }
                deltaTensor.freeRef();
                return passback;
            }).toArray(i -> new Tensor[i]);
            assert Arrays.stream(passbackA).flatMapToDouble(x -> Arrays.stream(x.getData())).allMatch(v -> Double.isFinite(v));
            @Nonnull TensorArray tensorArray = TensorArray.wrap(passbackA);
            inObj[0].accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            Arrays.stream(expA).forEach(ReferenceCounting::freeRef);
            Arrays.stream(inObj).forEach(ReferenceCounting::freeRef);
        }

        @Override
        public boolean isAlive() {
            return inObj[0].isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) DoubleStream(java.util.stream.DoubleStream) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nonnull(javax.annotation.Nonnull)

Example 100 with TensorList

use of com.simiacryptus.mindseye.lang.TensorList in project MindsEye by SimiaCryptus.

the class StaticScalarLossLayer method eval.

@Nonnull
@Override
public Result eval(@Nonnull final Result... inObj) {
    if (1 != inObj.length)
        throw new IllegalArgumentException();
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    // if (inObj[0].getData().length() != 1) throw new IllegalArgumentException();
    final Result in0 = inObj[0];
    TensorList indata = in0.getData();
    indata.addRef();
    return new Result(TensorArray.wrap(IntStream.range(0, indata.length()).parallel().mapToObj(dataIndex -> {
        @Nullable final Tensor a = indata.get(dataIndex);
        final double diff = Math.abs(a.get(0) - getTarget());
        a.freeRef();
        return new Tensor(new double[] { diff }, 1);
    }).toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (in0.isAlive()) {
            @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, data.length()).parallel().mapToObj(dataIndex -> {
                @Nullable final Tensor a = indata.get(dataIndex);
                Tensor tensor = data.get(dataIndex);
                final double deriv = tensor.get(0) * (a.get(0) - getTarget() < 0 ? -1 : 1);
                tensor.freeRef();
                a.freeRef();
                return new Tensor(new double[] { deriv }, 1);
            }).toArray(i -> new Tensor[i]));
            in0.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            indata.freeRef();
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
        }

        @Override
        public boolean isAlive() {
            return in0.isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Aggregations

TensorList (com.simiacryptus.mindseye.lang.TensorList)110 Nonnull (javax.annotation.Nonnull)109 Nullable (javax.annotation.Nullable)103 Result (com.simiacryptus.mindseye.lang.Result)95 Arrays (java.util.Arrays)93 Layer (com.simiacryptus.mindseye.lang.Layer)91 Tensor (com.simiacryptus.mindseye.lang.Tensor)88 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)87 IntStream (java.util.stream.IntStream)82 List (java.util.List)80 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)76 Map (java.util.Map)68 JsonObject (com.google.gson.JsonObject)64 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)63 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)61 Logger (org.slf4j.Logger)57 LoggerFactory (org.slf4j.LoggerFactory)57 ReferenceCounting (com.simiacryptus.mindseye.lang.ReferenceCounting)33 CudaTensor (com.simiacryptus.mindseye.lang.cudnn.CudaTensor)30 CudaTensorList (com.simiacryptus.mindseye.lang.cudnn.CudaTensorList)30