Search in sources :

Example 1 with RuntimeOptions

use of deepwater.backends.RuntimeOptions in project h2o-3 by h2oai.

the class DeepwaterMojoReader method readModelData.

@Override
protected void readModelData() throws IOException {
    try {
        _model._network = readblob("model_network");
        _model._parameters = readblob("model_params");
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    // new ImageTrain(_width, _height, _channels, _deviceID, (int)parameters.getOrMakeRealSeed(), _gpu);
    _model._backend = DeepwaterMojoModel.createDeepWaterBackend((String) readkv("backend"));
    if (_model._backend == null) {
        throw new IllegalArgumentException("Couldn't instantiate the Deep Water backend.");
    }
    _model._problem_type = readkv("problem_type");
    _model._mini_batch_size = readkv("mini_batch_size");
    _model._height = readkv("height");
    _model._width = readkv("width");
    _model._channels = readkv("channels");
    _model._nums = readkv("nums");
    _model._cats = readkv("cats");
    _model._catOffsets = readkv("cat_offsets");
    _model._normMul = readkv("norm_mul");
    _model._normSub = readkv("norm_sub");
    _model._normRespMul = readkv("norm_resp_mul");
    _model._normRespSub = readkv("norm_resp_sub");
    _model._useAllFactorLevels = readkv("use_all_factor_levels");
    _model._imageDataSet = new ImageDataSet(_model._width, _model._height, _model._channels, _model._nclasses);
    _model._opts = new RuntimeOptions();
    // ignored - not needed during scoring
    _model._opts.setSeed(0);
    _model._opts.setUseGPU((boolean) readkv("gpu"));
    _model._opts.setDeviceID((int[]) readkv("device_id"));
    _model._backendParams = new BackendParams();
    _model._backendParams.set("mini_batch_size", 1);
    File file = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString() + ".json");
    try {
        FileOutputStream os = new FileOutputStream(file.toString());
        os.write(_model._network);
        os.close();
    } catch (IOException e) {
        e.printStackTrace();
    }
    _model._model = _model._backend.buildNet(_model._imageDataSet, _model._opts, _model._backendParams, _model._nclasses, file.toString());
    // 1) read the raw bytes of the mean image file from the MOJO
    byte[] meanBlob = null;
    try {
        meanBlob = readblob("mean_image_file");
    } catch (IOException e) {
    // e.printStackTrace();
    }
    if (meanBlob != null) {
        // 2) write the mean image file
        File meanFile = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString() + ".mean");
        try {
            FileOutputStream os = new FileOutputStream(meanFile.toString());
            os.write(meanBlob);
            os.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        // 3) tell the backend to use that mean image file (just in case it needs it)
        _model._imageDataSet.setMeanData(_model._backend.loadMeanImage(_model._model, meanFile.toString()));
        // 4) keep a float[] version of the mean array to be used during image processing
        _model._meanImageData = _model._imageDataSet.getMeanData();
    }
    file = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString());
    try {
        FileOutputStream os = new FileOutputStream(file.toString());
        os.write(_model._parameters);
        os.close();
    } catch (IOException e) {
        e.printStackTrace();
    }
    _model._backend.loadParam(_model._model, file.toString());
}
Also used : BackendParams(deepwater.backends.BackendParams) ImageDataSet(deepwater.datasets.ImageDataSet) FileOutputStream(java.io.FileOutputStream) IOException(java.io.IOException) RuntimeOptions(deepwater.backends.RuntimeOptions) File(java.io.File)

Example 2 with RuntimeOptions

use of deepwater.backends.RuntimeOptions in project h2o-3 by h2oai.

the class DeepWaterMXNetIntegrationTest method inceptionPredictionMX.

