Search in sources :

Example 1 with BackendModel

use of deepwater.backends.BackendModel in project h2o-3 by h2oai.

the class DeepWaterMXNetIntegrationTest method inceptionPredictionMX.

// This test has nothing to do with H2O - Pure integration test of deepwater/backends/mxnet
@Test
public void inceptionPredictionMX() throws IOException {
    for (boolean gpu : new boolean[] { true, false }) {
        // Set model parameters
        int w = 224, h = 224, channels = 3, nclasses = 1000;
        ImageDataSet id = new ImageDataSet(w, h, channels, nclasses);
        RuntimeOptions opts = new RuntimeOptions();
        opts.setSeed(1234);
        opts.setUseGPU(gpu);
        BackendParams bparm = new BackendParams();
        bparm.set("mini_batch_size", 1);
        // Load the model
        String path = "deepwater/backends/mxnet/models/Inception/";
        BackendModel _model = backend.buildNet(id, opts, bparm, nclasses, StringUtils.expandPath(extractFile(path, "Inception_BN-symbol.json")));
        backend.loadParam(_model, StringUtils.expandPath(extractFile(path, "Inception_BN-0039.params")));
        water.fvec.Frame labels = parse_test_file(extractFile(path, "synset.txt"));
        float[] mean = backend.loadMeanImage(_model, extractFile(path, "mean_224.nd"));
        // Turn the image into a vector of the correct size
        File imgFile = FileUtils.getFile("smalldata/deepwater/imagenet/test2.jpg");
        BufferedImage img = ImageIO.read(imgFile);
        BufferedImage scaledImg = new BufferedImage(w, h, img.getType());
        Graphics2D g2d = scaledImg.createGraphics();
        g2d.drawImage(img, 0, 0, w, h, null);
        g2d.dispose();
        float[] pixels = new float[w * h * channels];
        int r_idx = 0;
        int g_idx = r_idx + w * h;
        int b_idx = g_idx + w * h;
        for (int i = 0; i < h; i++) {
            for (int j = 0; j < w; j++) {
                Color mycolor = new Color(scaledImg.getRGB(j, i));
                int red = mycolor.getRed();
                int green = mycolor.getGreen();
                int blue = mycolor.getBlue();
                pixels[r_idx] = red - mean[r_idx];
                r_idx++;
                pixels[g_idx] = green - mean[g_idx];
                g_idx++;
                pixels[b_idx] = blue - mean[b_idx];
                b_idx++;
            }
        }
        float[] preds = backend.predict(_model, pixels);
        int K = 5;
        int[] topK = new int[K];
        for (int i = 0; i < preds.length; i++) {
            for (int j = 0; j < K; j++) {
                if (preds[i] > preds[topK[j]]) {
                    topK[j] = i;
                    break;
                }
            }
        }
        // Display the top 5 predictions
        StringBuilder sb = new StringBuilder();
        sb.append("\nTop " + K + " predictions:\n");
        BufferedString str = new BufferedString();
        for (int j = 0; j < K; j++) {
            String label = labels.anyVec().atStr(str, topK[j]).toString();
            sb.append(" Score: " + String.format("%.4f", preds[topK[j]]) + "\t" + label + "\n");
        }
        System.out.println("\n\n" + sb.toString() + "\n\n");
        Assert.assertTrue("Illegal predictions!", sb.toString().substring(40, 60).contains("Pembroke"));
        labels.remove();
    }
}
Also used : BackendParams(deepwater.backends.BackendParams) ImageDataSet(deepwater.datasets.ImageDataSet) BufferedString(water.parser.BufferedString) BufferedImage(java.awt.image.BufferedImage) BackendModel(deepwater.backends.BackendModel) BufferedString(water.parser.BufferedString) RuntimeOptions(deepwater.backends.RuntimeOptions) Test(org.junit.Test)

Example 2 with BackendModel

use of deepwater.backends.BackendModel in project h2o-3 by h2oai.

the class DeepWaterAbstractIntegrationTest method saveLoop.

@Test
public void saveLoop() throws IOException {
    BackendModel m = buildLENET();
    File f = File.createTempFile("saveLoop", ".tmp");
    for (int count = 0; count < 3; count++) {
        Log.info("Iteration: " + count);
        backend.saveParam(m, f.getAbsolutePath());
    }
}
Also used : BackendModel(deepwater.backends.BackendModel)

Example 3 with BackendModel

use of deepwater.backends.BackendModel in project h2o-3 by h2oai.

the class DeepWaterAbstractIntegrationTest method trainPredictLoop.

@Test
public void trainPredictLoop() {
    int batch_size = 64;
    BackendModel m = buildLENET();
    float[] data = new float[28 * 28 * 1 * batch_size];
    float[] labels = new float[batch_size];
    int count = 0;
    while (count++ < 1000) {
        Log.info("Iteration: " + count);
        backend.train(m, data, labels);
        float[] p = backend.predict(m, data);
    }
}
Also used : BackendModel(deepwater.backends.BackendModel)

Example 4 with BackendModel

use of deepwater.backends.BackendModel in project h2o-3 by h2oai.

the class DeepWaterAbstractIntegrationTest method predictLoop.

@Test
public void predictLoop() {
    BackendModel m = buildLENET();
    int batch_size = 64;
    float[] data = new float[28 * 28 * 1 * batch_size];
    int count = 0;
    while (count++ < 3) {
        Log.info("Iteration: " + count);
        backend.predict(m, data);
    }
}
Also used : BackendModel(deepwater.backends.BackendModel)

Example 5 with BackendModel

use of deepwater.backends.BackendModel in project h2o-3 by h2oai.

the class DeepWaterAbstractIntegrationTest method trainLoop.

@Test
public void trainLoop() throws InterruptedException {
    int batch_size = 64;
    BackendModel m = buildLENET();
    float[] data = new float[28 * 28 * 1 * batch_size];
    float[] labels = new float[batch_size];
    int count = 0;
    while (count++ < 1000) {
        Log.info("Iteration: " + count);
        backend.train(m, data, labels);
    }
}
Also used : BackendModel(deepwater.backends.BackendModel)

Aggregations

BackendModel (deepwater.backends.BackendModel)5 BackendParams (deepwater.backends.BackendParams)1 RuntimeOptions (deepwater.backends.RuntimeOptions)1 ImageDataSet (deepwater.datasets.ImageDataSet)1 BufferedImage (java.awt.image.BufferedImage)1 Test (org.junit.Test)1 BufferedString (water.parser.BufferedString)1