Search in sources :

Example 1 with DataSerializer

use of com.simiacryptus.mindseye.lang.DataSerializer in project MindsEye by SimiaCryptus.

the class StandardLayerTests method getInvocations.

/**
 * Gets invocations.
 *
 * @param smallLayer the small layer
 * @param smallDims  the small dims
 * @return the invocations
 */
@Nonnull
public Collection<Invocation> getInvocations(@Nonnull Layer smallLayer, @Nonnull int[][] smallDims) {
    @Nonnull DAGNetwork smallCopy = (DAGNetwork) smallLayer.copy();
    @Nonnull HashSet<Invocation> invocations = new HashSet<>();
    smallCopy.visitNodes(node -> {
        @Nullable Layer inner = node.getLayer();
        inner.addRef();
        @Nullable Layer wrapper = new LayerBase() {

            @Nullable
            @Override
            public Result eval(@Nonnull Result... array) {
                if (null == inner)
                    return null;
                @Nullable Result result = inner.eval(array);
                invocations.add(new Invocation(inner, Arrays.stream(array).map(x -> x.getData().getDimensions()).toArray(i -> new int[i][])));
                return result;
            }

            @Override
            public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
                return inner.getJson(resources, dataSerializer);
            }

            @Nullable
            @Override
            public List<double[]> state() {
                return inner.state();
            }

            @Override
            protected void _free() {
                inner.freeRef();
            }
        };
        node.setLayer(wrapper);
        wrapper.freeRef();
    });
    Tensor[] input = Arrays.stream(smallDims).map(i -> new Tensor(i)).toArray(i -> new Tensor[i]);
    try {
        Result eval = smallCopy.eval(input);
        eval.freeRef();
        eval.getData().freeRef();
        return invocations;
    } finally {
        Arrays.stream(input).forEach(ReferenceCounting::freeRef);
        smallCopy.freeRef();
    }
}
Also used : JsonObject(com.google.gson.JsonObject) Graphviz(guru.nidi.graphviz.engine.Graphviz) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) NotebookReportBase(com.simiacryptus.mindseye.test.NotebookReportBase) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) Format(guru.nidi.graphviz.engine.Format) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) CudaError(com.simiacryptus.mindseye.lang.cudnn.CudaError) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) SysOutInterceptor(com.simiacryptus.util.test.SysOutInterceptor) Collection(java.util.Collection) TestUtil(com.simiacryptus.mindseye.test.TestUtil) File(java.io.File) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) Explodable(com.simiacryptus.mindseye.layers.cudnn.Explodable) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) LifecycleException(com.simiacryptus.mindseye.lang.LifecycleException) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) HashMap(java.util.HashMap) Map(java.util.Map) Nullable(javax.annotation.Nullable) HashSet(java.util.HashSet) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) Nonnull(javax.annotation.Nonnull)

Example 2 with DataSerializer

use of com.simiacryptus.mindseye.lang.DataSerializer in project MindsEye by SimiaCryptus.

the class ImgLinearSubnetLayer method getJson.

@Nonnull
@Override
public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
    @Nonnull final JsonObject json = super.getJsonStub();
    json.addProperty("precision", precision.name());
    json.addProperty("parallel", isParallel());
    JsonArray jsonArray = new JsonArray();
    legs.stream().map(x -> x.getJson(resources, dataSerializer)).forEach(jsonArray::add);
    json.add("legs", jsonArray);
    return json;
}
Also used : JsonArray(com.google.gson.JsonArray) IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) CudaMemory(com.simiacryptus.mindseye.lang.cudnn.CudaMemory) LoggerFactory(org.slf4j.LoggerFactory) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) ArrayList(java.util.ArrayList) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Logger(org.slf4j.Logger) CudaDevice(com.simiacryptus.mindseye.lang.cudnn.CudaDevice) CudaTensor(com.simiacryptus.mindseye.lang.cudnn.CudaTensor) CudaTensorList(com.simiacryptus.mindseye.lang.cudnn.CudaTensorList) JsonArray(com.google.gson.JsonArray) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) Stream(java.util.stream.Stream) CudaSystem(com.simiacryptus.mindseye.lang.cudnn.CudaSystem) TensorList(com.simiacryptus.mindseye.lang.TensorList) MemoryType(com.simiacryptus.mindseye.lang.cudnn.MemoryType) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) JsonObject(com.google.gson.JsonObject) Nonnull(javax.annotation.Nonnull)

