Search in sources :

Example 1 with VariationalAutoencoder

use of org.deeplearning4j.nn.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.

the class BaseVaeScoreWithKeyFunctionAdapter method call.

@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, INDArray>> iterator) throws Exception {
    if (!iterator.hasNext()) {
        return Collections.emptyList();
    }
    VariationalAutoencoder vae = getVaeLayer();
    List<Tuple2<K, Double>> ret = new ArrayList<>();
    List<INDArray> collect = new ArrayList<>(batchSize);
    List<K> collectKey = new ArrayList<>(batchSize);
    int totalCount = 0;
    while (iterator.hasNext()) {
        collect.clear();
        collectKey.clear();
        int nExamples = 0;
        while (iterator.hasNext() && nExamples < batchSize) {
            Tuple2<K, INDArray> t2 = iterator.next();
            INDArray features = t2._2();
            int n = features.size(0);
            if (n != 1)
                throw new IllegalStateException("Cannot score examples with one key per data set if " + "data set contains more than 1 example (numExamples: " + n + ")");
            collect.add(features);
            collectKey.add(t2._1());
            nExamples += n;
        }
        totalCount += nExamples;
        INDArray toScore = Nd4j.vstack(collect);
        INDArray scores = computeScore(vae, toScore);
        double[] doubleScores = scores.data().asDouble();
        for (int i = 0; i < doubleScores.length; i++) {
            ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
        }
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    if (log.isDebugEnabled()) {
        log.debug("Scored {} examples ", totalCount);
    }
    return ret;
}
Also used : VariationalAutoencoder(org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) ArrayList(java.util.ArrayList) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2)

Example 2 with VariationalAutoencoder

use of org.deeplearning4j.nn.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.

the class CGVaeReconstructionErrorWithKeyFunction method getVaeLayer.

@Override
public VariationalAutoencoder getVaeLayer() {
    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson((String) jsonConfig.getValue()));
    network.init();
    INDArray val = ((INDArray) params.value()).unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
    network.setParams(val);
    Layer l = network.getLayer(0);
    if (!(l instanceof VariationalAutoencoder)) {
        throw new RuntimeException("Cannot use CGVaeReconstructionErrorWithKeyFunction on network that doesn't have a VAE " + "layer as layer 0. Layer type: " + l.getClass());
    }
    return (VariationalAutoencoder) l;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) VariationalAutoencoder(org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Layer(org.deeplearning4j.nn.api.Layer)

Example 3 with VariationalAutoencoder

use of org.deeplearning4j.nn.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.

the class CGVaeReconstructionProbWithKeyFunction method getVaeLayer.

@Override
public VariationalAutoencoder getVaeLayer() {
    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson((String) jsonConfig.getValue()));
    network.init();
    INDArray val = ((INDArray) params.value()).unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
    network.setParams(val);
    Layer l = network.getLayer(0);
    if (!(l instanceof VariationalAutoencoder)) {
        throw new RuntimeException("Cannot use CGVaeReconstructionProbWithKeyFunction on network that doesn't have a VAE " + "layer as layer 0. Layer type: " + l.getClass());
    }
    return (VariationalAutoencoder) l;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) VariationalAutoencoder(org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Layer(org.deeplearning4j.nn.api.Layer)

Example 4 with VariationalAutoencoder

use of org.deeplearning4j.nn.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.

the class VaeReconstructionErrorWithKeyFunctionAdapter method getVaeLayer.

@Override
public VariationalAutoencoder getVaeLayer() {
    MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) jsonConfig.getValue()));
    network.init();
    INDArray val = ((INDArray) params.value()).unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
    network.setParameters(val);
    Layer l = network.getLayer(0);
    if (!(l instanceof VariationalAutoencoder)) {
        throw new RuntimeException("Cannot use VaeReconstructionErrorWithKeyFunction on network that doesn't have a VAE " + "layer as layer 0. Layer type: " + l.getClass());
    }
    return (VariationalAutoencoder) l;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) VariationalAutoencoder(org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) Layer(org.deeplearning4j.nn.api.Layer)

Example 5 with VariationalAutoencoder

use of org.deeplearning4j.nn.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.

the class VaeReconstructionProbWithKeyFunctionAdapter method getVaeLayer.

@Override
public VariationalAutoencoder getVaeLayer() {
    MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) jsonConfig.getValue()));
    network.init();
    INDArray val = ((INDArray) params.value()).unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
    network.setParameters(val);
    Layer l = network.getLayer(0);
    if (!(l instanceof VariationalAutoencoder)) {
        throw new RuntimeException("Cannot use VaeReconstructionProbWithKeyFunction on network that doesn't have a VAE " + "layer as layer 0. Layer type: " + l.getClass());
    }
    return (VariationalAutoencoder) l;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) VariationalAutoencoder(org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) Layer(org.deeplearning4j.nn.api.Layer)

Aggregations

VariationalAutoencoder (org.deeplearning4j.nn.layers.variational.VariationalAutoencoder)6 INDArray (org.nd4j.linalg.api.ndarray.INDArray)6 Layer (org.deeplearning4j.nn.api.Layer)4 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)3 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)2 Tuple2 (scala.Tuple2)2 ArrayList (java.util.ArrayList)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)1 LossFunctionWrapper (org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper)1 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)1 Test (org.junit.Test)1 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)1 LossMSE (org.nd4j.linalg.lossfunctions.impl.LossMSE)1