Search in sources :

Example 41 with DeltaSet

use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.

the class ProductInputsLayer method eval.

@Nonnull
@Override
public Result eval(@Nonnull final Result... inObj) {
    assert inObj.length > 1;
    Arrays.stream(inObj).forEach(x -> x.getData().addRef());
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    for (int i = 1; i < inObj.length; i++) {
        final int dim0 = Tensor.length(inObj[0].getData().getDimensions());
        final int dimI = Tensor.length(inObj[i].getData().getDimensions());
        if (dim0 != 1 && dimI != 1 && dim0 != dimI) {
            throw new IllegalArgumentException(Arrays.toString(inObj[0].getData().getDimensions()) + " != " + Arrays.toString(inObj[i].getData().getDimensions()));
        }
    }
    return new Result(Arrays.stream(inObj).parallel().map(x -> {
        TensorList data = x.getData();
        data.addRef();
        return data;
    }).reduce((l, r) -> {
        TensorArray productArray = TensorArray.wrap(IntStream.range(0, Math.max(l.length(), r.length())).parallel().mapToObj(i1 -> {
            @Nullable final Tensor left = l.get(1 == l.length() ? 0 : i1);
            @Nullable final Tensor right = r.get(1 == r.length() ? 0 : i1);
            Tensor product = Tensor.product(left, right);
            left.freeRef();
            right.freeRef();
            return product;
        }).toArray(i -> new Tensor[i]));
        l.freeRef();
        r.freeRef();
        return productArray;
    }).get(), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        for (@Nonnull final Result input : inObj) {
            if (input.isAlive()) {
                @Nonnull TensorList passback = Arrays.stream(inObj).parallel().map(x -> {
                    TensorList tensorList = x == input ? delta : x.getData();
                    tensorList.addRef();
                    return tensorList;
                }).reduce((l, r) -> {
                    TensorArray productList = TensorArray.wrap(IntStream.range(0, Math.max(l.length(), r.length())).parallel().mapToObj(j -> {
                        @Nullable final Tensor left = l.get(1 == l.length() ? 0 : j);
                        @Nullable final Tensor right = r.get(1 == r.length() ? 0 : j);
                        Tensor product = Tensor.product(left, right);
                        left.freeRef();
                        right.freeRef();
                        return product;
                    }).toArray(j -> new Tensor[j]));
                    l.freeRef();
                    r.freeRef();
                    return productList;
                }).get();
                final TensorList inputData = input.getData();
                if (1 == inputData.length() && 1 < passback.length()) {
                    TensorArray newValue = TensorArray.wrap(passback.stream().reduce((a, b) -> {
                        @Nullable Tensor c = a.addAndFree(b);
                        b.freeRef();
                        return c;
                    }).get());
                    passback.freeRef();
                    passback = newValue;
                }
                if (1 == Tensor.length(inputData.getDimensions()) && 1 < Tensor.length(passback.getDimensions())) {
                    TensorArray newValue = TensorArray.wrap(passback.stream().map((a) -> {
                        @Nonnull Tensor b = new Tensor(a.sum());
                        a.freeRef();
                        return b;
                    }).toArray(i -> new Tensor[i]));
                    passback.freeRef();
                    passback = newValue;
                }
                input.accumulate(buffer, passback);
            }
        }
    }) {

        @Override
        public boolean isAlive() {
            for (@Nonnull final Result element : inObj) if (element.isAlive()) {
                return true;
            }
            return false;
        }

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
            Arrays.stream(inObj).forEach(x -> x.getData().freeRef());
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Example 42 with DeltaSet

use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.

the class ReLuActivationLayer method eval.

@Nonnull
@Override
public Result eval(final Result... inObj) {
    final Result input = inObj[0];
    final TensorList indata = input.getData();
    input.addRef();
    indata.addRef();
    weights.addRef();
    final int itemCnt = indata.length();
    return new Result(TensorArray.wrap(IntStream.range(0, itemCnt).parallel().mapToObj(dataIndex -> {
        @Nullable Tensor tensorElement = indata.get(dataIndex);
        @Nonnull final Tensor tensor = tensorElement.multiply(weights.get(0));
        tensorElement.freeRef();
        @Nullable final double[] outputData = tensor.getData();
        for (int i = 0; i < outputData.length; i++) {
            if (outputData[i] < 0) {
                outputData[i] = 0;
            }
        }
        return tensor;
    }).toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        if (!isFrozen()) {
            IntStream.range(0, delta.length()).parallel().forEach(dataIndex -> {
                @Nullable Tensor deltaTensor = delta.get(dataIndex);
                @Nullable final double[] deltaData = deltaTensor.getData();
                @Nullable Tensor inputTensor = indata.get(dataIndex);
                @Nullable final double[] inputData = inputTensor.getData();
                @Nonnull final Tensor weightDelta = new Tensor(weights.getDimensions());
                @Nullable final double[] weightDeltaData = weightDelta.getData();
                for (int i = 0; i < deltaData.length; i++) {
                    weightDeltaData[0] += inputData[i] < 0 ? 0 : deltaData[i] * inputData[i];
                }
                buffer.get(ReLuActivationLayer.this, weights.getData()).addInPlace(weightDeltaData).freeRef();
                deltaTensor.freeRef();
                inputTensor.freeRef();
                weightDelta.freeRef();
            });
        }
        if (input.isAlive()) {
            final double weight = weights.getData()[0];
            @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, delta.length()).parallel().mapToObj(dataIndex -> {
                @Nullable Tensor deltaTensor = delta.get(dataIndex);
                @Nullable final double[] deltaData = deltaTensor.getData();
                @Nullable Tensor inTensor = indata.get(dataIndex);
                @Nullable final double[] inputData = inTensor.getData();
                @Nonnull final int[] dims = inTensor.getDimensions();
                @Nonnull final Tensor passback = new Tensor(dims);
                for (int i = 0; i < passback.length(); i++) {
                    passback.set(i, inputData[i] < 0 ? 0 : deltaData[i] * weight);
                }
                inTensor.freeRef();
                deltaTensor.freeRef();
                return passback;
            }).toArray(i -> new Tensor[i]));
            input.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            input.freeRef();
            indata.freeRef();
            weights.freeRef();
        }

        @Override
        public boolean isAlive() {
            return input.isAlive() || !isFrozen();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Util(com.simiacryptus.util.Util) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) DoubleSupplier(java.util.function.DoubleSupplier) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Example 43 with DeltaSet

use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.

the class ScaleMetaLayer method eval.

@Nullable
@Override
public Result eval(@Nonnull final Result... inObj) {
    final int itemCnt = inObj[0].getData().length();
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    Arrays.stream(inObj).forEach(x -> x.getData().addRef());
    final Tensor[] tensors = IntStream.range(0, itemCnt).mapToObj(dataIndex -> inObj[0].getData().get(dataIndex).mapIndex((v, c) -> v * inObj[1].getData().get(0).get(c))).toArray(i -> new Tensor[i]);
    Tensor tensor0 = tensors[0];
    tensor0.addRef();
    return new Result(TensorArray.wrap(tensors), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (inObj[0].isAlive()) {
            @Nonnull TensorArray tensorArray = TensorArray.wrap(data.stream().map(t -> {
                @Nullable Tensor t1 = inObj[1].getData().get(0);
                @Nullable Tensor tensor = t.mapIndex((v, c) -> {
                    return v * t1.get(c);
                });
                t.freeRef();
                t1.freeRef();
                return tensor;
            }).toArray(i -> new Tensor[i]));
            inObj[0].accumulate(buffer, tensorArray);
        }
        if (inObj[1].isAlive()) {
            @Nullable final Tensor passback = tensor0.mapIndex((v, c) -> {
                return IntStream.range(0, itemCnt).mapToDouble(i -> data.get(i).get(c) * inObj[0].getData().get(i).get(c)).sum();
            });
            tensor0.freeRef();
            @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, inObj[1].getData().length()).mapToObj(i -> i == 0 ? passback : passback.map(v -> 0)).toArray(i -> new Tensor[i]));
            inObj[1].accumulate(buffer, tensorArray);
        }
        Arrays.stream(inObj).forEach(x -> x.getData().addRef());
    }) {

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
        }

        @Override
        public boolean isAlive() {
            return inObj[0].isAlive() || inObj[1].isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nullable(javax.annotation.Nullable)

Example 44 with DeltaSet

use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.

the class QuantifyOrientationWrapper method orient.

@Override
public LineSearchCursor orient(final Trainable subject, final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
    final LineSearchCursor cursor = inner.orient(subject, measurement, monitor);
    if (cursor instanceof SimpleLineSearchCursor) {
        final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) cursor).direction;
        @Nonnull final StateSet<Layer> weights = ((SimpleLineSearchCursor) cursor).origin.weights;
        final Map<CharSequence, CharSequence> dataMap = weights.stream().collect(Collectors.groupingBy(x -> getId(x), Collectors.toList())).entrySet().stream().collect(Collectors.toMap(x -> x.getKey(), list -> {
            final List<Double> doubleList = list.getValue().stream().map(weightDelta -> {
                final DoubleBuffer<Layer> dirDelta = direction.getMap().get(weightDelta.layer);
                final double denominator = weightDelta.deltaStatistics().rms();
                final double numerator = null == dirDelta ? 0 : dirDelta.deltaStatistics().rms();
                return numerator / (0 == denominator ? 1 : denominator);
            }).collect(Collectors.toList());
            if (1 == doubleList.size())
                return Double.toString(doubleList.get(0));
            return new DoubleStatistics().accept(doubleList.stream().mapToDouble(x -> x).toArray()).toString();
        }));
        monitor.log(String.format("Line search stats: %s", dataMap));
    } else {
        monitor.log(String.format("Non-simple cursor: %s", cursor));
    }
    return cursor;
}
Also used : DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Collectors(java.util.stream.Collectors) StateSet(com.simiacryptus.mindseye.lang.StateSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) List(java.util.List) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nonnull(javax.annotation.Nonnull) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) List(java.util.List)

