use of com.simiacryptus.mindseye.network.DAGNode in project MindsEye by SimiaCryptus.
the class StyleTransfer method getContentComponents.
/**
* Gets content components.
*
* @param setup the setup
* @param nodeMap the node map
* @return the content components
*/
@Nonnull
public ArrayList<Tuple2<Double, DAGNode>> getContentComponents(NeuralSetup<T> setup, final Map<T, DAGNode> nodeMap) {
ArrayList<Tuple2<Double, DAGNode>> contentComponents = new ArrayList<>();
for (final T layerType : getLayerTypes()) {
final DAGNode node = nodeMap.get(layerType);
final double coeff_content = !setup.style.content.params.containsKey(layerType) ? 0 : setup.style.content.params.get(layerType);
final PipelineNetwork network1 = (PipelineNetwork) node.getNetwork();
if (coeff_content != 0) {
Tensor content = setup.contentTarget.content.get(layerType);
contentComponents.add(new Tuple2<>(coeff_content, network1.wrap(new MeanSqLossLayer().setAlpha(1.0 / content.rms()), node, network1.wrap(new ValueLayer(content), new DAGNode[] {}))));
}
}
return contentComponents;
}
use of com.simiacryptus.mindseye.network.DAGNode in project MindsEye by SimiaCryptus.
the class TrainingTester method train.
private List<StepRecord> train(@Nonnull NotebookOutput log, @Nonnull BiFunction<NotebookOutput, Trainable, List<StepRecord>> opt, @Nonnull Layer layer, @Nonnull Tensor[][] data, @Nonnull boolean... mask) {
try {
int inputs = data[0].length;
@Nonnull final PipelineNetwork network = new PipelineNetwork(inputs);
network.wrap(new MeanSqLossLayer(), network.add(layer, IntStream.range(0, inputs - 1).mapToObj(i -> network.getInput(i)).toArray(i -> new DAGNode[i])), network.getInput(inputs - 1));
@Nonnull ArrayTrainable trainable = new ArrayTrainable(data, network);
if (0 < mask.length)
trainable.setMask(mask);
List<StepRecord> history;
try {
history = opt.apply(log, trainable);
if (history.stream().mapToDouble(x -> x.fitness).min().orElse(1) > 1e-5) {
if (!network.isFrozen()) {
log.p("This training apply resulted in the following configuration:");
log.code(() -> {
return network.state().stream().map(Arrays::toString).reduce((a, b) -> a + "\n" + b).orElse("");
});
}
if (0 < mask.length) {
log.p("And regressed input:");
log.code(() -> {
return Arrays.stream(data).flatMap(x -> Arrays.stream(x)).limit(1).map(x -> x.prettyPrint()).reduce((a, b) -> a + "\n" + b).orElse("");
});
}
log.p("To produce the following output:");
log.code(() -> {
Result[] array = ConstantResult.batchResultArray(pop(data));
@Nullable Result eval = layer.eval(array);
for (@Nonnull Result result : array) {
result.freeRef();
result.getData().freeRef();
}
TensorList tensorList = eval.getData();
eval.freeRef();
String str = tensorList.stream().limit(1).map(x -> {
String s = x.prettyPrint();
x.freeRef();
return s;
}).reduce((a, b) -> a + "\n" + b).orElse("");
tensorList.freeRef();
return str;
});
} else {
log.p("Training Converged");
}
} finally {
trainable.freeRef();
network.freeRef();
}
return history;
} finally {
layer.freeRef();
for (@Nonnull Tensor[] tensors : data) {
for (@Nonnull Tensor tensor : tensors) {
tensor.freeRef();
}
}
}
}
Aggregations