Search in sources :

Example 1 with SumInputsLayer

use of com.simiacryptus.mindseye.layers.cudnn.SumInputsLayer in project MindsEye by SimiaCryptus.

the class StochasticSamplingSubnetLayer method average.

/**
 * Average result.
 *
 * @param samples the samples
 * @return the result
 */
public static Result average(final Result[] samples) {
    PipelineNetwork gateNetwork = new PipelineNetwork(1);
    gateNetwork.wrap(new ProductLayer(), gateNetwork.getInput(0), gateNetwork.wrap(new ValueLayer(new Tensor(1, 1, 1).mapAndFree(v -> 1.0 / samples.length)), new DAGNode[] {}));
    SumInputsLayer sumInputsLayer = new SumInputsLayer();
    try {
        return gateNetwork.evalAndFree(sumInputsLayer.evalAndFree(samples));
    } finally {
        sumInputsLayer.freeRef();
        gateNetwork.freeRef();
    }
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) CountingResult(com.simiacryptus.mindseye.network.CountingResult) SumInputsLayer(com.simiacryptus.mindseye.layers.cudnn.SumInputsLayer) Tensor(com.simiacryptus.mindseye.lang.Tensor) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) ArrayList(java.util.ArrayList) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) ProductLayer(com.simiacryptus.mindseye.layers.cudnn.ProductLayer) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) ProductLayer(com.simiacryptus.mindseye.layers.cudnn.ProductLayer) Tensor(com.simiacryptus.mindseye.lang.Tensor) SumInputsLayer(com.simiacryptus.mindseye.layers.cudnn.SumInputsLayer) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) DAGNode(com.simiacryptus.mindseye.network.DAGNode)

Aggregations

JsonObject (com.google.gson.JsonObject)1 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)1 Layer (com.simiacryptus.mindseye.lang.Layer)1 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)1 Result (com.simiacryptus.mindseye.lang.Result)1 Tensor (com.simiacryptus.mindseye.lang.Tensor)1 ProductLayer (com.simiacryptus.mindseye.layers.cudnn.ProductLayer)1 SumInputsLayer (com.simiacryptus.mindseye.layers.cudnn.SumInputsLayer)1 CountingResult (com.simiacryptus.mindseye.network.CountingResult)1 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)1 DAGNode (com.simiacryptus.mindseye.network.DAGNode)1 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)1 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 List (java.util.List)1 Map (java.util.Map)1 Random (java.util.Random)1 IntStream (java.util.stream.IntStream)1 Nonnull (javax.annotation.Nonnull)1 Nullable (javax.annotation.Nullable)1