use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class QuadraticSearch method _step.
/**
* Step point sample.
*
* @param cursor the cursor
* @param monitor the monitor
* @return the point sample
*/
public PointSample _step(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor) {
double thisX = 0;
LineSearchPoint thisPoint = cursor.step(thisX, monitor);
final LineSearchPoint initialPoint = thisPoint;
initialPoint.addRef();
double leftX = thisX;
LineSearchPoint leftPoint = thisPoint;
leftPoint.addRef();
monitor.log(String.format("F(%s) = %s", leftX, leftPoint));
if (0 == leftPoint.derivative) {
initialPoint.freeRef();
thisPoint.freeRef();
PointSample point = leftPoint.point;
point.addRef();
leftPoint.freeRef();
return point;
}
@Nonnull final LocateInitialRightPoint locateInitialRightPoint = new LocateInitialRightPoint(cursor, monitor, leftPoint).apply();
@Nonnull LineSearchPoint rightPoint = locateInitialRightPoint.getRightPoint();
rightPoint.addRef();
double rightX = locateInitialRightPoint.getRightX();
try {
int loops = 0;
while (true) {
final double a = (rightPoint.derivative - leftPoint.derivative) / (rightX - leftX);
final double b = rightPoint.derivative - a * rightX;
thisX = -b / a;
final boolean isBracketed = Math.signum(leftPoint.derivative) != Math.signum(rightPoint.derivative);
if (!Double.isFinite(thisX) || isBracketed && (leftX > thisX || rightX < thisX)) {
thisX = (rightX + leftX) / 2;
}
if (!isBracketed && thisX < 0) {
thisX = rightX * 2;
}
if (isSame(leftX, thisX, 1.0)) {
monitor.log(String.format("Converged to left"));
return filter(cursor, leftPoint.point, monitor);
} else if (isSame(thisX, rightX, 1.0)) {
monitor.log(String.format("Converged to right"));
return filter(cursor, rightPoint.point, monitor);
}
thisPoint.freeRef();
thisPoint = null;
thisPoint = cursor.step(thisX, monitor);
if (isSame(cursor, monitor, leftPoint, thisPoint)) {
monitor.log(String.format("%s ~= %s", leftX, thisX));
return filter(cursor, leftPoint.point, monitor);
}
if (isSame(cursor, monitor, thisPoint, rightPoint)) {
monitor.log(String.format("%s ~= %s", thisX, rightX));
return filter(cursor, rightPoint.point, monitor);
}
thisPoint.freeRef();
thisPoint = null;
thisPoint = cursor.step(thisX, monitor);
boolean isLeft;
if (!isBracketed) {
isLeft = Math.abs(rightPoint.point.rate - thisPoint.point.rate) > Math.abs(leftPoint.point.rate - thisPoint.point.rate);
} else {
isLeft = thisPoint.derivative < 0;
}
// monitor.log(String.format("isLeft=%s; isBracketed=%s; leftPoint=%s; rightPoint=%s", isLeft, isBracketed, leftPoint, rightPoint));
monitor.log(String.format("F(%s) = %s, delta = %s", thisX, thisPoint, thisPoint.point.getMean() - initialPoint.point.getMean()));
if (loops++ > 10) {
monitor.log(String.format("Loops = %s", loops));
PointSample filter = filter(cursor, thisPoint.point, monitor);
return filter;
}
if (isSame(cursor, monitor, leftPoint, rightPoint)) {
monitor.log(String.format("%s ~= %s", leftX, rightX));
PointSample filter = filter(cursor, thisPoint.point, monitor);
return filter;
}
if (isLeft) {
if (thisPoint.point.getMean() > leftPoint.point.getMean()) {
monitor.log(String.format("%s > %s", thisPoint.point.getMean(), leftPoint.point.getMean()));
return filter(cursor, leftPoint.point, monitor);
}
if (!isBracketed && leftPoint.point.getMean() < rightPoint.point.getMean()) {
rightX = leftX;
if (null != rightPoint)
rightPoint.freeRef();
rightPoint = leftPoint;
rightPoint.addRef();
}
if (null != leftPoint)
leftPoint.freeRef();
leftPoint = thisPoint;
leftPoint.addRef();
leftX = thisX;
monitor.log(String.format("Left bracket at %s", thisX));
} else {
if (thisPoint.point.getMean() > rightPoint.point.getMean()) {
monitor.log(String.format("%s > %s", thisPoint.point.getMean(), rightPoint.point.getMean()));
return filter(cursor, rightPoint.point, monitor);
}
if (!isBracketed && rightPoint.point.getMean() < leftPoint.point.getMean()) {
leftX = rightX;
if (null != leftPoint)
leftPoint.freeRef();
leftPoint = rightPoint;
leftPoint.addRef();
}
rightX = thisX;
if (null != rightPoint)
rightPoint.freeRef();
rightPoint = thisPoint;
rightPoint.addRef();
monitor.log(String.format("Right bracket at %s", thisX));
}
}
} finally {
if (null != leftPoint)
leftPoint.freeRef();
if (null != rightPoint)
rightPoint.freeRef();
if (null != thisPoint)
thisPoint.freeRef();
if (null != initialPoint)
initialPoint.freeRef();
if (null != locateInitialRightPoint)
locateInitialRightPoint.freeRef();
}
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class QuantifyOrientationWrapper method orient.
@Override
public LineSearchCursor orient(final Trainable subject, final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
final LineSearchCursor cursor = inner.orient(subject, measurement, monitor);
if (cursor instanceof SimpleLineSearchCursor) {
final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) cursor).direction;
@Nonnull final StateSet<Layer> weights = ((SimpleLineSearchCursor) cursor).origin.weights;
final Map<CharSequence, CharSequence> dataMap = weights.stream().collect(Collectors.groupingBy(x -> getId(x), Collectors.toList())).entrySet().stream().collect(Collectors.toMap(x -> x.getKey(), list -> {
final List<Double> doubleList = list.getValue().stream().map(weightDelta -> {
final DoubleBuffer<Layer> dirDelta = direction.getMap().get(weightDelta.layer);
final double denominator = weightDelta.deltaStatistics().rms();
final double numerator = null == dirDelta ? 0 : dirDelta.deltaStatistics().rms();
return numerator / (0 == denominator ? 1 : denominator);
}).collect(Collectors.toList());
if (1 == doubleList.size())
return Double.toString(doubleList.get(0));
return new DoubleStatistics().accept(doubleList.stream().mapToDouble(x -> x).toArray()).toString();
}));
monitor.log(String.format("Line search stats: %s", dataMap));
} else {
monitor.log(String.format("Non-simple cursor: %s", cursor));
}
return cursor;
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class TrustRegionStrategy method orient.
@Nonnull
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, final PointSample origin, final TrainingMonitor monitor) {
history.add(0, origin);
while (history.size() > maxHistory) {
history.remove(history.size() - 1);
}
final SimpleLineSearchCursor cursor = inner.orient(subject, origin, monitor);
return new LineSearchCursorBase() {
@Nonnull
@Override
public CharSequence getDirectionType() {
return cursor.getDirectionType() + "+Trust";
}
@Nonnull
@Override
public DeltaSet<Layer> position(final double alpha) {
reset();
@Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
project(adjustedPosVector, new TrainingMonitor());
return adjustedPosVector;
}
@Nonnull
public DeltaSet<Layer> project(@Nonnull final DeltaSet<Layer> deltaIn, final TrainingMonitor monitor) {
final DeltaSet<Layer> originalAlphaDerivative = cursor.direction;
@Nonnull final DeltaSet<Layer> newAlphaDerivative = originalAlphaDerivative.copy();
deltaIn.getMap().forEach((layer, buffer) -> {
@Nullable final double[] delta = buffer.getDelta();
if (null == delta)
return;
final double[] currentPosition = buffer.target;
@Nullable final double[] originalAlphaD = originalAlphaDerivative.get(layer, currentPosition).getDelta();
@Nullable final double[] newAlphaD = newAlphaDerivative.get(layer, currentPosition).getDelta();
@Nonnull final double[] proposedPosition = ArrayUtil.add(currentPosition, delta);
final TrustRegion region = getRegionPolicy(layer);
if (null != region) {
final Stream<double[]> zz = history.stream().map((@Nonnull final PointSample x) -> {
final DoubleBuffer<Layer> d = x.weights.getMap().get(layer);
@Nullable final double[] z = null == d ? null : d.getDelta();
return z;
});
final double[] projectedPosition = region.project(zz.filter(x -> null != x).toArray(i -> new double[i][]), proposedPosition);
if (projectedPosition != proposedPosition) {
for (int i = 0; i < projectedPosition.length; i++) {
delta[i] = projectedPosition[i] - currentPosition[i];
}
@Nonnull final double[] normal = ArrayUtil.subtract(projectedPosition, proposedPosition);
final double normalMagSq = ArrayUtil.dot(normal, normal);
// normalMagSq));
if (0 < normalMagSq) {
final double a = ArrayUtil.dot(originalAlphaD, normal);
if (a != -1) {
@Nonnull final double[] tangent = ArrayUtil.add(originalAlphaD, ArrayUtil.multiply(normal, -a / normalMagSq));
for (int i = 0; i < tangent.length; i++) {
newAlphaD[i] = tangent[i];
}
// double newAlphaDerivSq = ArrayUtil.dot(tangent, tangent);
// double originalAlphaDerivSq = ArrayUtil.dot(originalAlphaD, originalAlphaD);
// assert(newAlphaDerivSq <= originalAlphaDerivSq);
// assert(Math.abs(ArrayUtil.dot(tangent, normal)) <= 1e-4);
// monitor.log(String.format("%s: normalMagSq = %s, newAlphaDerivSq = %s, originalAlphaDerivSq = %s", layer, normalMagSq, newAlphaDerivSq, originalAlphaDerivSq));
}
}
}
}
});
return newAlphaDerivative;
}
@Override
public void reset() {
cursor.reset();
}
@Nonnull
@Override
public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
cursor.reset();
@Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
@Nonnull final DeltaSet<Layer> adjustedGradient = project(adjustedPosVector, monitor);
adjustedPosVector.accumulate(1);
@Nonnull final PointSample sample = subject.measure(monitor).setRate(alpha);
return new LineSearchPoint(sample, adjustedGradient.dot(sample.delta));
}
@Override
public void _free() {
cursor.freeRef();
}
};
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class LocalSparkTrainable method measure.
@Nonnull
@Override
public PointSample measure(final TrainingMonitor monitor) {
final long time1 = System.nanoTime();
final JavaRDD<Tensor[]> javaRDD = sampledRDD.toJavaRDD();
assert !javaRDD.isEmpty();
final List<ReducableResult> mapPartitions = javaRDD.partitions().stream().map(partition -> {
try {
final List<Tensor[]>[] array = javaRDD.collectPartitions(new int[] { partition.index() });
assert 0 < array.length;
if (0 == Arrays.stream(array).mapToInt((@Nonnull final List<Tensor[]> x) -> x.size()).sum()) {
return null;
}
assert 0 < Arrays.stream(array).mapToInt(x -> x.stream().mapToInt(y -> y.length).sum()).sum();
final Stream<Tensor[]> stream = Arrays.stream(array).flatMap(i -> i.stream());
@Nonnull final Iterator<Tensor[]> iterator = stream.iterator();
return new PartitionTask(network).call(iterator).next();
} catch (@Nonnull final RuntimeException e) {
throw e;
} catch (@Nonnull final Exception e) {
throw new RuntimeException(e);
}
}).filter(x -> null != x).collect(Collectors.toList());
final long time2 = System.nanoTime();
@Nonnull final SparkTrainable.ReducableResult result = mapPartitions.stream().reduce(SparkTrainable.ReducableResult::add).get();
if (isVerbose()) {
log.info(String.format("Measure timing: %.3f / %.3f for %s items", (time2 - time1) * 1e-9, (System.nanoTime() - time2) * 1e-9, sampledRDD.count()));
}
@Nonnull final DeltaSet<Layer> xxx = getDelta(result);
return new PointSample(xxx, new StateSet<Layer>(xxx), result.sum, 0.0, result.count).normalize();
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class FailsafeLineSearchCursor method accumulate.
/**
* Accumulate.
*
* @param step the runStep
*/
public void accumulate(@Nonnull final PointSample step) {
if (null == best || best.getMean() > step.getMean()) {
@Nonnull PointSample newValue = step.copyFull();
if (null != this.best) {
monitor.log(String.format("New Minimum: %s > %s", best.getMean(), step.getMean()));
this.best.freeRef();
}
this.best = newValue;
}
}
Aggregations