Search in sources :

Example 1 with ValueLayer

use of com.simiacryptus.mindseye.layers.java.ValueLayer in project MindsEye by SimiaCryptus.

the class StochasticSamplingSubnetLayer method average.

/**
 * Average result.
 *
 * @param samples   the samples
 * @param precision the precision
 * @return the result
 */
public static Result average(final Result[] samples, final Precision precision) {
    PipelineNetwork gateNetwork = new PipelineNetwork(1);
    gateNetwork.wrap(new ProductLayer().setPrecision(precision), gateNetwork.getInput(0), gateNetwork.wrap(new ValueLayer(new Tensor(1, 1, 1).mapAndFree(v -> 1.0 / samples.length)), new DAGNode[] {}));
    SumInputsLayer sumInputsLayer = new SumInputsLayer().setPrecision(precision);
    try {
        return gateNetwork.evalAndFree(sumInputsLayer.evalAndFree(samples));
    } finally {
        sumInputsLayer.freeRef();
        gateNetwork.freeRef();
    }
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) CountingResult(com.simiacryptus.mindseye.network.CountingResult) Tensor(com.simiacryptus.mindseye.lang.Tensor) Random(java.util.Random) WrapperLayer(com.simiacryptus.mindseye.layers.java.WrapperLayer) Result(com.simiacryptus.mindseye.lang.Result) ValueLayer(com.simiacryptus.mindseye.layers.java.ValueLayer) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) ValueLayer(com.simiacryptus.mindseye.layers.java.ValueLayer) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) DAGNode(com.simiacryptus.mindseye.network.DAGNode)

Aggregations

JsonObject (com.google.gson.JsonObject)1 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)1 Layer (com.simiacryptus.mindseye.lang.Layer)1 Result (com.simiacryptus.mindseye.lang.Result)1 Tensor (com.simiacryptus.mindseye.lang.Tensor)1 Precision (com.simiacryptus.mindseye.lang.cudnn.Precision)1 StochasticComponent (com.simiacryptus.mindseye.layers.java.StochasticComponent)1 ValueLayer (com.simiacryptus.mindseye.layers.java.ValueLayer)1 WrapperLayer (com.simiacryptus.mindseye.layers.java.WrapperLayer)1 CountingResult (com.simiacryptus.mindseye.network.CountingResult)1 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)1 DAGNode (com.simiacryptus.mindseye.network.DAGNode)1 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)1 Arrays (java.util.Arrays)1 Map (java.util.Map)1 Random (java.util.Random)1 IntStream (java.util.stream.IntStream)1 Nonnull (javax.annotation.Nonnull)1 Nullable (javax.annotation.Nullable)1