use of org.deeplearning4j.nn.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.
the class BaseVaeScoreWithKeyFunctionAdapter method call.
@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, INDArray>> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}
VariationalAutoencoder vae = getVaeLayer();
List<Tuple2<K, Double>> ret = new ArrayList<>();
List<INDArray> collect = new ArrayList<>(batchSize);
List<K> collectKey = new ArrayList<>(batchSize);
int totalCount = 0;
while (iterator.hasNext()) {
collect.clear();
collectKey.clear();
int nExamples = 0;
while (iterator.hasNext() && nExamples < batchSize) {
Tuple2<K, INDArray> t2 = iterator.next();
INDArray features = t2._2();
int n = features.size(0);
if (n != 1)
throw new IllegalStateException("Cannot score examples with one key per data set if " + "data set contains more than 1 example (numExamples: " + n + ")");
collect.add(features);
collectKey.add(t2._1());
nExamples += n;
}
totalCount += nExamples;
INDArray toScore = Nd4j.vstack(collect);
INDArray scores = computeScore(vae, toScore);
double[] doubleScores = scores.data().asDouble();
for (int i = 0; i < doubleScores.length; i++) {
ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
}
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
if (log.isDebugEnabled()) {
log.debug("Scored {} examples ", totalCount);
}
return ret;
}
use of org.deeplearning4j.nn.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.
the class CGVaeReconstructionErrorWithKeyFunction method getVaeLayer.
@Override
public VariationalAutoencoder getVaeLayer() {
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson((String) jsonConfig.getValue()));
network.init();
INDArray val = ((INDArray) params.value()).unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
network.setParams(val);
Layer l = network.getLayer(0);
if (!(l instanceof VariationalAutoencoder)) {
throw new RuntimeException("Cannot use CGVaeReconstructionErrorWithKeyFunction on network that doesn't have a VAE " + "layer as layer 0. Layer type: " + l.getClass());
}
return (VariationalAutoencoder) l;
}
use of org.deeplearning4j.nn.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.
the class CGVaeReconstructionProbWithKeyFunction method getVaeLayer.
@Override
public VariationalAutoencoder getVaeLayer() {
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson((String) jsonConfig.getValue()));
network.init();
INDArray val = ((INDArray) params.value()).unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
network.setParams(val);
Layer l = network.getLayer(0);
if (!(l instanceof VariationalAutoencoder)) {
throw new RuntimeException("Cannot use CGVaeReconstructionProbWithKeyFunction on network that doesn't have a VAE " + "layer as layer 0. Layer type: " + l.getClass());
}
return (VariationalAutoencoder) l;
}
use of org.deeplearning4j.nn.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.
the class VaeReconstructionErrorWithKeyFunctionAdapter method getVaeLayer.
@Override
public VariationalAutoencoder getVaeLayer() {
MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) jsonConfig.getValue()));
network.init();
INDArray val = ((INDArray) params.value()).unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParameters(val);
Layer l = network.getLayer(0);
if (!(l instanceof VariationalAutoencoder)) {
throw new RuntimeException("Cannot use VaeReconstructionErrorWithKeyFunction on network that doesn't have a VAE " + "layer as layer 0. Layer type: " + l.getClass());
}
return (VariationalAutoencoder) l;
}
use of org.deeplearning4j.nn.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.
the class VaeReconstructionProbWithKeyFunctionAdapter method getVaeLayer.
@Override
public VariationalAutoencoder getVaeLayer() {
MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) jsonConfig.getValue()));
network.init();
INDArray val = ((INDArray) params.value()).unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParameters(val);
Layer l = network.getLayer(0);
if (!(l instanceof VariationalAutoencoder)) {
throw new RuntimeException("Cannot use VaeReconstructionProbWithKeyFunction on network that doesn't have a VAE " + "layer as layer 0. Layer type: " + l.getClass());
}
return (VariationalAutoencoder) l;
}
Aggregations