Search in sources :

Example 21 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class ImgTileAssemblyLayer method eval.

@Nonnull
@Override
public Result eval(@Nonnull final Result... inObj) {
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    assert 3 == inObj[0].getData().getDimensions().length;
    int[] outputDims = getOutputDims(inObj);
    return new Result(TensorArray.wrap(IntStream.range(0, inObj[0].getData().length()).parallel().mapToObj(dataIndex -> {
        @Nonnull final Tensor outputData = new Tensor(outputDims);
        int totalWidth = 0;
        int totalHeight = 0;
        int inputIndex = 0;
        for (int row = 0; row < rows; row++) {
            int positionX = 0;
            int rowHeight = 0;
            for (int col = 0; col < columns; col++) {
                TensorList tileTensor = inObj[inputIndex].getData();
                int[] tileDimensions = tileTensor.getDimensions();
                rowHeight = Math.max(rowHeight, tileDimensions[1]);
                Tensor inputData = tileTensor.get(dataIndex);
                ImgTileAssemblyLayer.copy(inputData, outputData, positionX, totalHeight);
                inputData.freeRef();
                positionX += tileDimensions[0];
                inputIndex += 1;
            }
            totalHeight += rowHeight;
            totalWidth = Math.max(totalWidth, positionX);
        }
        return outputData;
    }).toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        int totalHeight = 0;
        int inputIndex = 0;
        for (int row = 0; row < rows; row++) {
            int positionX = 0;
            int rowHeight = 0;
            for (int col = 0; col < columns; col++) {
                Result in = inObj[inputIndex];
                int[] inputDataDimensions = in.getData().getDimensions();
                rowHeight = Math.max(rowHeight, inputDataDimensions[1]);
                if (in.isAlive()) {
                    int _positionX = positionX;
                    int _totalHeight = totalHeight;
                    @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, delta.length()).parallel().mapToObj(dataIndex -> {
                        @Nullable final Tensor deltaTensor = delta.get(dataIndex);
                        @Nonnull final Tensor passbackTensor = new Tensor(inputDataDimensions);
                        ImgTileAssemblyLayer.copy(deltaTensor, passbackTensor, -_positionX, -_totalHeight);
                        deltaTensor.freeRef();
                        return passbackTensor;
                    }).toArray(i -> new Tensor[i]));
                    in.accumulate(buffer, tensorArray);
                }
                positionX += inputDataDimensions[0];
                inputIndex += 1;
            }
            totalHeight += rowHeight;
        }
    }) {

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
        }

        @Override
        public boolean isAlive() {
            return inObj[0].isAlive() || !isFrozen();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) ArrayList(java.util.ArrayList) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Example 22 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class L1NormalizationLayer method eval.

