Search in sources :

Example 11 with DAGNetwork

use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.

the class ValidatingTrainer method run.

/**
 * Run double.
 *
 * @return the double
 */
public double run() {
    try {
        final long timeoutAt = System.currentTimeMillis() + timeout.toMillis();
        if (validationSubject.getLayer() instanceof DAGNetwork) {
            ((DAGNetwork) validationSubject.getLayer()).visitLayers(layer -> {
                if (layer instanceof StochasticComponent)
                    ((StochasticComponent) layer).clearNoise();
            });
        }
        @Nonnull final EpochParams epochParams = new EpochParams(timeoutAt, epochIterations, getTrainingSize(), validationSubject.measure(monitor));
        int epochNumber = 0;
        int iterationNumber = 0;
        int lastImprovement = 0;
        double lowestValidation = Double.POSITIVE_INFINITY;
        while (true) {
            if (shouldHalt(monitor, timeoutAt)) {
                monitor.log("Training halted");
                break;
            }
            monitor.log(String.format("Epoch parameters: %s, %s", epochParams.trainingSize, epochParams.iterations));
            @Nonnull final List<TrainingPhase> regimen = getRegimen();
            final long seed = System.nanoTime();
            final List<EpochResult> epochResults = IntStream.range(0, regimen.size()).mapToObj(i -> {
                final TrainingPhase phase = getRegimen().get(i);
                return runPhase(epochParams, phase, i, seed);
            }).collect(Collectors.toList());
            final EpochResult primaryPhase = epochResults.get(0);
            iterationNumber += primaryPhase.iterations;
            final double trainingDelta = primaryPhase.currentPoint.getMean() / primaryPhase.priorMean;
            if (validationSubject.getLayer() instanceof DAGNetwork) {
                ((DAGNetwork) validationSubject.getLayer()).visitLayers(layer -> {
                    if (layer instanceof StochasticComponent)
                        ((StochasticComponent) layer).clearNoise();
                });
            }
            final PointSample currentValidation = validationSubject.measure(monitor);
            final double overtraining = Math.log(trainingDelta) / Math.log(currentValidation.getMean() / epochParams.validation.getMean());
            final double validationDelta = currentValidation.getMean() / epochParams.validation.getMean();
            final double adj1 = Math.pow(Math.log(getTrainingTarget()) / Math.log(validationDelta), adjustmentFactor);
            final double adj2 = Math.pow(overtraining / getOvertrainingTarget(), adjustmentFactor);
            final double validationMean = currentValidation.getMean();
            if (validationMean < lowestValidation) {
                lowestValidation = validationMean;
                lastImprovement = iterationNumber;
            }
            monitor.log(String.format("Epoch %d result apply %s iterations, %s/%s samples: {validation *= 2^%.5f; training *= 2^%.3f; Overtraining = %.2f}, {itr*=%.2f, len*=%.2f} %s since improvement; %.4f validation time", ++epochNumber, primaryPhase.iterations, epochParams.trainingSize, getMaxTrainingSize(), Math.log(validationDelta) / Math.log(2), Math.log(trainingDelta) / Math.log(2), overtraining, adj1, adj2, iterationNumber - lastImprovement, validatingMeasurementTime.getAndSet(0) / 1e9));
            if (!primaryPhase.continueTraining) {
                monitor.log(String.format("Training %d runPhase halted", epochNumber));
                break;
            }
            if (epochParams.trainingSize >= getMaxTrainingSize()) {
                final double roll = FastRandom.INSTANCE.random();
                if (roll > Math.pow(2 - validationDelta, pessimism)) {
                    monitor.log(String.format("Training randomly converged: %3f", roll));
                    break;
                } else {
                    if (iterationNumber - lastImprovement > improvmentStaleThreshold) {
                        if (disappointments.incrementAndGet() > getDisappointmentThreshold()) {
                            monitor.log(String.format("Training converged after %s iterations", iterationNumber - lastImprovement));
                            break;
                        } else {
                            monitor.log(String.format("Training failed to converged on %s attempt after %s iterations", disappointments.get(), iterationNumber - lastImprovement));
                        }
                    } else {
                        disappointments.set(0);
                    }
                }
            }
            if (validationDelta < 1.0 && trainingDelta < 1.0) {
                if (adj1 < 1 - adjustmentTolerance || adj1 > 1 + adjustmentTolerance) {
                    epochParams.iterations = Math.max(getMinEpochIterations(), Math.min(getMaxEpochIterations(), (int) (primaryPhase.iterations * adj1)));
                }
                if (adj2 < 1 + adjustmentTolerance || adj2 > 1 - adjustmentTolerance) {
                    epochParams.trainingSize = Math.max(0, Math.min(Math.max(getMinTrainingSize(), Math.min(getMaxTrainingSize(), (int) (epochParams.trainingSize * adj2))), epochParams.trainingSize));
                }
            } else {
                epochParams.trainingSize = Math.max(0, Math.min(Math.max(getMinTrainingSize(), Math.min(getMaxTrainingSize(), epochParams.trainingSize * 5)), epochParams.trainingSize));
                epochParams.iterations = 1;
            }
            epochParams.validation = currentValidation;
        }
        if (validationSubject.getLayer() instanceof DAGNetwork) {
            ((DAGNetwork) validationSubject.getLayer()).visitLayers(layer -> {
                if (layer instanceof StochasticComponent)
                    ((StochasticComponent) layer).clearNoise();
            });
        }
        return epochParams.validation.getMean();
    } catch (@Nonnull final Throwable e) {
        throw new RuntimeException(e);
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) TemporalUnit(java.time.temporal.TemporalUnit) LineSearchStrategy(com.simiacryptus.mindseye.opt.line.LineSearchStrategy) SampledTrainable(com.simiacryptus.mindseye.eval.SampledTrainable) TrainableBase(com.simiacryptus.mindseye.eval.TrainableBase) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) HashMap(java.util.HashMap) SampledCachedTrainable(com.simiacryptus.mindseye.eval.SampledCachedTrainable) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) Function(java.util.function.Function) StateSet(com.simiacryptus.mindseye.lang.StateSet) ArrayList(java.util.ArrayList) TrainableWrapper(com.simiacryptus.mindseye.eval.TrainableWrapper) Trainable(com.simiacryptus.mindseye.eval.Trainable) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Duration(java.time.Duration) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) ManagementFactory(java.lang.management.ManagementFactory) FailsafeLineSearchCursor(com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) IterativeStopException(com.simiacryptus.mindseye.lang.IterativeStopException) Util(com.simiacryptus.util.Util) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) QQN(com.simiacryptus.mindseye.opt.orient.QQN) OrientationStrategy(com.simiacryptus.mindseye.opt.orient.OrientationStrategy) FastRandom(com.simiacryptus.util.FastRandom) Collectors(java.util.stream.Collectors) TimeUnit(java.util.concurrent.TimeUnit) AtomicLong(java.util.concurrent.atomic.AtomicLong) List(java.util.List) ChronoUnit(java.time.temporal.ChronoUnit) TimedResult(com.simiacryptus.util.lang.TimedResult) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nonnull(javax.annotation.Nonnull) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) PointSample(com.simiacryptus.mindseye.lang.PointSample)

