Search in sources :

Example 96 with DeltaSet

use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.

the class RecursiveSubspace method buildSubspace.

/**
 * Build subspace nn layer.
 *
 * @param subject     the subject
 * @param measurement the measurement
 * @param monitor     the monitor
 * @return the nn layer
 */
@Nullable
public Layer buildSubspace(@Nonnull Trainable subject, @Nonnull PointSample measurement, @Nonnull TrainingMonitor monitor) {
    @Nonnull PointSample origin = measurement.copyFull().backup();
    @Nonnull final DeltaSet<Layer> direction = measurement.delta.scale(-1);
    final double magnitude = direction.getMagnitude();
    if (Math.abs(magnitude) < 1e-10) {
        monitor.log(String.format("Zero gradient: %s", magnitude));
    } else if (Math.abs(magnitude) < 1e-5) {
        monitor.log(String.format("Low gradient: %s", magnitude));
    }
    boolean hasPlaceholders = direction.getMap().entrySet().stream().filter(x -> x.getKey() instanceof PlaceholderLayer).findAny().isPresent();
    List<Layer> deltaLayers = direction.getMap().entrySet().stream().map(x -> x.getKey()).filter(x -> !(x instanceof PlaceholderLayer)).collect(Collectors.toList());
    int size = deltaLayers.size() + (hasPlaceholders ? 1 : 0);
    if (null == weights || weights.length != size)
        weights = new double[size];
    return new LayerBase() {

        @Nonnull
        Layer self = this;

        @Nonnull
        @Override
        public Result eval(Result... array) {
            assertAlive();
            origin.restore();
            IntStream.range(0, deltaLayers.size()).forEach(i -> {
                direction.getMap().get(deltaLayers.get(i)).accumulate(weights[hasPlaceholders ? (i + 1) : i]);
            });
            if (hasPlaceholders) {
                direction.getMap().entrySet().stream().filter(x -> x.getKey() instanceof PlaceholderLayer).distinct().forEach(entry -> entry.getValue().accumulate(weights[0]));
            }
            PointSample measure = subject.measure(monitor);
            double mean = measure.getMean();
            monitor.log(String.format("RecursiveSubspace: %s <- %s", mean, Arrays.toString(weights)));
            direction.addRef();
            return new Result(TensorArray.wrap(new Tensor(mean)), (DeltaSet<Layer> buffer, TensorList data) -> {
                DoubleStream deltaStream = deltaLayers.stream().mapToDouble(layer -> {
                    Delta<Layer> a = direction.getMap().get(layer);
                    Delta<Layer> b = measure.delta.getMap().get(layer);
                    return b.dot(a) / Math.max(Math.sqrt(a.dot(a)), 1e-8);
                });
                if (hasPlaceholders) {
                    deltaStream = DoubleStream.concat(DoubleStream.of(direction.getMap().keySet().stream().filter(x -> x instanceof PlaceholderLayer).distinct().mapToDouble(layer -> {
                        Delta<Layer> a = direction.getMap().get(layer);
                        Delta<Layer> b = measure.delta.getMap().get(layer);
                        return b.dot(a) / Math.max(Math.sqrt(a.dot(a)), 1e-8);
                    }).sum()), deltaStream);
                }
                buffer.get(self, weights).addInPlace(deltaStream.toArray()).freeRef();
            }) {

                @Override
                protected void _free() {
                    measure.freeRef();
                    direction.freeRef();
                }

                @Override
                public boolean isAlive() {
                    return true;
                }
            };
        }

        @Override
        protected void _free() {
            direction.freeRef();
            origin.freeRef();
            super._free();
        }

        @Nonnull
        @Override
        public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
            throw new IllegalStateException();
        }

        @Nullable
        @Override
        public List<double[]> state() {
            return null;
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) StateSet(com.simiacryptus.mindseye.lang.StateSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) Delta(com.simiacryptus.mindseye.lang.Delta) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) Collectors(java.util.stream.Collectors) DoubleStream(java.util.stream.DoubleStream) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) PointSample(com.simiacryptus.mindseye.lang.PointSample) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) DoubleStream(java.util.stream.DoubleStream) PointSample(com.simiacryptus.mindseye.lang.PointSample) Map(java.util.Map) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) Nullable(javax.annotation.Nullable)

Example 97 with DeltaSet

use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.

the class SimpleEval method call.