@Nonnull
@Override
public Result eval(@Nonnull final Result... input) {
    Arrays.stream(input).forEach(nnResult -> nnResult.addRef());
    final Result in = input[0];
    final TensorList inData = in.getData();
    inData.addRef();
    return new Result(TensorArray.wrap(IntStream.range(0, inData.length()).mapToObj(dataIndex -> {
        @Nullable final Tensor value = inData.get(dataIndex);
        try {
            final double sum = value.sum();
            if (!Double.isFinite(sum) || 0 == sum) {
                value.addRef();
                return value;
            } else {
                return value.scale(1.0 / sum);
            }
        } finally {
            value.freeRef();
        }
    }).toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList outDelta) -> {
        if (in.isAlive()) {
            final Tensor[] passbackArray = IntStream.range(0, outDelta.length()).mapToObj(dataIndex -> {
                Tensor inputTensor = inData.get(dataIndex);
                @Nullable final double[] value = inputTensor.getData();
                Tensor outputTensor = outDelta.get(dataIndex);
                @Nullable final double[] delta = outputTensor.getData();
                final double dot = ArrayUtil.dot(value, delta);
                final double sum = Arrays.stream(value).sum();
                @Nonnull final Tensor passback = new Tensor(outputTensor.getDimensions());
                @Nullable final double[] passbackData = passback.getData();
                if (0 != sum || Double.isFinite(sum)) {
                    for (int i = 0; i < value.length; i++) {
                        passbackData[i] = (delta[i] - dot / sum) / sum;
                    }
                }
                outputTensor.freeRef();
                inputTensor.freeRef();
                return passback;
            }).toArray(i -> new Tensor[i]);
            assert Arrays.stream(passbackArray).flatMapToDouble(x -> Arrays.stream(x.getData())).allMatch(v -> Double.isFinite(v));
            @Nonnull TensorArray tensorArray = TensorArray.wrap(passbackArray);
            in.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            inData.freeRef();
            Arrays.stream(input).forEach(nnResult -> nnResult.freeRef());
        }

        @Override
        public boolean isAlive() {
            return in.isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) ArrayUtil(com.simiacryptus.util.ArrayUtil) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Example 23 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class DAGNetwork method visitLayers.

/**
 * Visit layers.
 *
 * @param visitor the visitor
 */
public void visitLayers(@Nonnull final Consumer<Layer> visitor) {
    layersById.values().forEach(layer -> {
        Layer unwrapped = layer;
        while (unwrapped instanceof WrapperLayer) {
            unwrapped = ((WrapperLayer) unwrapped).getInner();
        }
        if (unwrapped instanceof DAGNetwork) {
            ((DAGNetwork) unwrapped).visitLayers(visitor);
        }
        visitor.accept(layer);
        while (layer instanceof WrapperLayer) {
            Layer inner = ((WrapperLayer) layer).getInner();
            visitor.accept(inner);
            layer = inner;
        }
    });
}
Also used : WrapperLayer(com.simiacryptus.mindseye.layers.java.WrapperLayer) Layer(com.simiacryptus.mindseye.lang.Layer) WrapperLayer(com.simiacryptus.mindseye.layers.java.WrapperLayer)

Example 24 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class DAGNetwork method add.

/**
 * Add dag node.
 *
 * @param label the label
 * @param layer the layer
 * @param head  the head
 * @return the dag node
 */
public InnerNode add(@Nullable final CharSequence label, @Nonnull final Layer layer, final DAGNode... head) {
    assertAlive();
    assertConsistent();
    assert null != getInput();
    @Nonnull final InnerNode node = new InnerNode(this, layer, head);
    synchronized (layersById) {
        if (!layersById.containsKey(layer.getId())) {
            Layer replaced = layersById.put(layer.getId(), layer);
            layer.addRef();
            if (null != replaced)
                replaced.freeRef();
        }
    }
    DAGNode replaced = nodesById.put(node.getId(), node);
    if (null != replaced)
        replaced.freeRef();
    if (null != label) {
        labels.put(label, node.getId());
    }
    assertConsistent();
    return node;
}
Also used : Nonnull(javax.annotation.Nonnull) Layer(com.simiacryptus.mindseye.lang.Layer) WrapperLayer(com.simiacryptus.mindseye.layers.java.WrapperLayer)

Example 25 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class LinearSumConstraintTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 10000);
        @Nonnull final TrustRegionStrategy trustRegionStrategy = new TrustRegionStrategy() {

            @Override
            public TrustRegion getRegionPolicy(final Layer layer) {
                return new LinearSumConstraint();
            }
        };
        return new IterativeTrainer(trainable).setIterationsPerSample(100).setMonitor(monitor).setOrientation(trustRegionStrategy).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Layer(com.simiacryptus.mindseye.lang.Layer) TrustRegionStrategy(com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy)

Aggregations

Layer (com.simiacryptus.mindseye.lang.Layer)167 Nonnull (javax.annotation.Nonnull)159 Nullable (javax.annotation.Nullable)128 Arrays (java.util.Arrays)117 Tensor (com.simiacryptus.mindseye.lang.Tensor)116 List (java.util.List)108 Result (com.simiacryptus.mindseye.lang.Result)103 IntStream (java.util.stream.IntStream)98 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)95 TensorList (com.simiacryptus.mindseye.lang.TensorList)93 Map (java.util.Map)83 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)76 Logger (org.slf4j.Logger)76 LoggerFactory (org.slf4j.LoggerFactory)76 JsonObject (com.google.gson.JsonObject)70 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)66 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)64 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)51 Collectors (java.util.stream.Collectors)42 Stream (java.util.stream.Stream)37