// This test has nothing to do with H2O - Pure integration test of deepwater/backends/mxnet
@Test
public void inceptionPredictionMX() throws IOException {
    for (boolean gpu : new boolean[] { true, false }) {
        // Set model parameters
        int w = 224, h = 224, channels = 3, nclasses = 1000;
        ImageDataSet id = new ImageDataSet(w, h, channels, nclasses);
        RuntimeOptions opts = new RuntimeOptions();
        opts.setSeed(1234);
        opts.setUseGPU(gpu);
        BackendParams bparm = new BackendParams();
        bparm.set("mini_batch_size", 1);
        // Load the model
        String path = "deepwater/backends/mxnet/models/Inception/";
        BackendModel _model = backend.buildNet(id, opts, bparm, nclasses, StringUtils.expandPath(extractFile(path, "Inception_BN-symbol.json")));
        backend.loadParam(_model, StringUtils.expandPath(extractFile(path, "Inception_BN-0039.params")));
        water.fvec.Frame labels = parse_test_file(extractFile(path, "synset.txt"));
        float[] mean = backend.loadMeanImage(_model, extractFile(path, "mean_224.nd"));
        // Turn the image into a vector of the correct size
        File imgFile = FileUtils.getFile("smalldata/deepwater/imagenet/test2.jpg");
        BufferedImage img = ImageIO.read(imgFile);
        BufferedImage scaledImg = new BufferedImage(w, h, img.getType());
        Graphics2D g2d = scaledImg.createGraphics();
        g2d.drawImage(img, 0, 0, w, h, null);
        g2d.dispose();
        float[] pixels = new float[w * h * channels];
        int r_idx = 0;
        int g_idx = r_idx + w * h;
        int b_idx = g_idx + w * h;
        for (int i = 0; i < h; i++) {
            for (int j = 0; j < w; j++) {
                Color mycolor = new Color(scaledImg.getRGB(j, i));
                int red = mycolor.getRed();
                int green = mycolor.getGreen();
                int blue = mycolor.getBlue();
                pixels[r_idx] = red - mean[r_idx];
                r_idx++;
                pixels[g_idx] = green - mean[g_idx];
                g_idx++;
                pixels[b_idx] = blue - mean[b_idx];
                b_idx++;
            }
        }
        float[] preds = backend.predict(_model, pixels);
        int K = 5;
        int[] topK = new int[K];
        for (int i = 0; i < preds.length; i++) {
            for (int j = 0; j < K; j++) {
                if (preds[i] > preds[topK[j]]) {
                    topK[j] = i;
                    break;
                }
            }
        }
        // Display the top 5 predictions
        StringBuilder sb = new StringBuilder();
        sb.append("\nTop " + K + " predictions:\n");
        BufferedString str = new BufferedString();
        for (int j = 0; j < K; j++) {
            String label = labels.anyVec().atStr(str, topK[j]).toString();
            sb.append(" Score: " + String.format("%.4f", preds[topK[j]]) + "\t" + label + "\n");
        }
        System.out.println("\n\n" + sb.toString() + "\n\n");
        Assert.assertTrue("Illegal predictions!", sb.toString().substring(40, 60).contains("Pembroke"));
        labels.remove();
    }
}
Also used : BackendParams(deepwater.backends.BackendParams) ImageDataSet(deepwater.datasets.ImageDataSet) BufferedString(water.parser.BufferedString) BufferedImage(java.awt.image.BufferedImage) BackendModel(deepwater.backends.BackendModel) BufferedString(water.parser.BufferedString) RuntimeOptions(deepwater.backends.RuntimeOptions) Test(org.junit.Test)

Example 3 with RuntimeOptions

use of deepwater.backends.RuntimeOptions in project h2o-3 by h2oai.

the class DeepWaterAbstractIntegrationTest method buildLENET.

private BackendModel buildLENET() {
    int batch_size = 64;
    int classes = 10;
    ImageDataSet dataset = new ImageDataSet(28, 28, 1, classes);
    RuntimeOptions opts = new RuntimeOptions();
    opts.setUseGPU(true);
    opts.setSeed(1234);
    opts.setDeviceID(0);
    BackendParams bparm = new BackendParams();
    bparm.set("mini_batch_size", batch_size);
    return backend.buildNet(dataset, opts, bparm, classes, "lenet");
}
Also used : BackendParams(deepwater.backends.BackendParams) ImageDataSet(deepwater.datasets.ImageDataSet) RuntimeOptions(deepwater.backends.RuntimeOptions)

