Search in sources :

Example 1 with LayerBase

use of com.simiacryptus.mindseye.lang.LayerBase in project MindsEye by SimiaCryptus.

the class StandardLayerTests method getInvocations.

/**
 * Gets invocations.
 *
 * @param smallLayer the small layer
 * @param smallDims  the small dims
 * @return the invocations
 */
@Nonnull
public Collection<Invocation> getInvocations(@Nonnull Layer smallLayer, @Nonnull int[][] smallDims) {
    @Nonnull DAGNetwork smallCopy = (DAGNetwork) smallLayer.copy();
    @Nonnull HashSet<Invocation> invocations = new HashSet<>();
    smallCopy.visitNodes(node -> {
        @Nullable Layer inner = node.getLayer();
        inner.addRef();
        @Nullable Layer wrapper = new LayerBase() {

            @Nullable
            @Override
            public Result eval(@Nonnull Result... array) {
                if (null == inner)
                    return null;
                @Nullable Result result = inner.eval(array);
                invocations.add(new Invocation(inner, Arrays.stream(array).map(x -> x.getData().getDimensions()).toArray(i -> new int[i][])));
                return result;
            }

            @Override
            public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
                return inner.getJson(resources, dataSerializer);
            }

            @Nullable
            @Override
            public List<double[]> state() {
                return inner.state();
            }

            @Override
            protected void _free() {
                inner.freeRef();
            }
        };
        node.setLayer(wrapper);
        wrapper.freeRef();
    });
    Tensor[] input = Arrays.stream(smallDims).map(i -> new Tensor(i)).toArray(i -> new Tensor[i]);
    try {
        Result eval = smallCopy.eval(input);
        eval.freeRef();
        eval.getData().freeRef();
        return invocations;
    } finally {
        Arrays.stream(input).forEach(ReferenceCounting::freeRef);
        smallCopy.freeRef();
    }
}
Also used : JsonObject(com.google.gson.JsonObject) Graphviz(guru.nidi.graphviz.engine.Graphviz) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) NotebookReportBase(com.simiacryptus.mindseye.test.NotebookReportBase) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) Format(guru.nidi.graphviz.engine.Format) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) CudaError(com.simiacryptus.mindseye.lang.cudnn.CudaError) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) SysOutInterceptor(com.simiacryptus.util.test.SysOutInterceptor) Collection(java.util.Collection) TestUtil(com.simiacryptus.mindseye.test.TestUtil) File(java.io.File) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) Explodable(com.simiacryptus.mindseye.layers.cudnn.Explodable) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) LifecycleException(com.simiacryptus.mindseye.lang.LifecycleException) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) HashMap(java.util.HashMap) Map(java.util.Map) Nullable(javax.annotation.Nullable) HashSet(java.util.HashSet) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) Nonnull(javax.annotation.Nonnull)

Example 2 with LayerBase

use of com.simiacryptus.mindseye.lang.LayerBase in project MindsEye by SimiaCryptus.

the class SimpleConvolutionLayer method getCompatibilityLayer.

/**
 * Gets compatibility layer.
 *
 * @return the compatibility layer
 */
@Nonnull
public Layer getCompatibilityLayer() {
    log.info(String.format("Using compatibility layer for %s", this));
    int bands = (int) Math.sqrt(this.kernel.getDimensions()[2]);
    @Nonnull final com.simiacryptus.mindseye.layers.aparapi.ConvolutionLayer convolutionLayer = new com.simiacryptus.mindseye.layers.aparapi.ConvolutionLayer(this.kernel.getDimensions()[0], this.kernel.getDimensions()[1], this.kernel.getDimensions()[2], true);
    @Nonnull final Tensor tensor = new Tensor(kernel.getDimensions());
    tensor.setByCoord(c -> {
        final int band = c.getCoords()[2];
        final int bandX = band % bands;
        final int bandY = (band - bandX) / bands;
        assert band == bandX + bandY * bands;
        final int bandT = bandY + bandX * bands;
        return kernel.get(c.getCoords()[0], c.getCoords()[1], bandT);
    });
    convolutionLayer.kernel.set(tensor);
    return new LayerBase() {

        @Nonnull
        @Override
        public Result eval(@Nonnull Result... array) {
            Arrays.stream(array).forEach(x -> x.addRef());
            @Nonnull Result result = convolutionLayer.eval(array);
            return new Result(result.getData(), (DeltaSet<Layer> buffer, TensorList data) -> {
                throw new IllegalStateException();
            }) {

                @Override
                protected void _free() {
                    Arrays.stream(array).forEach(x -> x.freeRef());
                }

                @Override
                public boolean isAlive() {
                    return false;
                }
            };
        }

        @Nonnull
        @Override
        public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
            throw new IllegalStateException();
        }

        @Nonnull
        @Override
        public List<double[]> state() {
            throw new IllegalStateException();
        }
    };
}
Also used : Tensor(com.simiacryptus.mindseye.lang.Tensor) CudaTensor(com.simiacryptus.mindseye.lang.cudnn.CudaTensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) CudaTensorList(com.simiacryptus.mindseye.lang.cudnn.CudaTensorList) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) Nonnull(javax.annotation.Nonnull)

Example 3 with LayerBase

use of com.simiacryptus.mindseye.lang.LayerBase in project MindsEye by SimiaCryptus.

