use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class TensorListTrainable method eval.
/**
* Eval point sample.
*
* @param list the list
* @param monitor the monitor
* @return the point sample
*/
@Nonnull
protected PointSample eval(@Nonnull final TensorList[] list, @Nullable final TrainingMonitor monitor) {
int inputs = data.length;
assert 0 < inputs;
int items = data[0].length();
assert 0 < items;
@Nonnull final TimedResult<PointSample> timedResult = TimedResult.time(() -> {
final Result[] nnContext = TensorListTrainable.getNNContext(list, mask);
final Result result = network.eval(nnContext);
for (@Nonnull Result nnResult : nnContext) {
nnResult.getData().freeRef();
nnResult.freeRef();
}
final TensorList resultData = result.getData();
final DoubleSummaryStatistics statistics = resultData.stream().flatMapToDouble(x -> {
double[] array = Arrays.stream(x.getData()).toArray();
x.freeRef();
return Arrays.stream(array);
}).summaryStatistics();
final double sum = statistics.getSum();
@Nonnull final DeltaSet<Layer> deltaSet = new DeltaSet<Layer>();
@Nonnull PointSample pointSample;
try {
result.accumulate(deltaSet, 1.0);
// log.info(String.format("Evaluated to %s delta buffers, %s mag", DeltaSet<LayerBase>.getMap().size(), DeltaSet<LayerBase>.getMagnitude()));
@Nonnull StateSet<Layer> stateSet = new StateSet<>(deltaSet);
pointSample = new PointSample(deltaSet, stateSet, sum, 0.0, items);
stateSet.freeRef();
} finally {
resultData.freeRef();
result.freeRef();
deltaSet.freeRef();
}
return pointSample;
});
if (null != monitor && verbosity() > 0) {
monitor.log(String.format("Device completed %s items in %.3f sec", items, timedResult.timeNanos / 1e9));
}
@Nonnull PointSample normalize = timedResult.result.normalize();
timedResult.result.freeRef();
return normalize;
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class LayerRateDiagnosticTrainer method measure.
/**
* Measure point sample.
*
* @return the point sample
*/
public PointSample measure() {
PointSample currentPoint;
int retries = 0;
do {
if (!subject.reseed(System.nanoTime()) && retries > 0)
throw new IterativeStopException();
if (10 < retries++)
throw new IterativeStopException();
currentPoint = subject.measure(monitor);
} while (!Double.isFinite(currentPoint.sum));
assert Double.isFinite(currentPoint.sum);
return currentPoint;
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class ValidatingTrainer method runStep.
/**
* Step runStep result.
*
* @param previousPoint the previous point
* @param phase the phase
* @return the runStep result
*/
@Nonnull
protected StepResult runStep(@Nonnull final PointSample previousPoint, @Nonnull final TrainingPhase phase) {
currentIteration.incrementAndGet();
@Nonnull final TimedResult<LineSearchCursor> timedOrientation = TimedResult.time(() -> phase.orientation.orient(phase.trainingSubject, previousPoint, monitor));
final LineSearchCursor direction = timedOrientation.result;
final CharSequence directionType = direction.getDirectionType();
LineSearchStrategy lineSearchStrategy;
if (phase.lineSearchStrategyMap.containsKey(directionType)) {
lineSearchStrategy = phase.lineSearchStrategyMap.get(directionType);
} else {
monitor.log(String.format("Constructing line search parameters: %s", directionType));
lineSearchStrategy = phase.lineSearchFactory.apply(direction.getDirectionType());
phase.lineSearchStrategyMap.put(directionType, lineSearchStrategy);
}
@Nonnull final TimedResult<PointSample> timedLineSearch = TimedResult.time(() -> {
@Nonnull final FailsafeLineSearchCursor cursor = new FailsafeLineSearchCursor(direction, previousPoint, monitor);
lineSearchStrategy.step(cursor, monitor);
@Nonnull final PointSample restore = cursor.getBest(monitor).restore();
// cursor.step(restore.rate, monitor);
return restore;
});
final PointSample bestPoint = timedLineSearch.result;
if (bestPoint.getMean() > previousPoint.getMean()) {
throw new IllegalStateException(bestPoint.getMean() + " > " + previousPoint.getMean());
}
monitor.log(compare(previousPoint, bestPoint));
return new StepResult(previousPoint, bestPoint, new double[] { timedOrientation.timeNanos / 1e9, timedLineSearch.timeNanos / 1e9 });
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class ValidatingTrainer method measure.
private PointSample measure(@Nonnull final TrainingPhase phase) {
int retries = 0;
do {
if (10 < retries++)
throw new IterativeStopException();
final PointSample currentPoint = phase.trainingSubject.measure(monitor);
if (Double.isFinite(currentPoint.getMean()))
return currentPoint;
phase.orientation.reset();
} while (true);
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class ValidatingTrainer method run.
/**
* Run double.
*
* @return the double
*/
public double run() {
try {
final long timeoutAt = System.currentTimeMillis() + timeout.toMillis();
if (validationSubject.getLayer() instanceof DAGNetwork) {
((DAGNetwork) validationSubject.getLayer()).visitLayers(layer -> {
if (layer instanceof StochasticComponent)
((StochasticComponent) layer).clearNoise();
});
}
@Nonnull final EpochParams epochParams = new EpochParams(timeoutAt, epochIterations, getTrainingSize(), validationSubject.measure(monitor));
int epochNumber = 0;
int iterationNumber = 0;
int lastImprovement = 0;
double lowestValidation = Double.POSITIVE_INFINITY;
while (true) {
if (shouldHalt(monitor, timeoutAt)) {
monitor.log("Training halted");
break;
}
monitor.log(String.format("Epoch parameters: %s, %s", epochParams.trainingSize, epochParams.iterations));
@Nonnull final List<TrainingPhase> regimen = getRegimen();
final long seed = System.nanoTime();
final List<EpochResult> epochResults = IntStream.range(0, regimen.size()).mapToObj(i -> {
final TrainingPhase phase = getRegimen().get(i);
return runPhase(epochParams, phase, i, seed);
}).collect(Collectors.toList());
final EpochResult primaryPhase = epochResults.get(0);
iterationNumber += primaryPhase.iterations;
final double trainingDelta = primaryPhase.currentPoint.getMean() / primaryPhase.priorMean;
if (validationSubject.getLayer() instanceof DAGNetwork) {
((DAGNetwork) validationSubject.getLayer()).visitLayers(layer -> {
if (layer instanceof StochasticComponent)
((StochasticComponent) layer).clearNoise();
});
}
final PointSample currentValidation = validationSubject.measure(monitor);
final double overtraining = Math.log(trainingDelta) / Math.log(currentValidation.getMean() / epochParams.validation.getMean());
final double validationDelta = currentValidation.getMean() / epochParams.validation.getMean();
final double adj1 = Math.pow(Math.log(getTrainingTarget()) / Math.log(validationDelta), adjustmentFactor);
final double adj2 = Math.pow(overtraining / getOvertrainingTarget(), adjustmentFactor);
final double validationMean = currentValidation.getMean();
if (validationMean < lowestValidation) {
lowestValidation = validationMean;
lastImprovement = iterationNumber;
}
monitor.log(String.format("Epoch %d result apply %s iterations, %s/%s samples: {validation *= 2^%.5f; training *= 2^%.3f; Overtraining = %.2f}, {itr*=%.2f, len*=%.2f} %s since improvement; %.4f validation time", ++epochNumber, primaryPhase.iterations, epochParams.trainingSize, getMaxTrainingSize(), Math.log(validationDelta) / Math.log(2), Math.log(trainingDelta) / Math.log(2), overtraining, adj1, adj2, iterationNumber - lastImprovement, validatingMeasurementTime.getAndSet(0) / 1e9));
if (!primaryPhase.continueTraining) {
monitor.log(String.format("Training %d runPhase halted", epochNumber));
break;
}
if (epochParams.trainingSize >= getMaxTrainingSize()) {
final double roll = FastRandom.INSTANCE.random();
if (roll > Math.pow(2 - validationDelta, pessimism)) {
monitor.log(String.format("Training randomly converged: %3f", roll));
break;
} else {
if (iterationNumber - lastImprovement > improvmentStaleThreshold) {
if (disappointments.incrementAndGet() > getDisappointmentThreshold()) {
monitor.log(String.format("Training converged after %s iterations", iterationNumber - lastImprovement));
break;
} else {
monitor.log(String.format("Training failed to converged on %s attempt after %s iterations", disappointments.get(), iterationNumber - lastImprovement));
}
} else {
disappointments.set(0);
}
}
}
if (validationDelta < 1.0 && trainingDelta < 1.0) {
if (adj1 < 1 - adjustmentTolerance || adj1 > 1 + adjustmentTolerance) {
epochParams.iterations = Math.max(getMinEpochIterations(), Math.min(getMaxEpochIterations(), (int) (primaryPhase.iterations * adj1)));
}
if (adj2 < 1 + adjustmentTolerance || adj2 > 1 - adjustmentTolerance) {
epochParams.trainingSize = Math.max(0, Math.min(Math.max(getMinTrainingSize(), Math.min(getMaxTrainingSize(), (int) (epochParams.trainingSize * adj2))), epochParams.trainingSize));
}
} else {
epochParams.trainingSize = Math.max(0, Math.min(Math.max(getMinTrainingSize(), Math.min(getMaxTrainingSize(), epochParams.trainingSize * 5)), epochParams.trainingSize));
epochParams.iterations = 1;
}
epochParams.validation = currentValidation;
}
if (validationSubject.getLayer() instanceof DAGNetwork) {
((DAGNetwork) validationSubject.getLayer()).visitLayers(layer -> {
if (layer instanceof StochasticComponent)
((StochasticComponent) layer).clearNoise();
});
}
return epochParams.validation.getMean();
} catch (@Nonnull final Throwable e) {
throw new RuntimeException(e);
}
}
Aggregations