use of com.simiacryptus.mindseye.test.StepRecord in project MindsEye by SimiaCryptus.
the class TrainingTester method train.
private List<StepRecord> train(@Nonnull NotebookOutput log, @Nonnull BiFunction<NotebookOutput, Trainable, List<StepRecord>> opt, @Nonnull Layer layer, @Nonnull Tensor[][] data, @Nonnull boolean... mask) {
try {
int inputs = data[0].length;
@Nonnull final PipelineNetwork network = new PipelineNetwork(inputs);
network.wrap(new MeanSqLossLayer(), network.add(layer, IntStream.range(0, inputs - 1).mapToObj(i -> network.getInput(i)).toArray(i -> new DAGNode[i])), network.getInput(inputs - 1));
@Nonnull ArrayTrainable trainable = new ArrayTrainable(data, network);
if (0 < mask.length)
trainable.setMask(mask);
List<StepRecord> history;
try {
history = opt.apply(log, trainable);
if (history.stream().mapToDouble(x -> x.fitness).min().orElse(1) > 1e-5) {
if (!network.isFrozen()) {
log.p("This training apply resulted in the following configuration:");
log.code(() -> {
return network.state().stream().map(Arrays::toString).reduce((a, b) -> a + "\n" + b).orElse("");
});
}
if (0 < mask.length) {
log.p("And regressed input:");
log.code(() -> {
return Arrays.stream(data).flatMap(x -> Arrays.stream(x)).limit(1).map(x -> x.prettyPrint()).reduce((a, b) -> a + "\n" + b).orElse("");
});
}
log.p("To produce the following output:");
log.code(() -> {
Result[] array = ConstantResult.batchResultArray(pop(data));
@Nullable Result eval = layer.eval(array);
for (@Nonnull Result result : array) {
result.freeRef();
result.getData().freeRef();
}
TensorList tensorList = eval.getData();
eval.freeRef();
String str = tensorList.stream().limit(1).map(x -> {
String s = x.prettyPrint();
x.freeRef();
return s;
}).reduce((a, b) -> a + "\n" + b).orElse("");
tensorList.freeRef();
return str;
});
} else {
log.p("Training Converged");
}
} finally {
trainable.freeRef();
network.freeRef();
}
return history;
} finally {
layer.freeRef();
for (@Nonnull Tensor[] tensors : data) {
for (@Nonnull Tensor tensor : tensors) {
tensor.freeRef();
}
}
}
}
Aggregations