Example 12 with DAGNetwork

use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.

the class StochasticSamplingSubnetLayer method eval.

@Nullable
@Override
public Result eval(@Nonnull final Result... inObj) {
    Result[] counting = Arrays.stream(inObj).map(r -> {
        return new CountingResult(r, samples);
    }).toArray(i -> new Result[i]);
    return average(Arrays.stream(getSeeds()).mapToObj(seed -> {
        Layer inner = getInner();
        if (inner instanceof DAGNetwork) {
            ((DAGNetwork) inner).visitNodes(node -> {
                Layer layer = node.getLayer();
                if (layer instanceof StochasticComponent) {
                    ((StochasticComponent) layer).shuffle(seed);
                }
                if (layer instanceof MultiPrecision<?>) {
                    ((MultiPrecision) layer).setPrecision(precision);
                }
            });
        }
        if (inner instanceof MultiPrecision<?>) {
            ((MultiPrecision) inner).setPrecision(precision);
        }
        if (inner instanceof StochasticComponent) {
            ((StochasticComponent) inner).shuffle(seed);
        }
        inner.setFrozen(isFrozen());
        return inner.eval(counting);
    }).toArray(i -> new Result[i]), precision);
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) CountingResult(com.simiacryptus.mindseye.network.CountingResult) Tensor(com.simiacryptus.mindseye.lang.Tensor) Random(java.util.Random) WrapperLayer(com.simiacryptus.mindseye.layers.java.WrapperLayer) Result(com.simiacryptus.mindseye.lang.Result) ValueLayer(com.simiacryptus.mindseye.layers.java.ValueLayer) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) WrapperLayer(com.simiacryptus.mindseye.layers.java.WrapperLayer) ValueLayer(com.simiacryptus.mindseye.layers.java.ValueLayer) Layer(com.simiacryptus.mindseye.lang.Layer) CountingResult(com.simiacryptus.mindseye.network.CountingResult) Result(com.simiacryptus.mindseye.lang.Result) CountingResult(com.simiacryptus.mindseye.network.CountingResult) Nullable(javax.annotation.Nullable)

