use of org.deeplearning4j.ui.activation.PathUpdate in project deeplearning4j by deeplearning4j.
the class RemoteConvolutionalIterationListener method iterationDone.
/**
* Event listener for each iteration
*
* @param model the model iterating
* @param iteration the iteration number
*/
@Override
public void iterationDone(Model model, int iteration) {
if (iteration % freq == 0) {
List<INDArray> tensors = new ArrayList<>();
int cnt = 0;
Random rnd = new Random();
MultiLayerNetwork l = (MultiLayerNetwork) model;
BufferedImage sourceImage = null;
int sampleDim = -1;
for (Layer layer : l.getLayers()) {
if (layer.type() == Layer.Type.CONVOLUTIONAL) {
INDArray output = layer.activate();
if (sampleDim < 0)
sampleDim = rnd.nextInt(output.shape()[0] - 1) + 1;
if (cnt == 0) {
INDArray inputs = ((ConvolutionLayer) layer).input();
try {
sourceImage = restoreRGBImage(inputs.tensorAlongDimension(sampleDim, new int[] { 3, 2, 1 }));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
INDArray tad = output.tensorAlongDimension(sampleDim, 3, 2, 1);
tensors.add(tad);
cnt++;
}
}
BufferedImage render = rasterizeConvoLayers(tensors, sourceImage);
try {
File tempFile = File.createTempFile("cnn_activations", ".png");
tempFile.deleteOnExit();
ImageIO.write(render, "png", tempFile);
PathUpdate update = new PathUpdate();
//ensure path is set
update.setPath(tempFile.getPath());
//ensure the server is hooked up with the path
//target.request(MediaType.APPLICATION_JSON).post(Entity.entity(update, MediaType.APPLICATION_JSON));
WebReporter.getInstance().queueReport(target, Entity.entity(update, MediaType.APPLICATION_JSON));
if (firstIteration) {
firstIteration = false;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
minibatchNum++;
}
}
Aggregations