Search in sources :

Example 1 with PredictUnknownTypeException

use of hex.genmodel.easy.exception.PredictUnknownTypeException in project h2o-3 by h2oai.

the class EasyPredictModelWrapper method fillRawData.

private double[] fillRawData(RowData data, double[] rawData) throws PredictException {
    // TODO: refactor
    boolean isImage = m instanceof DeepwaterMojoModel && ((DeepwaterMojoModel) m)._problem_type.equals("image");
    boolean isText = m instanceof DeepwaterMojoModel && ((DeepwaterMojoModel) m)._problem_type.equals("text");
    for (String dataColumnName : data.keySet()) {
        Integer index = modelColumnNameToIndexMap.get(dataColumnName);
        // Skip the "response" column which should not be included in `rawData`
        if (index == null || index >= rawData.length) {
            continue;
        }
        BufferedImage img = null;
        String[] domainValues = m.getDomainValues(index);
        if (domainValues == null) {
            // Column is either numeric or a string (for images or text)
            double value = Double.NaN;
            Object o = data.get(dataColumnName);
            if (o instanceof String) {
                String s = ((String) o).trim();
                // Url to an image given
                if (isImage) {
                    boolean isURL = s.matches("^(https?|ftp|file)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]");
                    try {
                        img = isURL ? ImageIO.read(new URL(s)) : ImageIO.read(new File(s));
                    } catch (IOException e) {
                        throw new PredictException("Couldn't read image from " + s);
                    }
                } else if (isText) {
                    // TODO: use model-specific vectorization of text
                    throw new PredictException("MOJO scoring for text classification is not yet implemented.");
                } else {
                    // numeric
                    try {
                        value = Double.parseDouble(s);
                    } catch (NumberFormatException nfe) {
                        if (!convertInvalidNumbersToNa)
                            throw new PredictNumberFormatException("Unable to parse value: " + s + ", from column: " + dataColumnName + ", as Double; " + nfe.getMessage());
                    }
                }
            } else if (o instanceof Double) {
                value = (Double) o;
            } else if (o instanceof byte[] && isImage) {
                // Read the image from raw bytes
                InputStream is = new ByteArrayInputStream((byte[]) o);
                try {
                    img = ImageIO.read(is);
                } catch (IOException e) {
                    throw new PredictException("Couldn't interpret raw bytes as an image.");
                }
            } else {
                throw new PredictUnknownTypeException("Unexpected object type " + o.getClass().getName() + " for numeric column " + dataColumnName);
            }
            if (isImage && img != null) {
                DeepwaterMojoModel dwm = (DeepwaterMojoModel) m;
                int W = dwm._width;
                int H = dwm._height;
                int C = dwm._channels;
                float[] _destData = new float[W * H * C];
                try {
                    GenModel.img2pixels(img, W, H, C, _destData, 0, dwm._meanImageData);
                } catch (IOException e) {
                    e.printStackTrace();
                    throw new PredictException("Couldn't vectorize image.");
                }
                rawData = new double[_destData.length];
                for (int i = 0; i < rawData.length; ++i) rawData[i] = _destData[i];
                return rawData;
            }
            rawData[index] = value;
        } else {
            // Column has categorical value.
            Object o = data.get(dataColumnName);
            double value;
            if (o instanceof String) {
                String levelName = (String) o;
                HashMap<String, Integer> columnDomainMap = domainMap.get(index);
                Integer levelIndex = columnDomainMap.get(levelName);
                if (levelIndex == null) {
                    levelIndex = columnDomainMap.get(dataColumnName + "." + levelName);
                }
                if (levelIndex == null) {
                    if (convertUnknownCategoricalLevelsToNa) {
                        value = Double.NaN;
                        unknownCategoricalLevelsSeenPerColumn.get(dataColumnName).incrementAndGet();
                    } else {
                        throw new PredictUnknownCategoricalLevelException("Unknown categorical level (" + dataColumnName + "," + levelName + ")", dataColumnName, levelName);
                    }
                } else {
                    value = levelIndex;
                }
            } else if (o instanceof Double && Double.isNaN((double) o)) {
                //Missing factor is the only Double value allowed
                value = (double) o;
            } else {
                throw new PredictUnknownTypeException("Unexpected object type " + o.getClass().getName() + " for categorical column " + dataColumnName);
            }
            rawData[index] = value;
        }
    }
    return rawData;
}
Also used : PredictUnknownCategoricalLevelException(hex.genmodel.easy.exception.PredictUnknownCategoricalLevelException) PredictException(hex.genmodel.easy.exception.PredictException) ByteArrayInputStream(java.io.ByteArrayInputStream) InputStream(java.io.InputStream) PredictNumberFormatException(hex.genmodel.easy.exception.PredictNumberFormatException) IOException(java.io.IOException) BufferedImage(java.awt.image.BufferedImage) URL(java.net.URL) PredictNumberFormatException(hex.genmodel.easy.exception.PredictNumberFormatException) PredictUnknownTypeException(hex.genmodel.easy.exception.PredictUnknownTypeException) DeepwaterMojoModel(hex.genmodel.algos.deepwater.DeepwaterMojoModel) ByteArrayInputStream(java.io.ByteArrayInputStream) File(java.io.File)

Aggregations

DeepwaterMojoModel (hex.genmodel.algos.deepwater.DeepwaterMojoModel)1 PredictException (hex.genmodel.easy.exception.PredictException)1 PredictNumberFormatException (hex.genmodel.easy.exception.PredictNumberFormatException)1 PredictUnknownCategoricalLevelException (hex.genmodel.easy.exception.PredictUnknownCategoricalLevelException)1 PredictUnknownTypeException (hex.genmodel.easy.exception.PredictUnknownTypeException)1 BufferedImage (java.awt.image.BufferedImage)1 ByteArrayInputStream (java.io.ByteArrayInputStream)1 File (java.io.File)1 IOException (java.io.IOException)1 InputStream (java.io.InputStream)1 URL (java.net.URL)1