use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.
the class IterativeTrainer method measure.
/**
* Measure point sample.
*
* @param reset the reset
* @return the point sample
*/
@Nullable
public PointSample measure(boolean reset) {
@Nullable PointSample currentPoint = null;
int retries = 0;
do {
if (reset) {
orientation.reset();
if (subject.getLayer() instanceof DAGNetwork) {
((DAGNetwork) subject.getLayer()).visitLayers(layer -> {
if (layer instanceof StochasticComponent)
((StochasticComponent) layer).shuffle(StochasticComponent.random.get().nextLong());
});
}
if (!subject.reseed(System.nanoTime())) {
if (retries > 0)
throw new IterativeStopException("Failed to reset training subject");
} else {
monitor.log(String.format("Reset training subject"));
}
}
if (null != currentPoint) {
currentPoint.freeRef();
}
currentPoint = subject.measure(monitor);
} while (!Double.isFinite(currentPoint.getMean()) && 10 < retries++);
if (!Double.isFinite(currentPoint.getMean())) {
currentPoint.freeRef();
throw new IterativeStopException();
}
return currentPoint;
}
use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.
the class IterativeTrainer method run.
/**
* Run double.
*
* @return the double
*/
public double run() {
final long timeoutMs = System.currentTimeMillis() + timeout.toMillis();
long lastIterationTime = System.nanoTime();
@Nullable PointSample currentPoint = measure(true);
mainLoop: while (timeoutMs > System.currentTimeMillis() && currentPoint.getMean() > terminateThreshold) {
if (currentIteration.get() > maxIterations) {
break;
}
currentPoint.freeRef();
currentPoint = measure(true);
assert 0 < currentPoint.delta.getMap().size() : "Nothing to optimize";
subiterationLoop: for (int subiteration = 0; subiteration < iterationsPerSample || iterationsPerSample <= 0; subiteration++) {
if (timeoutMs < System.currentTimeMillis()) {
break mainLoop;
}
if (currentIteration.incrementAndGet() > maxIterations) {
break mainLoop;
}
currentPoint.freeRef();
currentPoint = measure(true);
@Nullable final PointSample _currentPoint = currentPoint;
@Nonnull final TimedResult<LineSearchCursor> timedOrientation = TimedResult.time(() -> orientation.orient(subject, _currentPoint, monitor));
final LineSearchCursor direction = timedOrientation.result;
final CharSequence directionType = direction.getDirectionType();
@Nullable final PointSample previous = currentPoint;
previous.addRef();
try {
@Nonnull final TimedResult<PointSample> timedLineSearch = TimedResult.time(() -> step(direction, directionType, previous));
currentPoint.freeRef();
currentPoint = timedLineSearch.result;
final long now = System.nanoTime();
final CharSequence perfString = String.format("Total: %.4f; Orientation: %.4f; Line Search: %.4f", (now - lastIterationTime) / 1e9, timedOrientation.timeNanos / 1e9, timedLineSearch.timeNanos / 1e9);
lastIterationTime = now;
monitor.log(String.format("Fitness changed from %s to %s", previous.getMean(), currentPoint.getMean()));
if (previous.getMean() <= currentPoint.getMean()) {
if (previous.getMean() < currentPoint.getMean()) {
monitor.log(String.format("Resetting Iteration %s", perfString));
currentPoint.freeRef();
currentPoint = direction.step(0, monitor).point;
} else {
monitor.log(String.format("Static Iteration %s", perfString));
}
if (subject.reseed(System.nanoTime())) {
monitor.log(String.format("Iteration %s failed, retrying. Error: %s", currentIteration.get(), currentPoint.getMean()));
monitor.log(String.format("Previous Error: %s -> %s", previous.getRate(), previous.getMean()));
break subiterationLoop;
} else {
monitor.log(String.format("Iteration %s failed, aborting. Error: %s", currentIteration.get(), currentPoint.getMean()));
monitor.log(String.format("Previous Error: %s -> %s", previous.getRate(), previous.getMean()));
break mainLoop;
}
} else {
monitor.log(String.format("Iteration %s complete. Error: %s " + perfString, currentIteration.get(), currentPoint.getMean()));
}
monitor.onStepComplete(new Step(currentPoint, currentIteration.get()));
} finally {
previous.freeRef();
direction.freeRef();
}
}
}
if (subject.getLayer() instanceof DAGNetwork) {
((DAGNetwork) subject.getLayer()).visitLayers(layer -> {
if (layer instanceof StochasticComponent)
((StochasticComponent) layer).clearNoise();
});
}
double mean = null == currentPoint ? Double.NaN : currentPoint.getMean();
currentPoint.freeRef();
return mean;
}
use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.
the class MnistTestBase method run.
/**
* Run.
*
* @param log the log
*/
public void run(@Nonnull NotebookOutput log) {
@Nonnull final List<Step> history = new ArrayList<>();
@Nonnull final MonitoredObject monitoringRoot = new MonitoredObject();
@Nonnull final TrainingMonitor monitor = getMonitor(history);
final Tensor[][] trainingData = getTrainingData(log);
final DAGNetwork network = buildModel(log);
addMonitoring(network, monitoringRoot);
log.h1("Training");
train(log, network, trainingData, monitor);
report(log, monitoringRoot, history, network);
validate(log, network);
removeMonitoring(network);
}
use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.
the class ImageClassifierTestBase method run.
/**
* Test.
*
* @param log the log
*/
public void run(@Nonnull NotebookOutput log) {
Future<Tensor[][]> submit = Executors.newSingleThreadExecutor().submit(() -> Arrays.stream(EncodingUtil.getImages(log, img -> {
return img;
// return TestUtil.resize(img, 224, 224);
// if(img.getWidth()>img.getHeight()) {
// return TestUtil.resize(img, 224, img.getHeight() * 224 / img.getWidth());
// } else {
// return TestUtil.resize(img, img.getWidth() * 224 / img.getHeight(), 224);
// }
}, 10, new CharSequence[] {})).toArray(i -> new Tensor[i][]));
ImageClassifier vgg16 = getImageClassifier(log);
@Nonnull Layer network = vgg16.getNetwork();
log.h1("Network Diagram");
log.p("This is a diagram of the imported network:");
log.code(() -> {
return Graphviz.fromGraph(TestUtil.toGraph((DAGNetwork) network)).height(4000).width(800).render(Format.PNG).toImage();
});
// @javax.annotation.Nonnull SerializationTest serializationTest = new SerializationTest();
// serializationTest.setPersist(true);
// serializationTest.test(log, network, (Tensor[]) null);
log.h1("Predictions");
Tensor[][] images;
try {
images = submit.get();
} catch (Exception e) {
throw new RuntimeException(e);
}
@Nonnull Map<CharSequence, List<LinkedHashMap<CharSequence, Double>>> modelPredictions = new HashMap<>();
modelPredictions.put("Source", predict(log, vgg16, network, images));
network.freeRef();
// serializationTest.getModels().forEach((precision, model) -> {
// log.h2(precision.name());
// modelPredictions.put(precision.name(), predict(log, vgg16, model, images));
// });
log.h1("Result");
log.code(() -> {
@Nonnull TableOutput tableOutput = new TableOutput();
for (int i = 0; i < images.length; i++) {
int index = i;
@Nonnull HashMap<CharSequence, Object> row = new HashMap<>();
row.put("Image", log.image(images[i][1].toImage(), ""));
modelPredictions.forEach((model, predictions) -> {
row.put(model, predictions.get(index).entrySet().stream().map(e -> String.format("%s -> %.2f", e.getKey(), 100 * e.getValue())).reduce((a, b) -> a + "<br/>" + b).get());
});
tableOutput.putRow(row);
}
return tableOutput;
}, 256 * 1024);
// log.p("CudaSystem Statistics:");
// log.code(() -> {
// return TestUtil.toFormattedJson(CudaSystem.getExecutionStatistics());
// });
}
Aggregations