Example 13 with DAGNetwork

use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.

the class ExplodedConvolutionLeg method add.

/**
 * Add dag node.
 *
 * @param input the input
 * @return the dag node
 */
public DAGNode add(@Nonnull final DAGNode input) {
    assertAlive();
    DAGNetwork network = input.getNetwork();
    DAGNode head = input;
    final int[] filterDimensions = this.convolutionParams.masterFilterDimensions;
    if (getInputBands() == this.convolutionParams.outputBands) {
        assert 1 == subLayers.size();
        head = network.add(subLayers.get(0), head);
    } else {
        head = network.wrap(new ImgConcatLayer().setMaxBands(this.convolutionParams.outputBands).setPrecision(this.convolutionParams.precision).setParallel(CudaSettings.INSTANCE.isConv_para_2()), subLayers.stream().map(l -> network.add(l, input)).toArray(i -> new DAGNode[i])).setParallel(CudaSettings.INSTANCE.isConv_para_2());
    }
    return head;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) CudaSettings(com.simiacryptus.mindseye.lang.cudnn.CudaSettings) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) DAGNode(com.simiacryptus.mindseye.network.DAGNode) Function(java.util.function.Function) ArrayList(java.util.ArrayList) Delta(com.simiacryptus.mindseye.lang.Delta) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) List(java.util.List) Layer(com.simiacryptus.mindseye.lang.Layer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) DAGNode(com.simiacryptus.mindseye.network.DAGNode)

Example 14 with DAGNetwork

use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.

the class ExplodedConvolutionGrid method add.

/**
 * Add dag node.
 *
 * @param input the input
 * @return the dag node
 */
public DAGNode add(@Nonnull DAGNode input) {
    assertAlive();
    DAGNetwork network = input.getNetwork();
    int defaultPaddingX = 0;
    int defaultPaddingY = 0;
    boolean customPaddingX = this.convolutionParams.paddingX != null && convolutionParams.paddingX != defaultPaddingX;
    boolean customPaddingY = this.convolutionParams.paddingY != null && convolutionParams.paddingY != defaultPaddingY;
    final DAGNode paddedInput;
    if (customPaddingX || customPaddingY) {
        int x;
        if (this.convolutionParams.paddingX < -defaultPaddingX) {
            x = this.convolutionParams.paddingX + defaultPaddingX;
        } else if (this.convolutionParams.paddingX > defaultPaddingX) {
            x = this.convolutionParams.paddingX - defaultPaddingX;
        } else {
            x = 0;
        }
        int y;
        if (this.convolutionParams.paddingY < -defaultPaddingY) {
            y = this.convolutionParams.paddingY + defaultPaddingY;
        } else if (this.convolutionParams.paddingY > defaultPaddingY) {
            y = this.convolutionParams.paddingY - defaultPaddingY;
        } else {
            y = 0;
        }
        if (x != 0 || y != 0) {
            paddedInput = network.wrap(new ImgZeroPaddingLayer(x, y).setPrecision(convolutionParams.precision), input);
        } else {
            paddedInput = input;
        }
    } else {
        paddedInput = input;
    }
    InnerNode output;
    if (subLayers.size() == 1) {
        output = (InnerNode) subLayers.get(0).add(paddedInput);
    } else {
        ImgLinearSubnetLayer linearSubnetLayer = new ImgLinearSubnetLayer();
        subLayers.forEach(leg -> {
            PipelineNetwork subnet = new PipelineNetwork();
            leg.add(subnet.getHead());
            linearSubnetLayer.add(leg.fromBand, leg.toBand, subnet);
        });
        boolean isParallel = CudaSettings.INSTANCE.isConv_para_1();
        linearSubnetLayer.setPrecision(convolutionParams.precision).setParallel(isParallel);
        output = network.wrap(linearSubnetLayer, paddedInput).setParallel(isParallel);
    }
    if (customPaddingX || customPaddingY) {
        int x = !customPaddingX ? 0 : (this.convolutionParams.paddingX - defaultPaddingX);
        int y = !customPaddingY ? 0 : (this.convolutionParams.paddingY - defaultPaddingY);
        if (x > 0)
            x = 0;
        if (y > 0)
            y = 0;
        if (x != 0 || y != 0) {
            return network.wrap(new ImgZeroPaddingLayer(x, y).setPrecision(convolutionParams.precision), output);
        }
    }
    return output;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) DAGNode(com.simiacryptus.mindseye.network.DAGNode) InnerNode(com.simiacryptus.mindseye.network.InnerNode)