Example 45 with DeltaSet

use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.

the class TrustRegionStrategy method orient.

@Nonnull
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, final PointSample origin, final TrainingMonitor monitor) {
    history.add(0, origin);
    while (history.size() > maxHistory) {
        history.remove(history.size() - 1);
    }
    final SimpleLineSearchCursor cursor = inner.orient(subject, origin, monitor);
    return new LineSearchCursorBase() {

        @Nonnull
        @Override
        public CharSequence getDirectionType() {
            return cursor.getDirectionType() + "+Trust";
        }

        @Nonnull
        @Override
        public DeltaSet<Layer> position(final double alpha) {
            reset();
            @Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
            project(adjustedPosVector, new TrainingMonitor());
            return adjustedPosVector;
        }

        @Nonnull
        public DeltaSet<Layer> project(@Nonnull final DeltaSet<Layer> deltaIn, final TrainingMonitor monitor) {
            final DeltaSet<Layer> originalAlphaDerivative = cursor.direction;
            @Nonnull final DeltaSet<Layer> newAlphaDerivative = originalAlphaDerivative.copy();
            deltaIn.getMap().forEach((layer, buffer) -> {
                @Nullable final double[] delta = buffer.getDelta();
                if (null == delta)
                    return;
                final double[] currentPosition = buffer.target;
                @Nullable final double[] originalAlphaD = originalAlphaDerivative.get(layer, currentPosition).getDelta();
                @Nullable final double[] newAlphaD = newAlphaDerivative.get(layer, currentPosition).getDelta();
                @Nonnull final double[] proposedPosition = ArrayUtil.add(currentPosition, delta);
                final TrustRegion region = getRegionPolicy(layer);
                if (null != region) {
                    final Stream<double[]> zz = history.stream().map((@Nonnull final PointSample x) -> {
                        final DoubleBuffer<Layer> d = x.weights.getMap().get(layer);
                        @Nullable final double[] z = null == d ? null : d.getDelta();
                        return z;
                    });
                    final double[] projectedPosition = region.project(zz.filter(x -> null != x).toArray(i -> new double[i][]), proposedPosition);
                    if (projectedPosition != proposedPosition) {
                        for (int i = 0; i < projectedPosition.length; i++) {
                            delta[i] = projectedPosition[i] - currentPosition[i];
                        }
                        @Nonnull final double[] normal = ArrayUtil.subtract(projectedPosition, proposedPosition);
                        final double normalMagSq = ArrayUtil.dot(normal, normal);
                        // normalMagSq));
                        if (0 < normalMagSq) {
                            final double a = ArrayUtil.dot(originalAlphaD, normal);
                            if (a != -1) {
                                @Nonnull final double[] tangent = ArrayUtil.add(originalAlphaD, ArrayUtil.multiply(normal, -a / normalMagSq));
                                for (int i = 0; i < tangent.length; i++) {
                                    newAlphaD[i] = tangent[i];
                                }
                            // double newAlphaDerivSq = ArrayUtil.dot(tangent, tangent);
                            // double originalAlphaDerivSq = ArrayUtil.dot(originalAlphaD, originalAlphaD);
                            // assert(newAlphaDerivSq <= originalAlphaDerivSq);
                            // assert(Math.abs(ArrayUtil.dot(tangent, normal)) <= 1e-4);
                            // monitor.log(String.format("%s: normalMagSq = %s, newAlphaDerivSq = %s, originalAlphaDerivSq = %s", layer, normalMagSq, newAlphaDerivSq, originalAlphaDerivSq));
                            }
                        }
                    }
                }
            });
            return newAlphaDerivative;
        }

        @Override
        public void reset() {
            cursor.reset();
        }

        @Nonnull
        @Override
        public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
            cursor.reset();
            @Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
            @Nonnull final DeltaSet<Layer> adjustedGradient = project(adjustedPosVector, monitor);
            adjustedPosVector.accumulate(1);
            @Nonnull final PointSample sample = subject.measure(monitor).setRate(alpha);
            return new LineSearchPoint(sample, adjustedGradient.dot(sample.delta));
        }

        @Override
        public void _free() {
            cursor.freeRef();
        }
    };
}
Also used : TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) IntStream(java.util.stream.IntStream) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) ArrayUtil(com.simiacryptus.util.ArrayUtil) Trainable(com.simiacryptus.mindseye.eval.Trainable) List(java.util.List) Stream(java.util.stream.Stream) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) LinkedList(java.util.LinkedList) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) LineSearchCursorBase(com.simiacryptus.mindseye.opt.line.LineSearchCursorBase) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Layer(com.simiacryptus.mindseye.lang.Layer) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) PointSample(com.simiacryptus.mindseye.lang.PointSample) LineSearchCursorBase(com.simiacryptus.mindseye.opt.line.LineSearchCursorBase) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Aggregations

DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)98 Nonnull (javax.annotation.Nonnull)98 Result (com.simiacryptus.mindseye.lang.Result)90 Nullable (javax.annotation.Nullable)88 Layer (com.simiacryptus.mindseye.lang.Layer)86 TensorList (com.simiacryptus.mindseye.lang.TensorList)86 Arrays (java.util.Arrays)77 List (java.util.List)75 Tensor (com.simiacryptus.mindseye.lang.Tensor)73 Map (java.util.Map)66 IntStream (java.util.stream.IntStream)65 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)64 JsonObject (com.google.gson.JsonObject)62 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)61 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)60 Logger (org.slf4j.Logger)47 LoggerFactory (org.slf4j.LoggerFactory)47 ReferenceCounting (com.simiacryptus.mindseye.lang.ReferenceCounting)23 CudaTensor (com.simiacryptus.mindseye.lang.cudnn.CudaTensor)22 CudaTensorList (com.simiacryptus.mindseye.lang.cudnn.CudaTensorList)22