Search in sources :

Example 16 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class QuadraticSearch method _step.

/**
 * Step point sample.
 *
 * @param cursor  the cursor
 * @param monitor the monitor
 * @return the point sample
 */
public PointSample _step(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor) {
    double thisX = 0;
    LineSearchPoint thisPoint = cursor.step(thisX, monitor);
    final LineSearchPoint initialPoint = thisPoint;
    initialPoint.addRef();
    double leftX = thisX;
    LineSearchPoint leftPoint = thisPoint;
    leftPoint.addRef();
    monitor.log(String.format("F(%s) = %s", leftX, leftPoint));
    if (0 == leftPoint.derivative) {
        initialPoint.freeRef();
        thisPoint.freeRef();
        PointSample point = leftPoint.point;
        point.addRef();
        leftPoint.freeRef();
        return point;
    }
    @Nonnull final LocateInitialRightPoint locateInitialRightPoint = new LocateInitialRightPoint(cursor, monitor, leftPoint).apply();
    @Nonnull LineSearchPoint rightPoint = locateInitialRightPoint.getRightPoint();
    rightPoint.addRef();
    double rightX = locateInitialRightPoint.getRightX();
    try {
        int loops = 0;
        while (true) {
            final double a = (rightPoint.derivative - leftPoint.derivative) / (rightX - leftX);
            final double b = rightPoint.derivative - a * rightX;
            thisX = -b / a;
            final boolean isBracketed = Math.signum(leftPoint.derivative) != Math.signum(rightPoint.derivative);
            if (!Double.isFinite(thisX) || isBracketed && (leftX > thisX || rightX < thisX)) {
                thisX = (rightX + leftX) / 2;
            }
            if (!isBracketed && thisX < 0) {
                thisX = rightX * 2;
            }
            if (isSame(leftX, thisX, 1.0)) {
                monitor.log(String.format("Converged to left"));
                return filter(cursor, leftPoint.point, monitor);
            } else if (isSame(thisX, rightX, 1.0)) {
                monitor.log(String.format("Converged to right"));
                return filter(cursor, rightPoint.point, monitor);
            }
            thisPoint.freeRef();
            thisPoint = null;
            thisPoint = cursor.step(thisX, monitor);
            if (isSame(cursor, monitor, leftPoint, thisPoint)) {
                monitor.log(String.format("%s ~= %s", leftX, thisX));
                return filter(cursor, leftPoint.point, monitor);
            }
            if (isSame(cursor, monitor, thisPoint, rightPoint)) {
                monitor.log(String.format("%s ~= %s", thisX, rightX));
                return filter(cursor, rightPoint.point, monitor);
            }
            thisPoint.freeRef();
            thisPoint = null;
            thisPoint = cursor.step(thisX, monitor);
            boolean isLeft;
            if (!isBracketed) {
                isLeft = Math.abs(rightPoint.point.rate - thisPoint.point.rate) > Math.abs(leftPoint.point.rate - thisPoint.point.rate);
            } else {
                isLeft = thisPoint.derivative < 0;
            }
            // monitor.log(String.format("isLeft=%s; isBracketed=%s; leftPoint=%s; rightPoint=%s", isLeft, isBracketed, leftPoint, rightPoint));
            monitor.log(String.format("F(%s) = %s, delta = %s", thisX, thisPoint, thisPoint.point.getMean() - initialPoint.point.getMean()));
            if (loops++ > 10) {
                monitor.log(String.format("Loops = %s", loops));
                PointSample filter = filter(cursor, thisPoint.point, monitor);
                return filter;
            }
            if (isSame(cursor, monitor, leftPoint, rightPoint)) {
                monitor.log(String.format("%s ~= %s", leftX, rightX));
                PointSample filter = filter(cursor, thisPoint.point, monitor);
                return filter;
            }
            if (isLeft) {
                if (thisPoint.point.getMean() > leftPoint.point.getMean()) {
                    monitor.log(String.format("%s > %s", thisPoint.point.getMean(), leftPoint.point.getMean()));
                    return filter(cursor, leftPoint.point, monitor);
                }
                if (!isBracketed && leftPoint.point.getMean() < rightPoint.point.getMean()) {
                    rightX = leftX;
                    if (null != rightPoint)
                        rightPoint.freeRef();
                    rightPoint = leftPoint;
                    rightPoint.addRef();
                }
                if (null != leftPoint)
                    leftPoint.freeRef();
                leftPoint = thisPoint;
                leftPoint.addRef();
                leftX = thisX;
                monitor.log(String.format("Left bracket at %s", thisX));
            } else {
                if (thisPoint.point.getMean() > rightPoint.point.getMean()) {
                    monitor.log(String.format("%s > %s", thisPoint.point.getMean(), rightPoint.point.getMean()));
                    return filter(cursor, rightPoint.point, monitor);
                }
                if (!isBracketed && rightPoint.point.getMean() < leftPoint.point.getMean()) {
                    leftX = rightX;
                    if (null != leftPoint)
                        leftPoint.freeRef();
                    leftPoint = rightPoint;
                    leftPoint.addRef();
                }
                rightX = thisX;
                if (null != rightPoint)
                    rightPoint.freeRef();
                rightPoint = thisPoint;
                rightPoint.addRef();
                monitor.log(String.format("Right bracket at %s", thisX));
            }
        }
    } finally {
        if (null != leftPoint)
            leftPoint.freeRef();
        if (null != rightPoint)
            rightPoint.freeRef();
        if (null != thisPoint)
            thisPoint.freeRef();
        if (null != initialPoint)
            initialPoint.freeRef();
        if (null != locateInitialRightPoint)
            locateInitialRightPoint.freeRef();
    }
}
Also used : Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample)

