Search in sources :

Example 1 with DoubleBuffer

use of com.simiacryptus.mindseye.lang.DoubleBuffer in project MindsEye by SimiaCryptus.

the class QuantifyOrientationWrapper method orient.

@Override
public LineSearchCursor orient(final Trainable subject, final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
    final LineSearchCursor cursor = inner.orient(subject, measurement, monitor);
    if (cursor instanceof SimpleLineSearchCursor) {
        final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) cursor).direction;
        @Nonnull final StateSet<Layer> weights = ((SimpleLineSearchCursor) cursor).origin.weights;
        final Map<CharSequence, CharSequence> dataMap = weights.stream().collect(Collectors.groupingBy(x -> getId(x), Collectors.toList())).entrySet().stream().collect(Collectors.toMap(x -> x.getKey(), list -> {
            final List<Double> doubleList = list.getValue().stream().map(weightDelta -> {
                final DoubleBuffer<Layer> dirDelta = direction.getMap().get(weightDelta.layer);
                final double denominator = weightDelta.deltaStatistics().rms();
                final double numerator = null == dirDelta ? 0 : dirDelta.deltaStatistics().rms();
                return numerator / (0 == denominator ? 1 : denominator);
            }).collect(Collectors.toList());
            if (1 == doubleList.size())
                return Double.toString(doubleList.get(0));
            return new DoubleStatistics().accept(doubleList.stream().mapToDouble(x -> x).toArray()).toString();
        }));
        monitor.log(String.format("Line search stats: %s", dataMap));
    } else {
        monitor.log(String.format("Non-simple cursor: %s", cursor));
    }
    return cursor;
}
Also used : DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Collectors(java.util.stream.Collectors) StateSet(com.simiacryptus.mindseye.lang.StateSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) List(java.util.List) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nonnull(javax.annotation.Nonnull) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) List(java.util.List)

Example 2 with DoubleBuffer

use of com.simiacryptus.mindseye.lang.DoubleBuffer in project MindsEye by SimiaCryptus.

the class TrustRegionStrategy method orient.

@Nonnull
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, final PointSample origin, final TrainingMonitor monitor) {
    history.add(0, origin);
    while (history.size() > maxHistory) {
        history.remove(history.size() - 1);
    }
    final SimpleLineSearchCursor cursor = inner.orient(subject, origin, monitor);
    return new LineSearchCursorBase() {

        @Nonnull
        @Override
        public CharSequence getDirectionType() {
            return cursor.getDirectionType() + "+Trust";
        }

        @Nonnull
        @Override
        public DeltaSet<Layer> position(final double alpha) {
            reset();
            @Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
            project(adjustedPosVector, new TrainingMonitor());
            return adjustedPosVector;
        }

        @Nonnull
        public DeltaSet<Layer> project(@Nonnull final DeltaSet<Layer> deltaIn, final TrainingMonitor monitor) {
            final DeltaSet<Layer> originalAlphaDerivative = cursor.direction;
            @Nonnull final DeltaSet<Layer> newAlphaDerivative = originalAlphaDerivative.copy();
            deltaIn.getMap().forEach((layer, buffer) -> {
                @Nullable final double[] delta = buffer.getDelta();
                if (null == delta)
                    return;
                final double[] currentPosition = buffer.target;
                @Nullable final double[] originalAlphaD = originalAlphaDerivative.get(layer, currentPosition).getDelta();
                @Nullable final double[] newAlphaD = newAlphaDerivative.get(layer, currentPosition).getDelta();
                @Nonnull final double[] proposedPosition = ArrayUtil.add(currentPosition, delta);
                final TrustRegion region = getRegionPolicy(layer);
                if (null != region) {
                    final Stream<double[]> zz = history.stream().map((@Nonnull final PointSample x) -> {
                        final DoubleBuffer<Layer> d = x.weights.getMap().get(layer);
                        @Nullable final double[] z = null == d ? null : d.getDelta();
                        return z;
                    });
                    final double[] projectedPosition = region.project(zz.filter(x -> null != x).toArray(i -> new double[i][]), proposedPosition);
                    if (projectedPosition != proposedPosition) {
                        for (int i = 0; i < projectedPosition.length; i++) {
                            delta[i] = projectedPosition[i] - currentPosition[i];
                        }
                        @Nonnull final double[] normal = ArrayUtil.subtract(projectedPosition, proposedPosition);
                        final double normalMagSq = ArrayUtil.dot(normal, normal);
                        // normalMagSq));
                        if (0 < normalMagSq) {
                            final double a = ArrayUtil.dot(originalAlphaD, normal);
                            if (a != -1) {
                                @Nonnull final double[] tangent = ArrayUtil.add(originalAlphaD, ArrayUtil.multiply(normal, -a / normalMagSq));
                                for (int i = 0; i < tangent.length; i++) {
                                    newAlphaD[i] = tangent[i];
                                }
                            // double newAlphaDerivSq = ArrayUtil.dot(tangent, tangent);
                            // double originalAlphaDerivSq = ArrayUtil.dot(originalAlphaD, originalAlphaD);
                            // assert(newAlphaDerivSq <= originalAlphaDerivSq);
                            // assert(Math.abs(ArrayUtil.dot(tangent, normal)) <= 1e-4);
                            // monitor.log(String.format("%s: normalMagSq = %s, newAlphaDerivSq = %s, originalAlphaDerivSq = %s", layer, normalMagSq, newAlphaDerivSq, originalAlphaDerivSq));
                            }
                        }
                    }
                }
            });
            return newAlphaDerivative;
        }

        @Override
        public void reset() {
            cursor.reset();
        }

        @Nonnull
        @Override
        public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
            cursor.reset();
            @Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
            @Nonnull final DeltaSet<Layer> adjustedGradient = project(adjustedPosVector, monitor);
            adjustedPosVector.accumulate(1);
            @Nonnull final PointSample sample = subject.measure(monitor).setRate(alpha);
            return new LineSearchPoint(sample, adjustedGradient.dot(sample.delta));
        }

        @Override
        public void _free() {
            cursor.freeRef();
        }
    };
}
Also used : TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) IntStream(java.util.stream.IntStream) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) ArrayUtil(com.simiacryptus.util.ArrayUtil) Trainable(com.simiacryptus.mindseye.eval.Trainable) List(java.util.List) Stream(java.util.stream.Stream) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) LinkedList(java.util.LinkedList) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) LineSearchCursorBase(com.simiacryptus.mindseye.opt.line.LineSearchCursorBase) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Layer(com.simiacryptus.mindseye.lang.Layer) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) PointSample(com.simiacryptus.mindseye.lang.PointSample) LineSearchCursorBase(com.simiacryptus.mindseye.opt.line.LineSearchCursorBase) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 3 with DoubleBuffer