the class RecursiveSubspace method buildSubspace.

/**
 * Build subspace nn layer.
 *
 * @param subject     the subject
 * @param measurement the measurement
 * @param monitor     the monitor
 * @return the nn layer
 */
@Nullable
public Layer buildSubspace(@Nonnull Trainable subject, @Nonnull PointSample measurement, @Nonnull TrainingMonitor monitor) {
    @Nonnull PointSample origin = measurement.copyFull().backup();
    @Nonnull final DeltaSet<Layer> direction = measurement.delta.scale(-1);
    final double magnitude = direction.getMagnitude();
    if (Math.abs(magnitude) < 1e-10) {
        monitor.log(String.format("Zero gradient: %s", magnitude));
    } else if (Math.abs(magnitude) < 1e-5) {
        monitor.log(String.format("Low gradient: %s", magnitude));
    }
    boolean hasPlaceholders = direction.getMap().entrySet().stream().filter(x -> x.getKey() instanceof PlaceholderLayer).findAny().isPresent();
    List<Layer> deltaLayers = direction.getMap().entrySet().stream().map(x -> x.getKey()).filter(x -> !(x instanceof PlaceholderLayer)).collect(Collectors.toList());
    int size = deltaLayers.size() + (hasPlaceholders ? 1 : 0);
    if (null == weights || weights.length != size)
        weights = new double[size];
    return new LayerBase() {

        @Nonnull
        Layer self = this;

        @Nonnull
        @Override
        public Result eval(Result... array) {
            assertAlive();
            origin.restore();
            IntStream.range(0, deltaLayers.size()).forEach(i -> {
                direction.getMap().get(deltaLayers.get(i)).accumulate(weights[hasPlaceholders ? (i + 1) : i]);
            });
            if (hasPlaceholders) {
                direction.getMap().entrySet().stream().filter(x -> x.getKey() instanceof PlaceholderLayer).distinct().forEach(entry -> entry.getValue().accumulate(weights[0]));
            }
            PointSample measure = subject.measure(monitor);
            double mean = measure.getMean();
            monitor.log(String.format("RecursiveSubspace: %s <- %s", mean, Arrays.toString(weights)));
            direction.addRef();
            return new Result(TensorArray.wrap(new Tensor(mean)), (DeltaSet<Layer> buffer, TensorList data) -> {
                DoubleStream deltaStream = deltaLayers.stream().mapToDouble(layer -> {
                    Delta<Layer> a = direction.getMap().get(layer);
                    Delta<Layer> b = measure.delta.getMap().get(layer);
                    return b.dot(a) / Math.max(Math.sqrt(a.dot(a)), 1e-8);
                });
                if (hasPlaceholders) {
                    deltaStream = DoubleStream.concat(DoubleStream.of(direction.getMap().keySet().stream().filter(x -> x instanceof PlaceholderLayer).distinct().mapToDouble(layer -> {
                        Delta<Layer> a = direction.getMap().get(layer);
                        Delta<Layer> b = measure.delta.getMap().get(layer);
                        return b.dot(a) / Math.max(Math.sqrt(a.dot(a)), 1e-8);
                    }).sum()), deltaStream);
                }
                buffer.get(self, weights).addInPlace(deltaStream.toArray()).freeRef();
            }) {

                @Override
                protected void _free() {
                    measure.freeRef();
                    direction.freeRef();
                }

                @Override
                public boolean isAlive() {
                    return true;
                }
            };
        }

        @Override
        protected void _free() {
            direction.freeRef();
            origin.freeRef();
            super._free();
        }

        @Nonnull
        @Override
        public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
            throw new IllegalStateException();
        }

        @Nullable
        @Override
        public List<double[]> state() {
            return null;
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) StateSet(com.simiacryptus.mindseye.lang.StateSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) Delta(com.simiacryptus.mindseye.lang.Delta) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) Collectors(java.util.stream.Collectors) DoubleStream(java.util.stream.DoubleStream) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) PointSample(com.simiacryptus.mindseye.lang.PointSample) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) DoubleStream(java.util.stream.DoubleStream) PointSample(com.simiacryptus.mindseye.lang.PointSample) Map(java.util.Map) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) Nullable(javax.annotation.Nullable)

Aggregations

DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)3 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)3 Result (com.simiacryptus.mindseye.lang.Result)3 Tensor (com.simiacryptus.mindseye.lang.Tensor)3 Map (java.util.Map)3 Nonnull (javax.annotation.Nonnull)3 JsonObject (com.google.gson.JsonObject)2 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)2 Layer (com.simiacryptus.mindseye.lang.Layer)2 TensorList (com.simiacryptus.mindseye.lang.TensorList)2 Arrays (java.util.Arrays)2 List (java.util.List)2 Nullable (javax.annotation.Nullable)2 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)1 BasicTrainable (com.simiacryptus.mindseye.eval.BasicTrainable)1 Trainable (com.simiacryptus.mindseye.eval.Trainable)1 Delta (com.simiacryptus.mindseye.lang.Delta)1 LifecycleException (com.simiacryptus.mindseye.lang.LifecycleException)1 PointSample (com.simiacryptus.mindseye.lang.PointSample)1 ReferenceCounting (com.simiacryptus.mindseye.lang.ReferenceCounting)1