Search in sources :

Example 1 with PathUpdate

use of org.deeplearning4j.ui.activation.PathUpdate in project deeplearning4j by deeplearning4j.

the class RemoteConvolutionalIterationListener method iterationDone.

/**
     * Event listener for each iteration
     *
     * @param model     the model iterating
     * @param iteration the iteration number
     */
@Override
public void iterationDone(Model model, int iteration) {
    if (iteration % freq == 0) {
        List<INDArray> tensors = new ArrayList<>();
        int cnt = 0;
        Random rnd = new Random();
        MultiLayerNetwork l = (MultiLayerNetwork) model;
        BufferedImage sourceImage = null;
        int sampleDim = -1;
        for (Layer layer : l.getLayers()) {
            if (layer.type() == Layer.Type.CONVOLUTIONAL) {
                INDArray output = layer.activate();
                if (sampleDim < 0)
                    sampleDim = rnd.nextInt(output.shape()[0] - 1) + 1;
                if (cnt == 0) {
                    INDArray inputs = ((ConvolutionLayer) layer).input();
                    try {
                        sourceImage = restoreRGBImage(inputs.tensorAlongDimension(sampleDim, new int[] { 3, 2, 1 }));
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
                INDArray tad = output.tensorAlongDimension(sampleDim, 3, 2, 1);
                tensors.add(tad);
                cnt++;
            }
        }
        BufferedImage render = rasterizeConvoLayers(tensors, sourceImage);
        try {
            File tempFile = File.createTempFile("cnn_activations", ".png");
            tempFile.deleteOnExit();
            ImageIO.write(render, "png", tempFile);
            PathUpdate update = new PathUpdate();
            //ensure path is set
            update.setPath(tempFile.getPath());
            //ensure the server is hooked up with the path
            //target.request(MediaType.APPLICATION_JSON).post(Entity.entity(update, MediaType.APPLICATION_JSON));
            WebReporter.getInstance().queueReport(target, Entity.entity(update, MediaType.APPLICATION_JSON));
            if (firstIteration) {
                firstIteration = false;
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        minibatchNum++;
    }
}
Also used : ArrayList(java.util.ArrayList) Layer(org.deeplearning4j.nn.api.Layer) ConvolutionLayer(org.deeplearning4j.nn.layers.convolution.ConvolutionLayer) BufferedImage(java.awt.image.BufferedImage) ConvolutionLayer(org.deeplearning4j.nn.layers.convolution.ConvolutionLayer) IOException(java.io.IOException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Random(java.util.Random) PathUpdate(org.deeplearning4j.ui.activation.PathUpdate) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) File(java.io.File)

Aggregations

BufferedImage (java.awt.image.BufferedImage)1 File (java.io.File)1 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1 Random (java.util.Random)1 Layer (org.deeplearning4j.nn.api.Layer)1 ConvolutionLayer (org.deeplearning4j.nn.layers.convolution.ConvolutionLayer)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 PathUpdate (org.deeplearning4j.ui.activation.PathUpdate)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1