use of com.simiacryptus.mindseye.lang.DoubleBuffer in project MindsEye by SimiaCryptus.

the class LayerRateDiagnosticTrainer method run.

/**
 * Run map.
 *
 * @return the map
 */
@Nonnull
public Map<Layer, LayerStats> run() {
    final long timeoutMs = System.currentTimeMillis() + timeout.toMillis();
    PointSample measure = measure();
    @Nonnull final ArrayList<Layer> layers = new ArrayList<>(measure.weights.getMap().keySet());
    while (timeoutMs > System.currentTimeMillis() && measure.sum > terminateThreshold) {
        if (currentIteration.get() > maxIterations) {
            break;
        }
        final PointSample initialPhasePoint = measure();
        measure = initialPhasePoint;
        for (int subiteration = 0; subiteration < iterationsPerSample; subiteration++) {
            if (currentIteration.incrementAndGet() > maxIterations) {
                break;
            }
            {
                @Nonnull final SimpleLineSearchCursor orient = (SimpleLineSearchCursor) getOrientation().orient(subject, measure, monitor);
                final double stepSize = 1e-12 * orient.origin.sum;
                @Nonnull final DeltaSet<Layer> pointB = orient.step(stepSize, monitor).point.delta.copy();
                @Nonnull final DeltaSet<Layer> pointA = orient.step(0.0, monitor).point.delta.copy();
                @Nonnull final DeltaSet<Layer> d1 = pointA;
                @Nonnull final DeltaSet<Layer> d2 = d1.add(pointB.scale(-1)).scale(1.0 / stepSize);
                @Nonnull final Map<Layer, Double> steps = new HashMap<>();
                final double overallStepEstimate = d1.getMagnitude() / d2.getMagnitude();
                for (final Layer layer : layers) {
                    final DoubleBuffer<Layer> a = d2.get(layer, (double[]) null);
                    final DoubleBuffer<Layer> b = d1.get(layer, (double[]) null);
                    final double bmag = Math.sqrt(b.deltaStatistics().sumSq());
                    final double amag = Math.sqrt(a.deltaStatistics().sumSq());
                    final double dot = a.dot(b) / (amag * bmag);
                    final double idealSize = bmag / (amag * dot);
                    steps.put(layer, idealSize);
                    monitor.log(String.format("Layers stats: %s (%s, %s, %s) => %s", layer, amag, bmag, dot, idealSize));
                }
                monitor.log(String.format("Estimated ideal rates for layers: %s (%s overall; probed at %s)", steps, overallStepEstimate, stepSize));
            }
            @Nullable SimpleLineSearchCursor bestOrient = null;
            @Nullable PointSample bestPoint = null;
            layerLoop: for (@Nonnull final Layer layer : layers) {
                @Nonnull SimpleLineSearchCursor orient = (SimpleLineSearchCursor) getOrientation().orient(subject, measure, monitor);
                @Nonnull final DeltaSet<Layer> direction = filterDirection(orient.direction, layer);
                if (direction.getMagnitude() == 0) {
                    monitor.log(String.format("Zero derivative for layer %s; skipping", layer));
                    continue layerLoop;
                }
                orient = new SimpleLineSearchCursor(orient.subject, orient.origin, direction);
                final PointSample previous = measure;
                measure = getLineSearchStrategy().step(orient, monitor);
                if (isStrict()) {
                    monitor.log(String.format("Iteration %s reverting. Error: %s", currentIteration.get(), measure.sum));
                    monitor.log(String.format("Optimal rate for layer %s: %s", layer.getName(), measure.getRate()));
                    if (null == bestPoint || bestPoint.sum < measure.sum) {
                        bestOrient = orient;
                        bestPoint = measure;
                    }
                    getLayerRates().put(layer, new LayerStats(measure.getRate(), initialPhasePoint.sum - measure.sum));
                    orient.step(0, monitor);
                    measure = previous;
                } else if (previous.sum == measure.sum) {
                    monitor.log(String.format("Iteration %s failed. Error: %s", currentIteration.get(), measure.sum));
                } else {
                    monitor.log(String.format("Iteration %s complete. Error: %s", currentIteration.get(), measure.sum));
                    monitor.log(String.format("Optimal rate for layer %s: %s", layer.getName(), measure.getRate()));
                    getLayerRates().put(layer, new LayerStats(measure.getRate(), initialPhasePoint.sum - measure.sum));
                }
            }
            monitor.log(String.format("Ideal rates: %s", getLayerRates()));
            if (null != bestPoint) {
                bestOrient.step(bestPoint.rate, monitor);
            }
            monitor.onStepComplete(new Step(measure, currentIteration.get()));
        }
    }
    return getLayerRates();
}
Also used : DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) PointSample(com.simiacryptus.mindseye.lang.PointSample) HashMap(java.util.HashMap) Map(java.util.Map) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 4 with DoubleBuffer

