Search in sources :

Example 1 with ImgBandSelectLayer

use of com.simiacryptus.mindseye.layers.java.ImgBandSelectLayer in project MindsEye by SimiaCryptus.

the class EncodingUtil method downExplodeTensors.

/**
 * Down explode tensors stream.
 *
 * @param stream the stream
 * @param factor the factor
 * @return the stream
 */
@Nonnull
public static Stream<Tensor[]> downExplodeTensors(@Nonnull final Stream<Tensor[]> stream, final int factor) {
    if (0 >= factor)
        throw new IllegalArgumentException();
    if (-1 == factor)
        throw new IllegalArgumentException();
    return 1 == factor ? stream : stream.flatMap(tensor -> IntStream.range(0, factor * factor).mapToObj(subband -> {
        @Nonnull final int[] select = new int[tensor[1].getDimensions()[2]];
        final int offset = subband * select.length;
        for (int i = 0; i < select.length; i++) {
            select[i] = offset + i;
        }
        @Nonnull final PipelineNetwork network = new PipelineNetwork();
        network.add(new ImgReshapeLayer(factor, factor, false));
        network.add(new ImgBandSelectLayer(select));
        @Nullable final Tensor result = network.eval(tensor[1]).getData().get(0);
        return new Tensor[] { tensor[0], result };
    }));
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) Arrays(java.util.Arrays) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) GifSequenceWriter(com.simiacryptus.util.io.GifSequenceWriter) TableOutput(com.simiacryptus.util.TableOutput) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Caltech101(com.simiacryptus.mindseye.test.data.Caltech101) Function(java.util.function.Function) LinkedHashMap(java.util.LinkedHashMap) ImgBandScaleLayer(com.simiacryptus.mindseye.layers.java.ImgBandScaleLayer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) ImageIO(javax.imageio.ImageIO) Layer(com.simiacryptus.mindseye.lang.Layer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) PCAUtil(com.simiacryptus.mindseye.test.PCAUtil) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) PrintStream(java.io.PrintStream) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Logger(org.slf4j.Logger) SysOutInterceptor(com.simiacryptus.util.test.SysOutInterceptor) BufferedImage(java.awt.image.BufferedImage) ImgBandSelectLayer(com.simiacryptus.mindseye.layers.java.ImgBandSelectLayer) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) FastRandom(com.simiacryptus.util.FastRandom) Collectors(java.util.stream.Collectors) File(java.io.File) DoubleStream(java.util.stream.DoubleStream) ConvolutionLayer(com.simiacryptus.mindseye.layers.cudnn.ConvolutionLayer) List(java.util.List) Stream(java.util.stream.Stream) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) ToDoubleFunction(java.util.function.ToDoubleFunction) ImgReshapeLayer(com.simiacryptus.mindseye.layers.java.ImgReshapeLayer) ImgBandBiasLayer(com.simiacryptus.mindseye.layers.java.ImgBandBiasLayer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Step(com.simiacryptus.mindseye.opt.Step) Comparator(java.util.Comparator) ImgReshapeLayer(com.simiacryptus.mindseye.layers.java.ImgReshapeLayer) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) ImgBandSelectLayer(com.simiacryptus.mindseye.layers.java.ImgBandSelectLayer) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Aggregations

Coordinate (com.simiacryptus.mindseye.lang.Coordinate)1 Layer (com.simiacryptus.mindseye.lang.Layer)1 Tensor (com.simiacryptus.mindseye.lang.Tensor)1 ConvolutionLayer (com.simiacryptus.mindseye.layers.cudnn.ConvolutionLayer)1 ImgBandBiasLayer (com.simiacryptus.mindseye.layers.java.ImgBandBiasLayer)1 ImgBandScaleLayer (com.simiacryptus.mindseye.layers.java.ImgBandScaleLayer)1 ImgBandSelectLayer (com.simiacryptus.mindseye.layers.java.ImgBandSelectLayer)1 ImgReshapeLayer (com.simiacryptus.mindseye.layers.java.ImgReshapeLayer)1 MeanSqLossLayer (com.simiacryptus.mindseye.layers.java.MeanSqLossLayer)1 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)1 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)1 Step (com.simiacryptus.mindseye.opt.Step)1 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)1 PCAUtil (com.simiacryptus.mindseye.test.PCAUtil)1 StepRecord (com.simiacryptus.mindseye.test.StepRecord)1 TestUtil (com.simiacryptus.mindseye.test.TestUtil)1 Caltech101 (com.simiacryptus.mindseye.test.data.Caltech101)1 FastRandom (com.simiacryptus.util.FastRandom)1 TableOutput (com.simiacryptus.util.TableOutput)1 DoubleStatistics (com.simiacryptus.util.data.DoubleStatistics)1