Search in sources :

Example 11 with Delta

use of com.simiacryptus.mindseye.lang.Delta in project MindsEye by SimiaCryptus.

the class SingleDerivativeTester method getFeedbackGradient.

@Nonnull
private Tensor getFeedbackGradient(@Nonnull final Layer component, final int inputIndex, @Nonnull final Tensor outputPrototype, @Nonnull final Tensor... inputPrototype) {
    final Tensor inputTensor = inputPrototype[inputIndex];
    final int inputDims = inputTensor.length();
    @Nonnull final Tensor result = new Tensor(inputDims, outputPrototype.length());
    for (int j = 0; j < outputPrototype.length(); j++) {
        final int j_ = j;
        @Nonnull final PlaceholderLayer<Tensor> inputKey = new PlaceholderLayer<Tensor>(new Tensor(1));
        inputKey.getKey().freeRef();
        final Result[] copyInput = Arrays.stream(inputPrototype).map(x -> new Result(TensorArray.create(x), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        }) {

            @Override
            public boolean isAlive() {
                return false;
            }
        }).toArray(i -> new Result[i]);
        copyInput[inputIndex].getData().freeRef();
        copyInput[inputIndex].freeRef();
        double[] target = new double[inputDims * outputPrototype.length()];
        copyInput[inputIndex] = new Result(TensorArray.create(inputTensor), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
            if (1 != data.length())
                throw new AssertionError();
            if (data.length() != 1)
                throw new AssertionError();
            @Nonnull final Tensor gradientBuffer = new Tensor(inputDims, outputPrototype.length());
            if (!Arrays.equals(inputTensor.getDimensions(), data.getDimensions())) {
                throw new AssertionError();
            }
            IntStream.range(0, data.length()).forEach(dataIndex -> {
                for (int i = 0; i < inputDims; i++) {
                    @Nullable Tensor tensor = data.get(dataIndex);
                    gradientBuffer.set(new int[] { i, j_ }, tensor.getData()[i]);
                    tensor.freeRef();
                }
            });
            buffer.get(inputKey, target).addInPlace(gradientBuffer.getData()).freeRef();
            gradientBuffer.freeRef();
        }) {

            @Override
            public boolean isAlive() {
                return true;
            }
        };
        @Nullable final Result eval;
        try {
            eval = component.eval(copyInput);
        } finally {
            for (@Nonnull Result nnResult : copyInput) {
                nnResult.freeRef();
                nnResult.getData().freeRef();
            }
        }
        @Nonnull final DeltaSet<Layer> deltaSet = new DeltaSet<Layer>();
        @Nonnull TensorArray tensorArray = TensorArray.wrap(new Tensor(outputPrototype.getDimensions()).set(j, 1));
        try {
            eval.accumulate(deltaSet, tensorArray);
            final Delta<Layer> inputDelta = deltaSet.getMap().get(inputKey);
            if (null != inputDelta) {
                @Nonnull Tensor tensor = new Tensor(inputDelta.getDelta(), result.getDimensions());
                result.addInPlace(tensor);
                tensor.freeRef();
            }
        } finally {
            eval.getData().freeRef();
            eval.freeRef();
            deltaSet.freeRef();
            inputKey.freeRef();
        }
    }
    return result;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Result(com.simiacryptus.mindseye.lang.Result) Collectors(java.util.stream.Collectors) Delta(com.simiacryptus.mindseye.lang.Delta) List(java.util.List) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Optional(java.util.Optional) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nullable(javax.annotation.Nullable) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Nonnull(javax.annotation.Nonnull)

Example 12 with Delta

use of com.simiacryptus.mindseye.lang.Delta in project MindsEye by SimiaCryptus.

the class SingleDerivativeTester method testFrozen.

/**
 * Test frozen.
 *
 * @param component      the component
 * @param inputPrototype the input prototype
 */