Example 4 with RuntimeOptions

use of deepwater.backends.RuntimeOptions in project h2o-3 by h2oai.

the class DeepWaterModelInfo method getRuntimeOptions.

private RuntimeOptions getRuntimeOptions() {
    RuntimeOptions opts = new RuntimeOptions();
    opts.setSeed((int) get_params().getOrMakeRealSeed());
    opts.setUseGPU(get_params()._gpu);
    opts.setDeviceID(get_params()._device_id);
    return opts;
}
Also used : RuntimeOptions(deepwater.backends.RuntimeOptions)

Example 5 with RuntimeOptions

use of deepwater.backends.RuntimeOptions in project h2o-3 by h2oai.

the class DeepWaterModelInfo method setupNativeBackend.

private void setupNativeBackend() {
    try {
        _backend = createDeepWaterBackend(parameters._backend.toString());
        if (_backend == null)
            throw new IllegalArgumentException("No backend found. Cannot build a Deep Water model.");
        ImageDataSet imageDataSet = getImageDataSet();
        RuntimeOptions opts = getRuntimeOptions();
        BackendParams bparms = getBackendParams();
        if (parameters._network != DeepWaterParameters.Network.user) {
            String network = parameters._network == null ? null : parameters._network.toString();
            if (network != null) {
                Log.info("Creating a fresh model of the following network type: " + network);
                _model = _backend.buildNet(imageDataSet, opts, bparms, _classes, network);
            } else {
                Log.info("Creating a fresh model of the following network type: MLP");
                _model = _backend.buildNet(imageDataSet, opts, bparms, _classes, "MLP");
            }
        }
        // load a network if specified
        final String networkDef = parameters._network_definition_file;
        if (networkDef != null && !networkDef.isEmpty()) {
            File f = new File(networkDef);
            if (!f.exists() || f.isDirectory()) {
                throw new RuntimeException("Network definition file " + f + " not found.");
            } else {
                Log.info("Loading the network from: " + f.getAbsolutePath());
                Log.info("Setting the optimizer and initializing the first and last layer.");
                _model = _backend.buildNet(imageDataSet, opts, bparms, _classes, f.getAbsolutePath());
            }
        }
        if (parameters._mean_image_file != null && !parameters._mean_image_file.isEmpty())
            imageDataSet.setMeanData(_backend.loadMeanImage(_model, parameters._mean_image_file));
        _meanData = imageDataSet.getMeanData();
        final String networkParms = parameters._network_parameters_file;
        if (networkParms != null && !networkParms.isEmpty()) {
            File f = new File(networkParms);
            if (!f.exists() || f.isDirectory()) {
                throw new RuntimeException("Network parameter file " + f + " not found.");
            } else {
                Log.info("Loading the parameters (weights/biases) from: " + f.getAbsolutePath());
                assert (_model != null);
                _backend.loadParam(_model, f.getAbsolutePath());
            }
        } else {
            Log.warn("No network parameters file specified. Starting from scratch.");
        }
        //store initial state as early as it's created
        nativeToJava();
    } catch (Throwable t) {
        throw new RuntimeException("Unable to initialize the native Deep Learning backend: " + t.getMessage());
    }
}
Also used : BackendParams(deepwater.backends.BackendParams) ImageDataSet(deepwater.datasets.ImageDataSet) RuntimeOptions(deepwater.backends.RuntimeOptions) File(java.io.File) H2OIllegalArgumentException(water.exceptions.H2OIllegalArgumentException)

Aggregations

RuntimeOptions (deepwater.backends.RuntimeOptions)5 BackendParams (deepwater.backends.BackendParams)4 ImageDataSet (deepwater.datasets.ImageDataSet)4 File (java.io.File)2 BackendModel (deepwater.backends.BackendModel)1 BufferedImage (java.awt.image.BufferedImage)1 FileOutputStream (java.io.FileOutputStream)1 IOException (java.io.IOException)1 Test (org.junit.Test)1 H2OIllegalArgumentException (water.exceptions.H2OIllegalArgumentException)1 BufferedString (water.parser.BufferedString)1