use of com.simiacryptus.mindseye.layers.cudnn.GateBiasLayer in project MindsEye by SimiaCryptus.
the class StyleTransfer method getStyleComponents.
/**
* Gets style components.
*
* @param node the node
* @param network the network
* @param styleParams the style params
* @param mean the mean
* @param covariance the covariance
* @param centeringMode the centering mode
* @return the style components
*/
@Nonnull
public ArrayList<Tuple2<Double, DAGNode>> getStyleComponents(final DAGNode node, final PipelineNetwork network, final LayerStyleParams styleParams, final Tensor mean, final Tensor covariance, final CenteringMode centeringMode) {
ArrayList<Tuple2<Double, DAGNode>> styleComponents = new ArrayList<>();
if (null != styleParams && (styleParams.cov != 0 || styleParams.mean != 0)) {
double meanRms = mean.rms();
double meanScale = 0 == meanRms ? 1 : (1.0 / meanRms);
InnerNode negTarget = network.wrap(new ValueLayer(mean.scale(-1)), new DAGNode[] {});
InnerNode negAvg = network.wrap(new BandAvgReducerLayer().setAlpha(-1), node);
if (styleParams.cov != 0) {
DAGNode recentered;
switch(centeringMode) {
case Origin:
recentered = node;
break;
case Dynamic:
recentered = network.wrap(new GateBiasLayer(), node, negAvg);
break;
case Static:
recentered = network.wrap(new GateBiasLayer(), node, negTarget);
break;
default:
throw new RuntimeException();
}
int[] covDim = covariance.getDimensions();
assert 0 < covDim[2] : Arrays.toString(covDim);
int inputBands = mean.getDimensions()[2];
assert 0 < inputBands : Arrays.toString(mean.getDimensions());
int outputBands = covDim[2] / inputBands;
assert 0 < outputBands : Arrays.toString(covDim) + " / " + inputBands;
double covRms = covariance.rms();
double covScale = 0 == covRms ? 1 : (1.0 / covRms);
styleComponents.add(new Tuple2<>(styleParams.cov, network.wrap(new MeanSqLossLayer().setAlpha(covScale), network.wrap(new ValueLayer(covariance), new DAGNode[] {}), network.wrap(ArtistryUtil.wrapTilesAvg(new GramianLayer()), recentered))));
}
if (styleParams.mean != 0) {
styleComponents.add(new Tuple2<>(styleParams.mean, network.wrap(new MeanSqLossLayer().setAlpha(meanScale), negAvg, negTarget)));
}
}
return styleComponents;
}
Aggregations