Example 17 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class QuantifyOrientationWrapper method orient.

@Override
public LineSearchCursor orient(final Trainable subject, final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
    final LineSearchCursor cursor = inner.orient(subject, measurement, monitor);
    if (cursor instanceof SimpleLineSearchCursor) {
        final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) cursor).direction;
        @Nonnull final StateSet<Layer> weights = ((SimpleLineSearchCursor) cursor).origin.weights;
        final Map<CharSequence, CharSequence> dataMap = weights.stream().collect(Collectors.groupingBy(x -> getId(x), Collectors.toList())).entrySet().stream().collect(Collectors.toMap(x -> x.getKey(), list -> {
            final List<Double> doubleList = list.getValue().stream().map(weightDelta -> {
                final DoubleBuffer<Layer> dirDelta = direction.getMap().get(weightDelta.layer);
                final double denominator = weightDelta.deltaStatistics().rms();
                final double numerator = null == dirDelta ? 0 : dirDelta.deltaStatistics().rms();
                return numerator / (0 == denominator ? 1 : denominator);
            }).collect(Collectors.toList());
            if (1 == doubleList.size())
                return Double.toString(doubleList.get(0));
            return new DoubleStatistics().accept(doubleList.stream().mapToDouble(x -> x).toArray()).toString();
        }));
        monitor.log(String.format("Line search stats: %s", dataMap));
    } else {
        monitor.log(String.format("Non-simple cursor: %s", cursor));
    }
    return cursor;
}
Also used : DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Collectors(java.util.stream.Collectors) StateSet(com.simiacryptus.mindseye.lang.StateSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) List(java.util.List) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nonnull(javax.annotation.Nonnull) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) List(java.util.List)

Example 18 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class TrustRegionStrategy method orient.

