Search in sources :

Example 1 with ImgConcatLayer

use of com.simiacryptus.mindseye.layers.cudnn.ImgConcatLayer in project MindsEye by SimiaCryptus.

the class RescaledSubnetLayer method eval.

@Nullable
@Override
public Result eval(@Nonnull final Result... inObj) {
    assert 1 == inObj.length;
    final TensorList batch = inObj[0].getData();
    @Nonnull final int[] inputDims = batch.getDimensions();
    assert 3 == inputDims.length;
    if (1 == scale)
        return subnetwork.eval(inObj);
    @Nonnull final PipelineNetwork network = new PipelineNetwork();
    @Nullable final DAGNode condensed = network.wrap(new ImgReshapeLayer(scale, scale, false));
    network.wrap(new ImgConcatLayer(), IntStream.range(0, scale * scale).mapToObj(subband -> {
        @Nonnull final int[] select = new int[inputDims[2]];
        for (int i = 0; i < inputDims[2]; i++) {
            select[i] = subband * inputDims[2] + i;
        }
        return network.add(subnetwork, network.wrap(new ImgBandSelectLayer(select), condensed));
    }).toArray(i -> new DAGNode[i]));
    network.wrap(new ImgReshapeLayer(scale, scale, true));
    Result eval = network.eval(inObj);
    network.freeRef();
    return eval;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Result(com.simiacryptus.mindseye.lang.Result) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) ArrayList(java.util.ArrayList) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) ImgConcatLayer(com.simiacryptus.mindseye.layers.cudnn.ImgConcatLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) ImgConcatLayer(com.simiacryptus.mindseye.layers.cudnn.ImgConcatLayer) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) TensorList(com.simiacryptus.mindseye.lang.TensorList) DAGNode(com.simiacryptus.mindseye.network.DAGNode) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nullable(javax.annotation.Nullable)

Aggregations

JsonObject (com.google.gson.JsonObject)1 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)1 Layer (com.simiacryptus.mindseye.lang.Layer)1 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)1 Result (com.simiacryptus.mindseye.lang.Result)1 TensorList (com.simiacryptus.mindseye.lang.TensorList)1 ImgConcatLayer (com.simiacryptus.mindseye.layers.cudnn.ImgConcatLayer)1 DAGNode (com.simiacryptus.mindseye.network.DAGNode)1 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1 Map (java.util.Map)1 IntStream (java.util.stream.IntStream)1 Nonnull (javax.annotation.Nonnull)1 Nullable (javax.annotation.Nullable)1