use of com.simiacryptus.mindseye.lang.DoubleBuffer in project MindsEye by SimiaCryptus.

the class SingleDerivativeTester method getLearningGradient.

@Nonnull
private Tensor getLearningGradient(@Nonnull final Layer component, final int layerNum, @Nonnull final Tensor outputPrototype, final Tensor... inputPrototype) {
    component.setFrozen(false);
    final double[] stateArray = component.state().get(layerNum);
    final int stateLen = stateArray.length;
    @Nonnull final Tensor gradient = new Tensor(stateLen, outputPrototype.length());
    for (int j = 0; j < outputPrototype.length(); j++) {
        final int j_ = j;
        @Nonnull final DeltaSet<Layer> buffer = new DeltaSet<Layer>();
        Result[] array = ConstantResult.batchResultArray(new Tensor[][] { inputPrototype });
        @Nullable final Result eval = component.eval(array);
        for (@Nonnull Result result : array) {
            result.getData().freeRef();
            result.freeRef();
        }
        @Nonnull TensorArray tensorArray = TensorArray.wrap(new Tensor(outputPrototype.getDimensions()).set((k) -> k == j_ ? 1 : 0));
        eval.accumulate(buffer, tensorArray);
        eval.getData().freeRef();
        eval.freeRef();
        final DoubleBuffer<Layer> deltaFlushBuffer = buffer.getMap().values().stream().filter(x -> x.target == stateArray).findFirst().orElse(null);
        if (null != deltaFlushBuffer) {
            for (int i = 0; i < stateLen; i++) {
                gradient.set(new int[] { i, j_ }, deltaFlushBuffer.getDelta()[i]);
            }
        }
        buffer.freeRef();
    }
    return gradient;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Result(com.simiacryptus.mindseye.lang.Result) Collectors(java.util.stream.Collectors) Delta(com.simiacryptus.mindseye.lang.Delta) List(java.util.List) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Optional(java.util.Optional) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Aggregations

DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)4 DoubleBuffer (com.simiacryptus.mindseye.lang.DoubleBuffer)4 Layer (com.simiacryptus.mindseye.lang.Layer)4 Nonnull (javax.annotation.Nonnull)4 PointSample (com.simiacryptus.mindseye.lang.PointSample)3 SimpleLineSearchCursor (com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor)3 List (java.util.List)3 Nullable (javax.annotation.Nullable)3 Trainable (com.simiacryptus.mindseye.eval.Trainable)2 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)2 LineSearchCursor (com.simiacryptus.mindseye.opt.line.LineSearchCursor)2 Map (java.util.Map)2 Collectors (java.util.stream.Collectors)2 IntStream (java.util.stream.IntStream)2 ConstantResult (com.simiacryptus.mindseye.lang.ConstantResult)1 Delta (com.simiacryptus.mindseye.lang.Delta)1 Result (com.simiacryptus.mindseye.lang.Result)1 StateSet (com.simiacryptus.mindseye.lang.StateSet)1 Tensor (com.simiacryptus.mindseye.lang.Tensor)1 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)1