@Nonnull
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, final PointSample origin, final TrainingMonitor monitor) {
    history.add(0, origin);
    while (history.size() > maxHistory) {
        history.remove(history.size() - 1);
    }
    final SimpleLineSearchCursor cursor = inner.orient(subject, origin, monitor);
    return new LineSearchCursorBase() {

        @Nonnull
        @Override
        public CharSequence getDirectionType() {
            return cursor.getDirectionType() + "+Trust";
        }

        @Nonnull
        @Override
        public DeltaSet<Layer> position(final double alpha) {
            reset();
            @Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
            project(adjustedPosVector, new TrainingMonitor());
            return adjustedPosVector;
        }

        @Nonnull
        public DeltaSet<Layer> project(@Nonnull final DeltaSet<Layer> deltaIn, final TrainingMonitor monitor) {
            final DeltaSet<Layer> originalAlphaDerivative = cursor.direction;
            @Nonnull final DeltaSet<Layer> newAlphaDerivative = originalAlphaDerivative.copy();
            deltaIn.getMap().forEach((layer, buffer) -> {
                @Nullable final double[] delta = buffer.getDelta();
                if (null == delta)
                    return;
                final double[] currentPosition = buffer.target;
                @Nullable final double[] originalAlphaD = originalAlphaDerivative.get(layer, currentPosition).getDelta();
                @Nullable final double[] newAlphaD = newAlphaDerivative.get(layer, currentPosition).getDelta();
                @Nonnull final double[] proposedPosition = ArrayUtil.add(currentPosition, delta);
                final TrustRegion region = getRegionPolicy(layer);
                if (null != region) {
                    final Stream<double[]> zz = history.stream().map((@Nonnull final PointSample x) -> {
                        final DoubleBuffer<Layer> d = x.weights.getMap().get(layer);
                        @Nullable final double[] z = null == d ? null : d.getDelta();
                        return z;
                    });
                    final double[] projectedPosition = region.project(zz.filter(x -> null != x).toArray(i -> new double[i][]), proposedPosition);
                    if (projectedPosition != proposedPosition) {
                        for (int i = 0; i < projectedPosition.length; i++) {
                            delta[i] = projectedPosition[i] - currentPosition[i];
                        }
                        @Nonnull final double[] normal = ArrayUtil.subtract(projectedPosition, proposedPosition);
                        final double normalMagSq = ArrayUtil.dot(normal, normal);
                        // normalMagSq));
                        if (0 < normalMagSq) {
                            final double a = ArrayUtil.dot(originalAlphaD, normal);
                            if (a != -1) {
                                @Nonnull final double[] tangent = ArrayUtil.add(originalAlphaD, ArrayUtil.multiply(normal, -a / normalMagSq));
                                for (int i = 0; i < tangent.length; i++) {
                                    newAlphaD[i] = tangent[i];
                                }
                            // double newAlphaDerivSq = ArrayUtil.dot(tangent, tangent);
                            // double originalAlphaDerivSq = ArrayUtil.dot(originalAlphaD, originalAlphaD);
                            // assert(newAlphaDerivSq <= originalAlphaDerivSq);
                            // assert(Math.abs(ArrayUtil.dot(tangent, normal)) <= 1e-4);
                            // monitor.log(String.format("%s: normalMagSq = %s, newAlphaDerivSq = %s, originalAlphaDerivSq = %s", layer, normalMagSq, newAlphaDerivSq, originalAlphaDerivSq));
                            }
                        }
                    }
                }
            });
            return newAlphaDerivative;
        }

        @Override
        public void reset() {
            cursor.reset();
        }

        @Nonnull
        @Override
        public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
            cursor.reset();
            @Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
            @Nonnull final DeltaSet<Layer> adjustedGradient = project(adjustedPosVector, monitor);
            adjustedPosVector.accumulate(1);
            @Nonnull final PointSample sample = subject.measure(monitor).setRate(alpha);
            return new LineSearchPoint(sample, adjustedGradient.dot(sample.delta));
        }

        @Override
        public void _free() {
            cursor.freeRef();
        }
    };
}
Also used : TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) IntStream(java.util.stream.IntStream) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) ArrayUtil(com.simiacryptus.util.ArrayUtil) Trainable(com.simiacryptus.mindseye.eval.Trainable) List(java.util.List) Stream(java.util.stream.Stream) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) LinkedList(java.util.LinkedList) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) LineSearchCursorBase(com.simiacryptus.mindseye.opt.line.LineSearchCursorBase) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Layer(com.simiacryptus.mindseye.lang.Layer) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) PointSample(com.simiacryptus.mindseye.lang.PointSample) LineSearchCursorBase(com.simiacryptus.mindseye.opt.line.LineSearchCursorBase) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 19 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class LocalSparkTrainable method measure.

