use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.
the class TrustRegionStrategy method orient.
@Nonnull
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, final PointSample origin, final TrainingMonitor monitor) {
history.add(0, origin);
while (history.size() > maxHistory) {
history.remove(history.size() - 1);
}
final SimpleLineSearchCursor cursor = inner.orient(subject, origin, monitor);
return new LineSearchCursorBase() {
@Nonnull
@Override
public CharSequence getDirectionType() {
return cursor.getDirectionType() + "+Trust";
}
@Nonnull
@Override
public DeltaSet<Layer> position(final double alpha) {
reset();
@Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
project(adjustedPosVector, new TrainingMonitor());
return adjustedPosVector;
}
@Nonnull
public DeltaSet<Layer> project(@Nonnull final DeltaSet<Layer> deltaIn, final TrainingMonitor monitor) {
final DeltaSet<Layer> originalAlphaDerivative = cursor.direction;
@Nonnull final DeltaSet<Layer> newAlphaDerivative = originalAlphaDerivative.copy();
deltaIn.getMap().forEach((layer, buffer) -> {
@Nullable final double[] delta = buffer.getDelta();
if (null == delta)
return;
final double[] currentPosition = buffer.target;
@Nullable final double[] originalAlphaD = originalAlphaDerivative.get(layer, currentPosition).getDelta();
@Nullable final double[] newAlphaD = newAlphaDerivative.get(layer, currentPosition).getDelta();
@Nonnull final double[] proposedPosition = ArrayUtil.add(currentPosition, delta);
final TrustRegion region = getRegionPolicy(layer);
if (null != region) {
final Stream<double[]> zz = history.stream().map((@Nonnull final PointSample x) -> {
final DoubleBuffer<Layer> d = x.weights.getMap().get(layer);
@Nullable final double[] z = null == d ? null : d.getDelta();
return z;
});
final double[] projectedPosition = region.project(zz.filter(x -> null != x).toArray(i -> new double[i][]), proposedPosition);
if (projectedPosition != proposedPosition) {
for (int i = 0; i < projectedPosition.length; i++) {
delta[i] = projectedPosition[i] - currentPosition[i];
}
@Nonnull final double[] normal = ArrayUtil.subtract(projectedPosition, proposedPosition);
final double normalMagSq = ArrayUtil.dot(normal, normal);
// normalMagSq));
if (0 < normalMagSq) {
final double a = ArrayUtil.dot(originalAlphaD, normal);
if (a != -1) {
@Nonnull final double[] tangent = ArrayUtil.add(originalAlphaD, ArrayUtil.multiply(normal, -a / normalMagSq));
for (int i = 0; i < tangent.length; i++) {
newAlphaD[i] = tangent[i];
}
// double newAlphaDerivSq = ArrayUtil.dot(tangent, tangent);
// double originalAlphaDerivSq = ArrayUtil.dot(originalAlphaD, originalAlphaD);
// assert(newAlphaDerivSq <= originalAlphaDerivSq);
// assert(Math.abs(ArrayUtil.dot(tangent, normal)) <= 1e-4);
// monitor.log(String.format("%s: normalMagSq = %s, newAlphaDerivSq = %s, originalAlphaDerivSq = %s", layer, normalMagSq, newAlphaDerivSq, originalAlphaDerivSq));
}
}
}
}
});
return newAlphaDerivative;
}
@Override
public void reset() {
cursor.reset();
}
@Nonnull
@Override
public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
cursor.reset();
@Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
@Nonnull final DeltaSet<Layer> adjustedGradient = project(adjustedPosVector, monitor);
adjustedPosVector.accumulate(1);
@Nonnull final PointSample sample = subject.measure(monitor).setRate(alpha);
return new LineSearchPoint(sample, adjustedGradient.dot(sample.delta));
}
@Override
public void _free() {
cursor.freeRef();
}
};
}
use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.
the class LocalSparkTrainable method measure.
@Nonnull
@Override
public PointSample measure(final TrainingMonitor monitor) {
final long time1 = System.nanoTime();
final JavaRDD<Tensor[]> javaRDD = sampledRDD.toJavaRDD();
assert !javaRDD.isEmpty();
final List<ReducableResult> mapPartitions = javaRDD.partitions().stream().map(partition -> {
try {
final List<Tensor[]>[] array = javaRDD.collectPartitions(new int[] { partition.index() });
assert 0 < array.length;
if (0 == Arrays.stream(array).mapToInt((@Nonnull final List<Tensor[]> x) -> x.size()).sum()) {
return null;
}
assert 0 < Arrays.stream(array).mapToInt(x -> x.stream().mapToInt(y -> y.length).sum()).sum();
final Stream<Tensor[]> stream = Arrays.stream(array).flatMap(i -> i.stream());
@Nonnull final Iterator<Tensor[]> iterator = stream.iterator();
return new PartitionTask(network).call(iterator).next();
} catch (@Nonnull final RuntimeException e) {
throw e;
} catch (@Nonnull final Exception e) {
throw new RuntimeException(e);
}
}).filter(x -> null != x).collect(Collectors.toList());
final long time2 = System.nanoTime();
@Nonnull final SparkTrainable.ReducableResult result = mapPartitions.stream().reduce(SparkTrainable.ReducableResult::add).get();
if (isVerbose()) {
log.info(String.format("Measure timing: %.3f / %.3f for %s items", (time2 - time1) * 1e-9, (System.nanoTime() - time2) * 1e-9, sampledRDD.count()));
}
@Nonnull final DeltaSet<Layer> xxx = getDelta(result);
return new PointSample(xxx, new StateSet<Layer>(xxx), result.sum, 0.0, result.count).normalize();
}
use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.
the class TrainingTester method trainCjGD.
/**
* Train cj gd list.
*
* @param log the log
* @param trainable the trainable
* @return the list
*/
@Nonnull
public List<StepRecord> trainCjGD(@Nonnull final NotebookOutput log, final Trainable trainable) {
log.p("First, we use a conjugate gradient descent method, which converges the fastest for purely linear functions.");
@Nonnull final List<StepRecord> history = new ArrayList<>();
@Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
try {
log.code(() -> {
return new IterativeTrainer(trainable).setLineSearchFactory(label -> new QuadraticSearch()).setOrientation(new GradientDescent()).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
});
} catch (Throwable e) {
if (isThrowExceptions())
throw new RuntimeException(e);
}
return history;
}
use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.
the class TrainingTester method trainLBFGS.
/**
* Train lbfgs list.
*
* @param log the log
* @param trainable the trainable
* @return the list
*/
@Nonnull
public List<StepRecord> trainLBFGS(@Nonnull final NotebookOutput log, final Trainable trainable) {
log.p("Next, we apply the same optimization using L-BFGS, which is nearly ideal for purely second-order or quadratic functions.");
@Nonnull final List<StepRecord> history = new ArrayList<>();
@Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
try {
log.code(() -> {
return new IterativeTrainer(trainable).setLineSearchFactory(label -> new ArmijoWolfeSearch()).setOrientation(new LBFGS()).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setIterationsPerSample(100).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
});
} catch (Throwable e) {
if (isThrowExceptions())
throw new RuntimeException(e);
}
return history;
}
use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.
the class LBFGS method lbfgs.
private boolean lbfgs(@Nonnull PointSample measurement, @Nonnull TrainingMonitor monitor, @Nonnull List<PointSample> history, @Nonnull DeltaSet<Layer> direction) {
try {
@Nonnull DeltaSet<Layer> p = measurement.delta.copy();
if (!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d)))) {
throw new IllegalStateException("Non-finite value");
}
@Nonnull final double[] alphas = new double[history.size()];
for (int i = history.size() - 2; i >= 0; i--) {
@Nonnull final DeltaSet<Layer> sd = history.get(i + 1).weights.subtract(history.get(i).weights);
@Nonnull final DeltaSet<Layer> yd = history.get(i + 1).delta.subtract(history.get(i).delta);
final double denominator = sd.dot(yd);
if (0 == denominator) {
throw new IllegalStateException("Orientation vanished.");
}
alphas[i] = p.dot(sd) / denominator;
p = p.subtract(yd.scale(alphas[i]));
if ((!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d))))) {
throw new IllegalStateException("Non-finite value");
}
}
@Nonnull final DeltaSet<Layer> sk = history.get(history.size() - 1).weights.subtract(history.get(history.size() - 2).weights);
@Nonnull final DeltaSet<Layer> yk = history.get(history.size() - 1).delta.subtract(history.get(history.size() - 2).delta);
p = p.scale(sk.dot(yk) / yk.dot(yk));
if (!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d)))) {
throw new IllegalStateException("Non-finite value");
}
for (int i = 0; i < history.size() - 1; i++) {
@Nonnull final DeltaSet<Layer> sd = history.get(i + 1).weights.subtract(history.get(i).weights);
@Nonnull final DeltaSet<Layer> yd = history.get(i + 1).delta.subtract(history.get(i).delta);
final double beta = p.dot(yd) / sd.dot(yd);
p = p.add(sd.scale(alphas[i] - beta));
if (!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d)))) {
throw new IllegalStateException("Non-finite value");
}
}
boolean accept = measurement.delta.dot(p) < 0;
if (accept) {
monitor.log("Accepted: " + new Stats(direction, p));
copy(p, direction);
} else {
monitor.log("Rejected: " + new Stats(direction, p));
}
return accept;
} catch (Throwable e) {
monitor.log(String.format("LBFGS Orientation Error: %s", e.getMessage()));
return false;
}
}
Aggregations