Example 3 with DataSerializer

use of com.simiacryptus.mindseye.lang.DataSerializer in project MindsEye by SimiaCryptus.

the class SimpleConvolutionLayer method getCompatibilityLayer.

/**
 * Gets compatibility layer.
 *
 * @return the compatibility layer
 */
@Nonnull
public Layer getCompatibilityLayer() {
    log.info(String.format("Using compatibility layer for %s", this));
    int bands = (int) Math.sqrt(this.kernel.getDimensions()[2]);
    @Nonnull final com.simiacryptus.mindseye.layers.aparapi.ConvolutionLayer convolutionLayer = new com.simiacryptus.mindseye.layers.aparapi.ConvolutionLayer(this.kernel.getDimensions()[0], this.kernel.getDimensions()[1], this.kernel.getDimensions()[2], true);
    @Nonnull final Tensor tensor = new Tensor(kernel.getDimensions());
    tensor.setByCoord(c -> {
        final int band = c.getCoords()[2];
        final int bandX = band % bands;
        final int bandY = (band - bandX) / bands;
        assert band == bandX + bandY * bands;
        final int bandT = bandY + bandX * bands;
        return kernel.get(c.getCoords()[0], c.getCoords()[1], bandT);
    });
    convolutionLayer.kernel.set(tensor);
    return new LayerBase() {

        @Nonnull
        @Override
        public Result eval(@Nonnull Result... array) {
            Arrays.stream(array).forEach(x -> x.addRef());
            @Nonnull Result result = convolutionLayer.eval(array);
            return new Result(result.getData(), (DeltaSet<Layer> buffer, TensorList data) -> {
                throw new IllegalStateException();
            }) {

                @Override
                protected void _free() {
                    Arrays.stream(array).forEach(x -> x.freeRef());
                }

                @Override
                public boolean isAlive() {
                    return false;
                }
            };
        }

        @Nonnull
        @Override
        public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
            throw new IllegalStateException();
        }

        @Nonnull
        @Override
        public List<double[]> state() {
            throw new IllegalStateException();
        }
    };
}
Also used : Tensor(com.simiacryptus.mindseye.lang.Tensor) CudaTensor(com.simiacryptus.mindseye.lang.cudnn.CudaTensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) CudaTensorList(com.simiacryptus.mindseye.lang.cudnn.CudaTensorList) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) Nonnull(javax.annotation.Nonnull)

Example 4 with DataSerializer

use of com.simiacryptus.mindseye.lang.DataSerializer in project MindsEye by SimiaCryptus.

the class RecursiveSubspace method buildSubspace.

/**
 * Build subspace nn layer.
 *
 * @param subject     the subject
 * @param measurement the measurement
 * @param monitor     the monitor
 * @return the nn layer
 */