@Nonnull
@Override
public PointSample measure(final TrainingMonitor monitor) {
    final long time1 = System.nanoTime();
    final JavaRDD<Tensor[]> javaRDD = sampledRDD.toJavaRDD();
    assert !javaRDD.isEmpty();
    final List<ReducableResult> mapPartitions = javaRDD.partitions().stream().map(partition -> {
        try {
            final List<Tensor[]>[] array = javaRDD.collectPartitions(new int[] { partition.index() });
            assert 0 < array.length;
            if (0 == Arrays.stream(array).mapToInt((@Nonnull final List<Tensor[]> x) -> x.size()).sum()) {
                return null;
            }
            assert 0 < Arrays.stream(array).mapToInt(x -> x.stream().mapToInt(y -> y.length).sum()).sum();
            final Stream<Tensor[]> stream = Arrays.stream(array).flatMap(i -> i.stream());
            @Nonnull final Iterator<Tensor[]> iterator = stream.iterator();
            return new PartitionTask(network).call(iterator).next();
        } catch (@Nonnull final RuntimeException e) {
            throw e;
        } catch (@Nonnull final Exception e) {
            throw new RuntimeException(e);
        }
    }).filter(x -> null != x).collect(Collectors.toList());
    final long time2 = System.nanoTime();
    @Nonnull final SparkTrainable.ReducableResult result = mapPartitions.stream().reduce(SparkTrainable.ReducableResult::add).get();
    if (isVerbose()) {
        log.info(String.format("Measure timing: %.3f / %.3f for %s items", (time2 - time1) * 1e-9, (System.nanoTime() - time2) * 1e-9, sampledRDD.count()));
    }
    @Nonnull final DeltaSet<Layer> xxx = getDelta(result);
    return new PointSample(xxx, new StateSet<Layer>(xxx), result.sum, 0.0, result.count).normalize();
}
Also used : Arrays(java.util.Arrays) Logger(org.slf4j.Logger) Iterator(java.util.Iterator) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Collectors(java.util.stream.Collectors) StateSet(com.simiacryptus.mindseye.lang.StateSet) List(java.util.List) Stream(java.util.stream.Stream) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Layer(com.simiacryptus.mindseye.lang.Layer) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) RDD(org.apache.spark.rdd.RDD) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) JavaRDD(org.apache.spark.api.java.JavaRDD) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) Layer(com.simiacryptus.mindseye.lang.Layer) List(java.util.List) PointSample(com.simiacryptus.mindseye.lang.PointSample) StateSet(com.simiacryptus.mindseye.lang.StateSet) Nonnull(javax.annotation.Nonnull)

Example 20 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class FailsafeLineSearchCursor method accumulate.

/**
 * Accumulate.
 *
 * @param step the runStep
 */
public void accumulate(@Nonnull final PointSample step) {
    if (null == best || best.getMean() > step.getMean()) {
        @Nonnull PointSample newValue = step.copyFull();
        if (null != this.best) {
            monitor.log(String.format("New Minimum: %s > %s", best.getMean(), step.getMean()));
            this.best.freeRef();
        }
        this.best = newValue;
    }
}
Also used : Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample)

Aggregations

PointSample (com.simiacryptus.mindseye.lang.PointSample)33 Nonnull (javax.annotation.Nonnull)24 Layer (com.simiacryptus.mindseye.lang.Layer)16 Nullable (javax.annotation.Nullable)14 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)10 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)9 SimpleLineSearchCursor (com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor)9 StateSet (com.simiacryptus.mindseye.lang.StateSet)8 LineSearchCursor (com.simiacryptus.mindseye.opt.line.LineSearchCursor)8 List (java.util.List)8 Trainable (com.simiacryptus.mindseye.eval.Trainable)7 Arrays (java.util.Arrays)7 Collectors (java.util.stream.Collectors)7 IterativeStopException (com.simiacryptus.mindseye.lang.IterativeStopException)6 Map (java.util.Map)6 IntStream (java.util.stream.IntStream)6 DoubleBuffer (com.simiacryptus.mindseye.lang.DoubleBuffer)5 PlaceholderLayer (com.simiacryptus.mindseye.layers.java.PlaceholderLayer)5 FailsafeLineSearchCursor (com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor)5 LineSearchStrategy (com.simiacryptus.mindseye.opt.line.LineSearchStrategy)5