public void testFrozen(@Nonnull final Layer component, @Nonnull Tensor[] inputPrototype) {
    final int inElements = Arrays.stream(inputPrototype).mapToInt(x -> x.length()).sum();
    inputPrototype = Arrays.stream(inputPrototype).map(tensor -> tensor.copy()).toArray(i -> new Tensor[i]);
    @Nonnull final AtomicBoolean reachedInputFeedback = new AtomicBoolean(false);
    @Nonnull final Layer frozen = component.copy().freeze();
    List<TensorArray> inputCopies = Arrays.stream(inputPrototype).map(TensorArray::wrap).collect(Collectors.toList());
    Result[] input = inputCopies.stream().map((tensorArray) -> new Result(tensorArray, (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        reachedInputFeedback.set(true);
    }) {

        @Override
        public boolean isAlive() {
            return true;
        }
    }).toArray(i -> new Result[i]);
    @Nullable final Result eval;
    try {
        eval = frozen.eval(input);
    } finally {
        for (@Nonnull Result result : input) {
            result.freeRef();
        }
        frozen.freeRef();
        for (@Nonnull TensorArray tensorArray : inputCopies) {
            tensorArray.freeRef();
        }
    }
    @Nonnull final DeltaSet<Layer> buffer;
    TensorList tensorList;
    TensorList evalData = eval.getData();
    try {
        buffer = new DeltaSet<Layer>();
        tensorList = evalData.copy();
        eval.accumulate(buffer, tensorList);
    } finally {
        evalData.freeRef();
        eval.freeRef();
    }
    final List<Delta<Layer>> deltas = component.state().stream().map(doubles -> {
        return buffer.stream().filter(x -> x.target == doubles).findFirst().orElse(null);
    }).filter(x -> x != null).collect(Collectors.toList());
    buffer.freeRef();
    if (!deltas.isEmpty() && !component.state().isEmpty()) {
        throw new AssertionError("Frozen component listed in delta. Deltas: " + deltas);
    }
    if (!reachedInputFeedback.get() && 0 < inElements) {
        throw new RuntimeException("Frozen component did not pass input backwards");
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Result(com.simiacryptus.mindseye.lang.Result) Collectors(java.util.stream.Collectors) Delta(com.simiacryptus.mindseye.lang.Delta) List(java.util.List) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Optional(java.util.Optional) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) Delta(com.simiacryptus.mindseye.lang.Delta) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nullable(javax.annotation.Nullable)

Example 13 with Delta

use of com.simiacryptus.mindseye.lang.Delta in project MindsEye by SimiaCryptus.

the class RecursiveSubspace method buildSubspace.

/**
 * Build subspace nn layer.
 *
 * @param subject     the subject
 * @param measurement the measurement
 * @param monitor     the monitor
 * @return the nn layer
 */
@Nullable
public Layer buildSubspace(@Nonnull Trainable subject, @Nonnull PointSample measurement, @Nonnull TrainingMonitor monitor) {
    @Nonnull PointSample origin = measurement.copyFull().backup();
    @Nonnull final DeltaSet<Layer> direction = measurement.delta.scale(-1);
    final double magnitude = direction.getMagnitude();
    if (Math.abs(magnitude) < 1e-10) {
        monitor.log(String.format("Zero gradient: %s", magnitude));
    } else if (Math.abs(magnitude) < 1e-5) {
        monitor.log(String.format("Low gradient: %s", magnitude));
    }
    boolean hasPlaceholders = direction.getMap().entrySet().stream().filter(x -> x.getKey() instanceof PlaceholderLayer).findAny().isPresent();
    List<Layer> deltaLayers = direction.getMap().entrySet().stream().map(x -> x.getKey()).filter(x -> !(x instanceof PlaceholderLayer)).collect(Collectors.toList());
    int size = deltaLayers.size() + (hasPlaceholders ? 1 : 0);
    if (null == weights || weights.length != size)
        weights = new double[size];
    return new LayerBase() {

        @Nonnull
        Layer self = this;

        @Nonnull
        @Override
        public Result eval(Result... array) {
            assertAlive();
            origin.restore();
            IntStream.range(0, deltaLayers.size()).forEach(i -> {
                direction.getMap().get(deltaLayers.get(i)).accumulate(weights[hasPlaceholders ? (i + 1) : i]);
            });
            if (hasPlaceholders) {
                direction.getMap().entrySet().stream().filter(x -> x.getKey() instanceof PlaceholderLayer).distinct().forEach(entry -> entry.getValue().accumulate(weights[0]));
            }
            PointSample measure = subject.measure(monitor);
            double mean = measure.getMean();
            monitor.log(String.format("RecursiveSubspace: %s <- %s", mean, Arrays.toString(weights)));
            direction.addRef();
            return new Result(TensorArray.wrap(new Tensor(mean)), (DeltaSet<Layer> buffer, TensorList data) -> {
                DoubleStream deltaStream = deltaLayers.stream().mapToDouble(layer -> {
                    Delta<Layer> a = direction.getMap().get(layer);
                    Delta<Layer> b = measure.delta.getMap().get(layer);
                    return b.dot(a) / Math.max(Math.sqrt(a.dot(a)), 1e-8);
                });
                if (hasPlaceholders) {
                    deltaStream = DoubleStream.concat(DoubleStream.of(direction.getMap().keySet().stream().filter(x -> x instanceof PlaceholderLayer).distinct().mapToDouble(layer -> {
                        Delta<Layer> a = direction.getMap().get(layer);
                        Delta<Layer> b = measure.delta.getMap().get(layer);
                        return b.dot(a) / Math.max(Math.sqrt(a.dot(a)), 1e-8);
                    }).sum()), deltaStream);
                }
                buffer.get(self, weights).addInPlace(deltaStream.toArray()).freeRef();
            }) {

                @Override
                protected void _free() {
                    measure.freeRef();
                    direction.freeRef();
                }

                @Override
                public boolean isAlive() {
                    return true;
                }
            };
        }

        @Override
        protected void _free() {
            direction.freeRef();
            origin.freeRef();
            super._free();
        }

        @Nonnull
        @Override
        public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
            throw new IllegalStateException();
        }

        @Nullable
        @Override
        public List<double[]> state() {
            return null;
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) StateSet(com.simiacryptus.mindseye.lang.StateSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) Delta(com.simiacryptus.mindseye.lang.Delta) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) Collectors(java.util.stream.Collectors) DoubleStream(java.util.stream.DoubleStream) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) PointSample(com.simiacryptus.mindseye.lang.PointSample) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) DoubleStream(java.util.stream.DoubleStream) PointSample(com.simiacryptus.mindseye.lang.PointSample) Map(java.util.Map) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) Nullable(javax.annotation.Nullable)

Aggregations

Delta (com.simiacryptus.mindseye.lang.Delta)13 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)13 Layer (com.simiacryptus.mindseye.lang.Layer)13 Arrays (java.util.Arrays)13 List (java.util.List)13 Nonnull (javax.annotation.Nonnull)13 Nullable (javax.annotation.Nullable)13 Result (com.simiacryptus.mindseye.lang.Result)12 Tensor (com.simiacryptus.mindseye.lang.Tensor)12 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)12 TensorList (com.simiacryptus.mindseye.lang.TensorList)12 Logger (org.slf4j.Logger)11 LoggerFactory (org.slf4j.LoggerFactory)11 IntStream (java.util.stream.IntStream)10 PlaceholderLayer (com.simiacryptus.mindseye.layers.java.PlaceholderLayer)8 Collectors (java.util.stream.Collectors)8 Map (java.util.Map)7 JsonObject (com.google.gson.JsonObject)6 ConstantResult (com.simiacryptus.mindseye.lang.ConstantResult)6 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)6