use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.
the class L2NormalizationTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.p("Training a model involves a few different components. First, our model is combined mapCoords a loss function. " + "Then we take that model and combine it mapCoords our training data to define a trainable object. " + "Finally, we use a simple iterative scheme to refine the weights of our model. " + "The final output is the last output value of the loss function when evaluating the last batch.");
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new L12Normalizer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000)) {
@Override
public Layer getLayer() {
return inner.getLayer();
}
@Override
protected double getL1(final Layer layer) {
return 0.0;
}
@Override
protected double getL2(final Layer layer) {
return 1e4;
}
};
return new IterativeTrainer(trainable).setMonitor(monitor).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.
the class NLayerTest method test.
/**
* Test.
*
* @param log the log
*/
public void test(@Nonnull final NotebookOutput log) {
log.h1("%s", getClass().getSimpleName());
@Nonnull final int[] inputDims = getInputDims();
@Nonnull final ArrayList<int[]> workingSpec = new ArrayList<>();
for (final int[] l : dimList) {
workingSpec.add(l);
@Nonnull final Layer layer = buildNetwork(concat(inputDims, workingSpec));
graphviz(log, layer);
test(log, layer, inputDims);
}
}
use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.
the class MnistTestBase method validate.
/**
* Validate.
*
* @param log the log
* @param network the network
*/
public void validate(@Nonnull final NotebookOutput log, @Nonnull final Layer network) {
log.h1("Validation");
log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
log.code(() -> {
return MNIST.validationDataStream().mapToDouble(labeledObject -> predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0).average().getAsDouble() * 100;
});
log.p("Let's examine some incorrectly predicted results in more detail:");
log.code(() -> {
@Nonnull final TableOutput table = new TableOutput();
MNIST.validationDataStream().map(labeledObject -> {
final int actualCategory = parse(labeledObject.label);
@Nullable final double[] predictionSignal = network.eval(labeledObject.data).getData().get(0).getData();
final int[] predictionList = IntStream.range(0, 10).mapToObj(x -> x).sorted(Comparator.comparing(i -> -predictionSignal[i])).mapToInt(x -> x).toArray();
// We will only examine mispredicted rows
if (predictionList[0] == actualCategory)
return null;
@Nonnull final LinkedHashMap<CharSequence, Object> row = new LinkedHashMap<>();
row.put("Image", log.image(labeledObject.data.toGrayImage(), labeledObject.label));
row.put("Prediction", Arrays.stream(predictionList).limit(3).mapToObj(i -> String.format("%d (%.1f%%)", i, 100.0 * predictionSignal[i])).reduce((a, b) -> a + ", " + b).get());
return row;
}).filter(x -> null != x).limit(10).forEach(table::putRow);
return table;
});
}
use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.
the class MnistTestBase method report.
/**
* Report.
*
* @param log the log
* @param monitoringRoot the monitoring root
* @param history the history
* @param network the network
*/
public void report(@Nonnull final NotebookOutput log, @Nonnull final MonitoredObject monitoringRoot, @Nonnull final List<Step> history, @Nonnull final Layer network) {
if (!history.isEmpty()) {
log.code(() -> {
@Nonnull final PlotCanvas plot = ScatterPlot.plot(history.stream().map(step -> new double[] { step.iteration, Math.log10(step.point.getMean()) }).toArray(i -> new double[i][]));
plot.setTitle("Convergence Plot");
plot.setAxisLabels("Iteration", "log10(Fitness)");
plot.setSize(600, 400);
return plot;
});
}
@Nonnull final String modelName = "model" + modelNo++ + ".json";
log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));
log.h1("Metrics");
log.code(() -> {
try {
@Nonnull final ByteArrayOutputStream out = new ByteArrayOutputStream();
JsonUtil.writeJson(out, monitoringRoot.getMetrics());
return out.toString();
} catch (@Nonnull final IOException e) {
throw new RuntimeException(e);
}
});
}
use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.
the class LBFGS method orient.
@Override
public SimpleLineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
// if (getClass().desiredAssertionStatus()) {
// double verify = subject.measure(monitor).getMean();
// double input = measurement.getMean();
// boolean isDifferent = Math.abs(verify - input) > 1e-2;
// if (isDifferent) throw new AssertionError(String.format("Invalid input point: %s != %s", verify, input));
// monitor.log(String.format("Verified input point: %s == %s", verify, input));
// }
addToHistory(measurement, monitor);
@Nonnull final List<PointSample> history = Arrays.asList(this.history.toArray(new PointSample[] {}));
@Nullable final DeltaSet<Layer> result = lbfgs(measurement, monitor, history);
SimpleLineSearchCursor returnValue;
if (null == result) {
@Nonnull DeltaSet<Layer> scale = measurement.delta.scale(-1);
returnValue = cursor(subject, measurement, "GD", scale);
scale.freeRef();
} else {
returnValue = cursor(subject, measurement, "LBFGS", result);
result.freeRef();
}
while (this.history.size() > (null == result ? minHistory : maxHistory)) {
@Nullable final PointSample remove = this.history.pollFirst();
if (verbose) {
monitor.log(String.format("Removed measurement %s to history. Total: %s", Long.toHexString(System.identityHashCode(remove)), history.size()));
}
remove.freeRef();
}
return returnValue;
}
Aggregations