@Nullable
public Layer buildSubspace(@Nonnull Trainable subject, @Nonnull PointSample measurement, @Nonnull TrainingMonitor monitor) {
    @Nonnull PointSample origin = measurement.copyFull().backup();
    @Nonnull final DeltaSet<Layer> direction = measurement.delta.scale(-1);
    final double magnitude = direction.getMagnitude();
    if (Math.abs(magnitude) < 1e-10) {
        monitor.log(String.format("Zero gradient: %s", magnitude));
    } else if (Math.abs(magnitude) < 1e-5) {
        monitor.log(String.format("Low gradient: %s", magnitude));
    }
    boolean hasPlaceholders = direction.getMap().entrySet().stream().filter(x -> x.getKey() instanceof PlaceholderLayer).findAny().isPresent();
    List<Layer> deltaLayers = direction.getMap().entrySet().stream().map(x -> x.getKey()).filter(x -> !(x instanceof PlaceholderLayer)).collect(Collectors.toList());
    int size = deltaLayers.size() + (hasPlaceholders ? 1 : 0);
    if (null == weights || weights.length != size)
        weights = new double[size];
    return new LayerBase() {

        @Nonnull
        Layer self = this;

        @Nonnull
        @Override
        public Result eval(Result... array) {
            assertAlive();
            origin.restore();
            IntStream.range(0, deltaLayers.size()).forEach(i -> {
                direction.getMap().get(deltaLayers.get(i)).accumulate(weights[hasPlaceholders ? (i + 1) : i]);
            });
            if (hasPlaceholders) {
                direction.getMap().entrySet().stream().filter(x -> x.getKey() instanceof PlaceholderLayer).distinct().forEach(entry -> entry.getValue().accumulate(weights[0]));
            }
            PointSample measure = subject.measure(monitor);
            double mean = measure.getMean();
            monitor.log(String.format("RecursiveSubspace: %s <- %s", mean, Arrays.toString(weights)));
            direction.addRef();
            return new Result(TensorArray.wrap(new Tensor(mean)), (DeltaSet<Layer> buffer, TensorList data) -> {
                DoubleStream deltaStream = deltaLayers.stream().mapToDouble(layer -> {
                    Delta<Layer> a = direction.getMap().get(layer);
                    Delta<Layer> b = measure.delta.getMap().get(layer);
                    return b.dot(a) / Math.max(Math.sqrt(a.dot(a)), 1e-8);
                });
                if (hasPlaceholders) {
                    deltaStream = DoubleStream.concat(DoubleStream.of(direction.getMap().keySet().stream().filter(x -> x instanceof PlaceholderLayer).distinct().mapToDouble(layer -> {
                        Delta<Layer> a = direction.getMap().get(layer);
                        Delta<Layer> b = measure.delta.getMap().get(layer);
                        return b.dot(a) / Math.max(Math.sqrt(a.dot(a)), 1e-8);
                    }).sum()), deltaStream);
                }
                buffer.get(self, weights).addInPlace(deltaStream.toArray()).freeRef();
            }) {

                @Override
                protected void _free() {
                    measure.freeRef();
                    direction.freeRef();
                }

                @Override
                public boolean isAlive() {
                    return true;
                }
            };
        }

        @Override
        protected void _free() {
            direction.freeRef();
            origin.freeRef();
            super._free();
        }

        @Nonnull
        @Override
        public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
            throw new IllegalStateException();
        }

        @Nullable
        @Override
        public List<double[]> state() {
            return null;
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) StateSet(com.simiacryptus.mindseye.lang.StateSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) Delta(com.simiacryptus.mindseye.lang.Delta) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) Collectors(java.util.stream.Collectors) DoubleStream(java.util.stream.DoubleStream) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) PointSample(com.simiacryptus.mindseye.lang.PointSample) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) DoubleStream(java.util.stream.DoubleStream) PointSample(com.simiacryptus.mindseye.lang.PointSample) Map(java.util.Map) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) Nullable(javax.annotation.Nullable)

Aggregations

DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)4 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)4 Result (com.simiacryptus.mindseye.lang.Result)4 Map (java.util.Map)4 Nonnull (javax.annotation.Nonnull)4 JsonObject (com.google.gson.JsonObject)3 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)3 Layer (com.simiacryptus.mindseye.lang.Layer)3 Tensor (com.simiacryptus.mindseye.lang.Tensor)3 TensorList (com.simiacryptus.mindseye.lang.TensorList)3 List (java.util.List)3 Nullable (javax.annotation.Nullable)3 ReferenceCounting (com.simiacryptus.mindseye.lang.ReferenceCounting)2 ReferenceCountingBase (com.simiacryptus.mindseye.lang.ReferenceCountingBase)2 CudaTensor (com.simiacryptus.mindseye.lang.cudnn.CudaTensor)2 CudaTensorList (com.simiacryptus.mindseye.lang.cudnn.CudaTensorList)2 ArrayList (java.util.ArrayList)2 Arrays (java.util.Arrays)2 IntStream (java.util.stream.IntStream)2 JsonArray (com.google.gson.JsonArray)1