@Nonnull
@Override
public SimpleEval call() {
    Tensor[] inputCopy = Arrays.stream(input).map(x -> x.copy()).toArray(i -> new Tensor[i]);
    derivative = Arrays.stream(inputCopy).map(input -> new Tensor(input.getDimensions())).toArray(i -> new Tensor[i]);
    Result[] input = IntStream.range(0, inputCopy.length).mapToObj(i -> {
        return new Result(TensorArray.create(inputCopy[i]), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
            data.stream().forEach(t -> {
                derivative[i].addInPlace(t);
                t.freeRef();
            });
        }) {

            @Override
            protected void _free() {
            }

            @Override
            public boolean isAlive() {
                return true;
            }
        };
    }).toArray(i -> new Result[i]);
    @Nullable final Result eval;
    try {
        eval = layer.eval(input);
    } finally {
        for (@Nonnull Result result : input) {
            result.getData().freeRef();
            result.freeRef();
        }
        for (@Nonnull Tensor tensor : inputCopy) {
            tensor.freeRef();
        }
    }
    TensorList evalData = eval.getData();
    TensorList outputTensorList = evalData.copy();
    @Nullable Tensor outputTensor = outputTensorList.get(0);
    @Nonnull DeltaSet<Layer> deltaSet = new DeltaSet<>();
    try {
        synchronized (this) {
            if (null != output) {
                output.freeRef();
                output = null;
            }
        }
        output = outputTensor.copy();
        @Nonnull TensorList tensorList = getFeedback(outputTensorList);
        eval.accumulate(deltaSet, tensorList);
        return this;
    } finally {
        outputTensor.freeRef();
        evalData.freeRef();
        outputTensorList.freeRef();
        eval.freeRef();
        deltaSet.freeRef();
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) TensorList(com.simiacryptus.mindseye.lang.TensorList) Layer(com.simiacryptus.mindseye.lang.Layer) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) Callable(java.util.concurrent.Callable) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Layer(com.simiacryptus.mindseye.lang.Layer) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Example 98 with DeltaSet

use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.

the class SimpleListEval method call.

@Nonnull
@Override
public SimpleResult call() {
    TensorList[] inputCopy = Arrays.stream(input).map(x -> x.copy()).toArray(i -> new TensorList[i]);
    inputDerivative = Arrays.stream(inputCopy).map(tensorList -> TensorArray.wrap(tensorList.stream().map(i -> {
        @Nonnull Tensor tensor = new Tensor(i.getDimensions());
        i.freeRef();
        return tensor;
    }).toArray(i -> new Tensor[i]))).toArray(i -> new TensorList[i]);
    Result[] inputs = IntStream.range(0, inputCopy.length).mapToObj(i -> {
        return new Result(inputCopy[i], (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
            SimpleListEval.accumulate(inputDerivative[i], data);
        }) {

            @Override
            public boolean isAlive() {
                return true;
            }
        };
    }).toArray(i -> new Result[i]);
    @Nullable final Result eval = layer.eval(inputs);
    for (@Nonnull Result result : inputs) {
        result.freeRef();
    }
    TensorList outputData = eval.getData().copy();
    for (@Nonnull TensorList tensorList : inputCopy) {
        tensorList.freeRef();
    }
    eval.getData().freeRef();
    @Nonnull TensorList tensorList = getFeedback(outputData);
    this.layerDerivative.freeRef();
    this.layerDerivative = new DeltaSet<>();
    eval.accumulate(layerDerivative, tensorList);
    eval.freeRef();
    output = outputData;
    return this;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) TensorList(com.simiacryptus.mindseye.lang.TensorList) Layer(com.simiacryptus.mindseye.lang.Layer) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) Callable(java.util.concurrent.Callable) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) TensorList(com.simiacryptus.mindseye.lang.TensorList) Layer(com.simiacryptus.mindseye.lang.Layer) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Aggregations

DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)98 Nonnull (javax.annotation.Nonnull)98 Result (com.simiacryptus.mindseye.lang.Result)90 Nullable (javax.annotation.Nullable)88 Layer (com.simiacryptus.mindseye.lang.Layer)86 TensorList (com.simiacryptus.mindseye.lang.TensorList)86 Arrays (java.util.Arrays)77 List (java.util.List)75 Tensor (com.simiacryptus.mindseye.lang.Tensor)73 Map (java.util.Map)66 IntStream (java.util.stream.IntStream)65 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)64 JsonObject (com.google.gson.JsonObject)62 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)61 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)60 Logger (org.slf4j.Logger)47 LoggerFactory (org.slf4j.LoggerFactory)47 ReferenceCounting (com.simiacryptus.mindseye.lang.ReferenceCounting)23 CudaTensor (com.simiacryptus.mindseye.lang.cudnn.CudaTensor)22 CudaTensorList (com.simiacryptus.mindseye.lang.cudnn.CudaTensorList)22