Search in sources :

Example 1 with BasePretrainNetwork

use of org.deeplearning4j.nn.layers.BasePretrainNetwork in project deeplearning4j by deeplearning4j.

the class LoadAndDraw method main.

/**
     * @param args
     */
public static void main(String[] args) throws Exception {
    MnistDataSetIterator iter = new MnistDataSetIterator(60, 60000);
    @SuppressWarnings("unchecked") ObjectInputStream ois = new ObjectInputStream(new FileInputStream(args[0]));
    BasePretrainNetwork network = (BasePretrainNetwork) ois.readObject();
    try {
        ois.close();
    } catch (IOException e) {
    }
    DataSet test = null;
    while (iter.hasNext()) {
        test = iter.next();
        INDArray reconstructed = network.activate(test.getFeatureMatrix());
        for (int i = 0; i < test.numExamples(); i++) {
            INDArray draw1 = test.get(i).getFeatureMatrix().mul(255);
            INDArray reconstructed2 = reconstructed.getRow(i);
            INDArray draw2 = Nd4j.getDistributions().createBinomial(1, reconstructed2).sample(reconstructed2.shape()).mul(255);
            DrawReconstruction d = new DrawReconstruction(draw1);
            d.title = "REAL";
            d.draw();
            DrawReconstruction d2 = new DrawReconstruction(draw2, 100, 100);
            d2.title = "TEST";
            d2.draw();
            Thread.sleep(10000);
            d.frame.dispose();
            d2.frame.dispose();
        }
    }
}
Also used : MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) BasePretrainNetwork(org.deeplearning4j.nn.layers.BasePretrainNetwork) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) IOException(java.io.IOException) FileInputStream(java.io.FileInputStream) ObjectInputStream(java.io.ObjectInputStream)

Aggregations

FileInputStream (java.io.FileInputStream)1 IOException (java.io.IOException)1 ObjectInputStream (java.io.ObjectInputStream)1 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)1 BasePretrainNetwork (org.deeplearning4j.nn.layers.BasePretrainNetwork)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 DataSet (org.nd4j.linalg.dataset.DataSet)1