use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.
the class OwlQn method orient.
@Nonnull
@Override
public LineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, final TrainingMonitor monitor) {
@Nonnull final SimpleLineSearchCursor gradient = (SimpleLineSearchCursor) inner.orient(subject, measurement, monitor);
@Nonnull final DeltaSet<Layer> searchDirection = gradient.direction.copy();
@Nonnull final DeltaSet<Layer> orthant = new DeltaSet<Layer>();
for (@Nonnull final Layer layer : getLayers(gradient.direction.getMap().keySet())) {
final double[] weights = gradient.direction.getMap().get(layer).target;
@Nullable final double[] delta = gradient.direction.getMap().get(layer).getDelta();
@Nullable final double[] searchDir = searchDirection.get(layer, weights).getDelta();
@Nullable final double[] suborthant = orthant.get(layer, weights).getDelta();
for (int i = 0; i < searchDir.length; i++) {
final int positionSign = sign(weights[i]);
final int directionSign = sign(delta[i]);
suborthant[i] = 0 == positionSign ? directionSign : positionSign;
searchDir[i] += factor_L1 * (weights[i] < 0 ? -1.0 : 1.0);
if (sign(searchDir[i]) != directionSign) {
searchDir[i] = delta[i];
}
}
assert null != searchDir;
}
return new SimpleLineSearchCursor(subject, measurement, searchDirection) {
@Nonnull
@Override
public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
origin.weights.stream().forEach(d -> d.restore());
@Nonnull final DeltaSet<Layer> currentDirection = direction.copy();
direction.getMap().forEach((layer, buffer) -> {
if (null == buffer.getDelta())
return;
@Nullable final double[] currentDelta = currentDirection.get(layer, buffer.target).getDelta();
for (int i = 0; i < buffer.getDelta().length; i++) {
final double prevValue = buffer.target[i];
final double newValue = prevValue + buffer.getDelta()[i] * alpha;
if (sign(prevValue) != 0 && sign(prevValue) != sign(newValue)) {
currentDelta[i] = 0;
buffer.target[i] = 0;
} else {
buffer.target[i] = newValue;
}
}
});
@Nonnull final PointSample measure = subject.measure(monitor).setRate(alpha);
return new LineSearchPoint(measure, currentDirection.dot(measure.delta));
}
}.setDirectionType("OWL/QN");
}
use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.
the class QQN method orient.
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, @Nonnull final PointSample origin, @Nonnull final TrainingMonitor monitor) {
inner.addToHistory(origin, monitor);
final SimpleLineSearchCursor lbfgsCursor = inner.orient(subject, origin, monitor);
final DeltaSet<Layer> lbfgs = lbfgsCursor.direction;
@Nonnull final DeltaSet<Layer> gd = origin.delta.scale(-1.0);
final double lbfgsMag = lbfgs.getMagnitude();
final double gdMag = gd.getMagnitude();
if (Math.abs(lbfgsMag - gdMag) / (lbfgsMag + gdMag) > 1e-2) {
@Nonnull final DeltaSet<Layer> scaledGradient = gd.scale(lbfgsMag / gdMag);
monitor.log(String.format("Returning Quadratic Cursor %s GD, %s QN", gdMag, lbfgsMag));
gd.freeRef();
return new LineSearchCursorBase() {
@Nonnull
@Override
public CharSequence getDirectionType() {
return CURSOR_NAME;
}
@Override
public DeltaSet<Layer> position(final double t) {
if (!Double.isFinite(t))
throw new IllegalArgumentException();
return scaledGradient.scale(t - t * t).add(lbfgs.scale(t * t));
}
@Override
public void reset() {
lbfgsCursor.reset();
}
@Nonnull
@Override
public LineSearchPoint step(final double t, @Nonnull final TrainingMonitor monitor) {
if (!Double.isFinite(t))
throw new IllegalArgumentException();
reset();
position(t).accumulate(1);
@Nonnull final PointSample sample = subject.measure(monitor).setRate(t);
// monitor.log(String.format("delta buffers %d %d %d %d %d", sample.delta.apply.size(), origin.delta.apply.size(), lbfgs.apply.size(), gd.apply.size(), scaledGradient.apply.size()));
inner.addToHistory(sample, monitor);
@Nonnull final DeltaSet<Layer> tangent = scaledGradient.scale(1 - 2 * t).add(lbfgs.scale(2 * t));
return new LineSearchPoint(sample, tangent.dot(sample.delta));
}
@Override
public void _free() {
scaledGradient.freeRef();
lbfgsCursor.freeRef();
}
};
} else {
gd.freeRef();
return lbfgsCursor;
}
}
use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.
the class RecursiveSubspace method buildSubspace.
/**
* Build subspace nn layer.
*
* @param subject the subject
* @param measurement the measurement
* @param monitor the monitor
* @return the nn layer
*/
@Nullable
public Layer buildSubspace(@Nonnull Trainable subject, @Nonnull PointSample measurement, @Nonnull TrainingMonitor monitor) {
@Nonnull PointSample origin = measurement.copyFull().backup();
@Nonnull final DeltaSet<Layer> direction = measurement.delta.scale(-1);
final double magnitude = direction.getMagnitude();
if (Math.abs(magnitude) < 1e-10) {
monitor.log(String.format("Zero gradient: %s", magnitude));
} else if (Math.abs(magnitude) < 1e-5) {
monitor.log(String.format("Low gradient: %s", magnitude));
}
boolean hasPlaceholders = direction.getMap().entrySet().stream().filter(x -> x.getKey() instanceof PlaceholderLayer).findAny().isPresent();
List<Layer> deltaLayers = direction.getMap().entrySet().stream().map(x -> x.getKey()).filter(x -> !(x instanceof PlaceholderLayer)).collect(Collectors.toList());
int size = deltaLayers.size() + (hasPlaceholders ? 1 : 0);
if (null == weights || weights.length != size)
weights = new double[size];
return new LayerBase() {
@Nonnull
Layer self = this;
@Nonnull
@Override
public Result eval(Result... array) {
assertAlive();
origin.restore();
IntStream.range(0, deltaLayers.size()).forEach(i -> {
direction.getMap().get(deltaLayers.get(i)).accumulate(weights[hasPlaceholders ? (i + 1) : i]);
});
if (hasPlaceholders) {
direction.getMap().entrySet().stream().filter(x -> x.getKey() instanceof PlaceholderLayer).distinct().forEach(entry -> entry.getValue().accumulate(weights[0]));
}
PointSample measure = subject.measure(monitor);
double mean = measure.getMean();
monitor.log(String.format("RecursiveSubspace: %s <- %s", mean, Arrays.toString(weights)));
direction.addRef();
return new Result(TensorArray.wrap(new Tensor(mean)), (DeltaSet<Layer> buffer, TensorList data) -> {
DoubleStream deltaStream = deltaLayers.stream().mapToDouble(layer -> {
Delta<Layer> a = direction.getMap().get(layer);
Delta<Layer> b = measure.delta.getMap().get(layer);
return b.dot(a) / Math.max(Math.sqrt(a.dot(a)), 1e-8);
});
if (hasPlaceholders) {
deltaStream = DoubleStream.concat(DoubleStream.of(direction.getMap().keySet().stream().filter(x -> x instanceof PlaceholderLayer).distinct().mapToDouble(layer -> {
Delta<Layer> a = direction.getMap().get(layer);
Delta<Layer> b = measure.delta.getMap().get(layer);
return b.dot(a) / Math.max(Math.sqrt(a.dot(a)), 1e-8);
}).sum()), deltaStream);
}
buffer.get(self, weights).addInPlace(deltaStream.toArray()).freeRef();
}) {
@Override
protected void _free() {
measure.freeRef();
direction.freeRef();
}
@Override
public boolean isAlive() {
return true;
}
};
}
@Override
protected void _free() {
direction.freeRef();
origin.freeRef();
super._free();
}
@Nonnull
@Override
public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
throw new IllegalStateException();
}
@Nullable
@Override
public List<double[]> state() {
return null;
}
};
}
use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.
the class RecursiveSubspace method train.
/**
* Train.
*
* @param monitor the monitor
* @param macroLayer the macro layer
*/
public void train(@Nonnull TrainingMonitor monitor, Layer macroLayer) {
@Nonnull BasicTrainable inner = new BasicTrainable(macroLayer);
// @javax.annotation.Nonnull Tensor tensor = new Tensor();
@Nonnull ArrayTrainable trainable = new ArrayTrainable(inner, new Tensor[][] { {} });
inner.freeRef();
// tensor.freeRef();
new IterativeTrainer(trainable).setOrientation(new LBFGS()).setLineSearchFactory(n -> new ArmijoWolfeSearch()).setMonitor(new TrainingMonitor() {
@Override
public void log(String msg) {
monitor.log("\t" + msg);
}
}).setMaxIterations(getIterations()).setIterationsPerSample(getIterations()).setTerminateThreshold(terminateThreshold).runAndFree();
trainable.freeRef();
}
Aggregations