Search in sources :

Example 1 with CountingResult

use of com.simiacryptus.mindseye.network.CountingResult in project MindsEye by SimiaCryptus.

the class StochasticSamplingSubnetLayer method eval.

@Nullable
@Override
public Result eval(@Nonnull final Result... inObj) {
    Result[] counting = Arrays.stream(inObj).map(r -> {
        return new CountingResult(r, samples);
    }).toArray(i -> new Result[i]);
    return average(Arrays.stream(getSeeds()).mapToObj(seed -> {
        Layer inner = getInner();
        if (inner instanceof DAGNetwork) {
            ((DAGNetwork) inner).visitNodes(node -> {
                Layer layer = node.getLayer();
                if (layer instanceof StochasticComponent) {
                    ((StochasticComponent) layer).shuffle(seed);
                }
                if (layer instanceof MultiPrecision<?>) {
                    ((MultiPrecision) layer).setPrecision(precision);
                }
            });
        }
        if (inner instanceof MultiPrecision<?>) {
            ((MultiPrecision) inner).setPrecision(precision);
        }
        if (inner instanceof StochasticComponent) {
            ((StochasticComponent) inner).shuffle(seed);
        }
        inner.setFrozen(isFrozen());
        return inner.eval(counting);
    }).toArray(i -> new Result[i]), precision);
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) CountingResult(com.simiacryptus.mindseye.network.CountingResult) Tensor(com.simiacryptus.mindseye.lang.Tensor) Random(java.util.Random) WrapperLayer(com.simiacryptus.mindseye.layers.java.WrapperLayer) Result(com.simiacryptus.mindseye.lang.Result) ValueLayer(com.simiacryptus.mindseye.layers.java.ValueLayer) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) WrapperLayer(com.simiacryptus.mindseye.layers.java.WrapperLayer) ValueLayer(com.simiacryptus.mindseye.layers.java.ValueLayer) Layer(com.simiacryptus.mindseye.lang.Layer) CountingResult(com.simiacryptus.mindseye.network.CountingResult) Result(com.simiacryptus.mindseye.lang.Result) CountingResult(com.simiacryptus.mindseye.network.CountingResult) Nullable(javax.annotation.Nullable)

Aggregations

JsonObject (com.google.gson.JsonObject)1 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)1 Layer (com.simiacryptus.mindseye.lang.Layer)1 Result (com.simiacryptus.mindseye.lang.Result)1 Tensor (com.simiacryptus.mindseye.lang.Tensor)1 Precision (com.simiacryptus.mindseye.lang.cudnn.Precision)1 StochasticComponent (com.simiacryptus.mindseye.layers.java.StochasticComponent)1 ValueLayer (com.simiacryptus.mindseye.layers.java.ValueLayer)1 WrapperLayer (com.simiacryptus.mindseye.layers.java.WrapperLayer)1 CountingResult (com.simiacryptus.mindseye.network.CountingResult)1 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)1 DAGNode (com.simiacryptus.mindseye.network.DAGNode)1 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)1 Arrays (java.util.Arrays)1 Map (java.util.Map)1 Random (java.util.Random)1 IntStream (java.util.stream.IntStream)1 Nonnull (javax.annotation.Nonnull)1 Nullable (javax.annotation.Nullable)1