Example 15 with DAGNetwork

use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.

the class ImageClassifier method add.

/**
 * Add.
 *
 * @param layer the layer
 * @param model the model
 * @return the layer
 */
@Nonnull
protected static Layer add(@Nonnull Layer layer, @Nonnull PipelineNetwork model) {
    name(layer);
    if (layer instanceof Explodable) {
        Layer explode = ((Explodable) layer).explode();
        try {
            if (explode instanceof DAGNetwork) {
                ((DAGNetwork) explode).visitNodes(node -> name(node.getLayer()));
                log.info(String.format("Exploded %s to %s (%s nodes)", layer.getName(), explode.getClass().getSimpleName(), ((DAGNetwork) explode).getNodes().size()));
            } else {
                log.info(String.format("Exploded %s to %s (%s nodes)", layer.getName(), explode.getClass().getSimpleName(), explode.getName()));
            }
            return add(explode, model);
        } finally {
            layer.freeRef();
        }
    } else {
        model.wrap(layer);
        return layer;
    }
}
Also used : Explodable(com.simiacryptus.mindseye.layers.cudnn.Explodable) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) ActivationLayer(com.simiacryptus.mindseye.layers.cudnn.ActivationLayer) SimpleConvolutionLayer(com.simiacryptus.mindseye.layers.cudnn.SimpleConvolutionLayer) FullyConnectedLayer(com.simiacryptus.mindseye.layers.cudnn.FullyConnectedLayer) Layer(com.simiacryptus.mindseye.lang.Layer) ConvolutionLayer(com.simiacryptus.mindseye.layers.cudnn.ConvolutionLayer) LinearActivationLayer(com.simiacryptus.mindseye.layers.java.LinearActivationLayer) BiasLayer(com.simiacryptus.mindseye.layers.java.BiasLayer) Nonnull(javax.annotation.Nonnull)

Aggregations

DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)19 Nonnull (javax.annotation.Nonnull)16 Layer (com.simiacryptus.mindseye.lang.Layer)11 Nullable (javax.annotation.Nullable)11 Tensor (com.simiacryptus.mindseye.lang.Tensor)10 Arrays (java.util.Arrays)10 List (java.util.List)10 ArrayList (java.util.ArrayList)9 StochasticComponent (com.simiacryptus.mindseye.layers.java.StochasticComponent)7 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)7 Map (java.util.Map)7 TimeUnit (java.util.concurrent.TimeUnit)7 DAGNode (com.simiacryptus.mindseye.network.DAGNode)6 TestUtil (com.simiacryptus.mindseye.test.TestUtil)6 IntStream (java.util.stream.IntStream)6 Format (guru.nidi.graphviz.engine.Format)5 Graphviz (guru.nidi.graphviz.engine.Graphviz)5 File (java.io.File)5 IOException (java.io.IOException)5 HashMap (java.util.HashMap)5