Search in sources :

Example 1 with PredictException

use of hex.genmodel.easy.exception.PredictException in project h2o-3 by h2oai.

the class EasyPredictModelWrapper method predictWord2Vec.

/**
   * Lookup word embeddings for a given word (or set of words).
   * @param data RawData structure, every key with a String value will be translated to an embedding
   * @return The prediction
   * @throws PredictException if model is not a WordEmbedding model
   */
public Word2VecPrediction predictWord2Vec(RowData data) throws PredictException {
    validateModelCategory(ModelCategory.WordEmbedding);
    if (!(m instanceof WordEmbeddingModel))
        throw new PredictException("Model is not of the expected type, class = " + m.getClass().getSimpleName());
    final WordEmbeddingModel weModel = (WordEmbeddingModel) m;
    final int vecSize = weModel.getVecSize();
    HashMap<String, float[]> embeddings = new HashMap<>(data.size());
    for (String wordKey : data.keySet()) {
        Object value = data.get(wordKey);
        if (value instanceof String) {
            String word = (String) value;
            embeddings.put(wordKey, weModel.transform0(word, new float[vecSize]));
        }
    }
    Word2VecPrediction p = new Word2VecPrediction();
    p.wordEmbeddings = embeddings;
    return p;
}
Also used : ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) WordEmbeddingModel(hex.genmodel.algos.word2vec.WordEmbeddingModel) PredictException(hex.genmodel.easy.exception.PredictException)

Example 2 with PredictException

use of hex.genmodel.easy.exception.PredictException in project h2o-3 by h2oai.

the class EasyPredictModelWrapper method fillRawData.

private double[] fillRawData(RowData data, double[] rawData) throws PredictException {
    // TODO: refactor
    boolean isImage = m instanceof DeepwaterMojoModel && ((DeepwaterMojoModel) m)._problem_type.equals("image");
    boolean isText = m instanceof DeepwaterMojoModel && ((DeepwaterMojoModel) m)._problem_type.equals("text");
    for (String dataColumnName : data.keySet()) {
        Integer index = modelColumnNameToIndexMap.get(dataColumnName);
        // Skip the "response" column which should not be included in `rawData`
        if (index == null || index >= rawData.length) {
            continue;
        }
        BufferedImage img = null;
        String[] domainValues = m.getDomainValues(index);
        if (domainValues == null) {
            // Column is either numeric or a string (for images or text)
            double value = Double.NaN;
            Object o = data.get(dataColumnName);
            if (o instanceof String) {
                String s = ((String) o).trim();
                // Url to an image given
                if (isImage) {
                    boolean isURL = s.matches("^(https?|ftp|file)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]");
                    try {
                        img = isURL ? ImageIO.read(new URL(s)) : ImageIO.read(new File(s));
                    } catch (IOException e) {
                        throw new PredictException("Couldn't read image from " + s);
                    }
                } else if (isText) {
                    // TODO: use model-specific vectorization of text
                    throw new PredictException("MOJO scoring for text classification is not yet implemented.");
                } else {
                    // numeric
                    try {
                        value = Double.parseDouble(s);
                    } catch (NumberFormatException nfe) {
                        if (!convertInvalidNumbersToNa)
                            throw new PredictNumberFormatException("Unable to parse value: " + s + ", from column: " + dataColumnName + ", as Double; " + nfe.getMessage());
                    }
                }
            } else if (o instanceof Double) {
                value = (Double) o;
            } else if (o instanceof byte[] && isImage) {
                // Read the image from raw bytes
                InputStream is = new ByteArrayInputStream((byte[]) o);
                try {
                    img = ImageIO.read(is);
                } catch (IOException e) {
                    throw new PredictException("Couldn't interpret raw bytes as an image.");
                }
            } else {
                throw new PredictUnknownTypeException("Unexpected object type " + o.getClass().getName() + " for numeric column " + dataColumnName);
            }
            if (isImage && img != null) {
                DeepwaterMojoModel dwm = (DeepwaterMojoModel) m;
                int W = dwm._width;
                int H = dwm._height;
                int C = dwm._channels;
                float[] _destData = new float[W * H * C];
                try {
                    GenModel.img2pixels(img, W, H, C, _destData, 0, dwm._meanImageData);
                } catch (IOException e) {
                    e.printStackTrace();
                    throw new PredictException("Couldn't vectorize image.");
                }
                rawData = new double[_destData.length];
                for (int i = 0; i < rawData.length; ++i) rawData[i] = _destData[i];
                return rawData;
            }
            rawData[index] = value;
        } else {
            // Column has categorical value.
            Object o = data.get(dataColumnName);
            double value;
            if (o instanceof String) {
                String levelName = (String) o;
                HashMap<String, Integer> columnDomainMap = domainMap.get(index);
                Integer levelIndex = columnDomainMap.get(levelName);
                if (levelIndex == null) {
                    levelIndex = columnDomainMap.get(dataColumnName + "." + levelName);
                }
                if (levelIndex == null) {
                    if (convertUnknownCategoricalLevelsToNa) {
                        value = Double.NaN;
                        unknownCategoricalLevelsSeenPerColumn.get(dataColumnName).incrementAndGet();
                    } else {
                        throw new PredictUnknownCategoricalLevelException("Unknown categorical level (" + dataColumnName + "," + levelName + ")", dataColumnName, levelName);
                    }
                } else {
                    value = levelIndex;
                }
            } else if (o instanceof Double && Double.isNaN((double) o)) {
                //Missing factor is the only Double value allowed
                value = (double) o;
            } else {
                throw new PredictUnknownTypeException("Unexpected object type " + o.getClass().getName() + " for categorical column " + dataColumnName);
            }
            rawData[index] = value;
        }
    }
    return rawData;
}
Also used : PredictUnknownCategoricalLevelException(hex.genmodel.easy.exception.PredictUnknownCategoricalLevelException) PredictException(hex.genmodel.easy.exception.PredictException) ByteArrayInputStream(java.io.ByteArrayInputStream) InputStream(java.io.InputStream) PredictNumberFormatException(hex.genmodel.easy.exception.PredictNumberFormatException) IOException(java.io.IOException) BufferedImage(java.awt.image.BufferedImage) URL(java.net.URL) PredictNumberFormatException(hex.genmodel.easy.exception.PredictNumberFormatException) PredictUnknownTypeException(hex.genmodel.easy.exception.PredictUnknownTypeException) DeepwaterMojoModel(hex.genmodel.algos.deepwater.DeepwaterMojoModel) ByteArrayInputStream(java.io.ByteArrayInputStream) File(java.io.File)

Example 3 with PredictException

use of hex.genmodel.easy.exception.PredictException in project h2o-3 by h2oai.

the class Model method testJavaScoring.

public boolean testJavaScoring(Frame data, Frame model_predictions, double rel_epsilon, double abs_epsilon, double fraction) {
    ModelBuilder mb = ModelBuilder.make(_parms.algoName().toLowerCase(), null, null);
    boolean havePojo = mb.havePojo();
    boolean haveMojo = mb.haveMojo();
    Random rnd = RandomUtils.getRNG(data.byteSize());
    assert data.numRows() == model_predictions.numRows();
    Frame fr = new Frame(data);
    boolean computeMetrics = data.vec(_output.responseName()) != null && !data.vec(_output.responseName()).isBad();
    try {
        String[] warns = adaptTestForTrain(fr, true, computeMetrics);
        if (warns.length > 0)
            System.err.println(Arrays.toString(warns));
        // Output is in the model's domain, but needs to be mapped to the scored
        // dataset's domain.
        int[] omap = null;
        if (_output.isClassifier()) {
            Vec actual = fr.vec(_output.responseName());
            // Scored/test domain; can be null
            String[] sdomain = actual == null ? null : actual.domain();
            // Domain of predictions (union of test and train)
            String[] mdomain = model_predictions.vec(0).domain();
            if (sdomain != null && !Arrays.equals(mdomain, sdomain)) {
                // Map from model-domain to scoring-domain
                omap = CategoricalWrappedVec.computeMap(mdomain, sdomain);
            }
        }
        String modelName = JCodeGen.toJavaId(_key.toString());
        boolean preview = false;
        GenModel genmodel = null;
        Vec[] dvecs = fr.vecs();
        Vec[] pvecs = model_predictions.vecs();
        double[] features = null;
        int num_errors = 0;
        int num_total = 0;
        // First try internal POJO via fast double[] API
        if (havePojo) {
            try {
                String java_text = toJava(preview, true);
                Class clz = JCodeGen.compile(modelName, java_text);
                genmodel = (GenModel) clz.newInstance();
            } catch (Exception e) {
                e.printStackTrace();
                throw H2O.fail("Internal POJO compilation failed", e);
            }
            features = MemoryManager.malloc8d(genmodel._names.length);
            double[] predictions = MemoryManager.malloc8d(genmodel.nclasses() + 1);
            // Compare predictions, counting mis-predicts
            for (int row = 0; row < fr.numRows(); row++) {
                // For all rows, single-threaded
                if (rnd.nextDouble() >= fraction)
                    continue;
                num_total++;
                // Native Java API
                for (// Build feature set
                int col = 0; // Build feature set
                col < features.length; // Build feature set
                col++) features[col] = dvecs[col].at(row);
                // POJO predictions
                genmodel.score0(features, predictions);
                for (int col = _output.isClassifier() ? 1 : 0; col < pvecs.length; col++) {
                    // Compare predictions
                    // Load internal scoring predictions
                    double d = pvecs[col].at(row);
                    // map categorical response to scoring domain
                    if (col == 0 && omap != null)
                        d = omap[(int) d];
                    if (!MathUtils.compare(predictions[col], d, abs_epsilon, rel_epsilon)) {
                        if (num_errors++ < 10)
                            System.err.println("Predictions mismatch, row " + row + ", col " + model_predictions._names[col] + ", internal prediction=" + d + ", POJO prediction=" + predictions[col]);
                        break;
                    }
                }
            }
        }
        // EasyPredict API with POJO and/or MOJO
        for (int i = 0; i < 2; ++i) {
            if (i == 0 && !havePojo)
                continue;
            if (i == 1 && !haveMojo)
                continue;
            if (i == 1) {
                // MOJO
                final String filename = modelName + ".zip";
                StreamingSchema ss = new StreamingSchema(getMojo(), filename);
                try {
                    FileOutputStream os = new FileOutputStream(ss.getFilename());
                    ss.getStreamWriter().writeTo(os);
                    os.close();
                    genmodel = MojoModel.load(filename);
                    features = MemoryManager.malloc8d(genmodel._names.length);
                } catch (IOException e1) {
                    e1.printStackTrace();
                    throw H2O.fail("Internal MOJO loading failed", e1);
                } finally {
                    boolean deleted = new File(filename).delete();
                    if (!deleted)
                        Log.warn("Failed to delete the file");
                }
            }
            EasyPredictModelWrapper epmw = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel(genmodel).setConvertUnknownCategoricalLevelsToNa(true));
            RowData rowData = new RowData();
            BufferedString bStr = new BufferedString();
            for (int row = 0; row < fr.numRows(); row++) {
                // For all rows, single-threaded
                if (rnd.nextDouble() >= fraction)
                    continue;
                if (genmodel.getModelCategory() == ModelCategory.AutoEncoder)
                    continue;
                // Generate input row
                for (int col = 0; col < features.length; col++) {
                    if (dvecs[col].isString()) {
                        rowData.put(genmodel._names[col], dvecs[col].atStr(bStr, row).toString());
                    } else {
                        double val = dvecs[col].at(row);
                        rowData.put(genmodel._names[col], genmodel._domains[col] == null ? (Double) val : // missing categorical values are kept as NaN, the score0 logic passes it on to bitSetContains()
                        Double.isNaN(val) ? // missing categorical values are kept as NaN, the score0 logic passes it on to bitSetContains()
                        val : //unseen levels are treated as such
                        (int) val < genmodel._domains[col].length ? genmodel._domains[col][(int) val] : "UnknownLevel");
                    }
                }
                // Make a prediction
                AbstractPrediction p;
                try {
                    p = epmw.predict(rowData);
                } catch (PredictException e) {
                    num_errors++;
                    if (num_errors < 20) {
                        System.err.println("EasyPredict threw an exception when predicting row " + rowData);
                        e.printStackTrace();
                    }
                    continue;
                }
                // Convert model predictions and "internal" predictions into the same shape
                double[] expected_preds = new double[pvecs.length];
                double[] actual_preds = new double[pvecs.length];
                for (int col = 0; col < pvecs.length; col++) {
                    // Compare predictions
                    // Load internal scoring predictions
                    double d = pvecs[col].at(row);
                    // map categorical response to scoring domain
                    if (col == 0 && omap != null)
                        d = omap[(int) d];
                    double d2 = Double.NaN;
                    switch(genmodel.getModelCategory()) {
                        case Clustering:
                            d2 = ((ClusteringModelPrediction) p).cluster;
                            break;
                        case Regression:
                            d2 = ((RegressionModelPrediction) p).value;
                            break;
                        case Binomial:
                            BinomialModelPrediction bmp = (BinomialModelPrediction) p;
                            d2 = (col == 0) ? bmp.labelIndex : bmp.classProbabilities[col - 1];
                            break;
                        case Multinomial:
                            MultinomialModelPrediction mmp = (MultinomialModelPrediction) p;
                            d2 = (col == 0) ? mmp.labelIndex : mmp.classProbabilities[col - 1];
                            break;
                        case DimReduction:
                            d2 = ((DimReductionModelPrediction) p).dimensions[col];
                            break;
                    }
                    expected_preds[col] = d;
                    actual_preds[col] = d2;
                }
                // Verify the correctness of the prediction
                num_total++;
                for (int col = genmodel.isClassifier() ? 1 : 0; col < pvecs.length; col++) {
                    if (!MathUtils.compare(actual_preds[col], expected_preds[col], abs_epsilon, rel_epsilon)) {
                        num_errors++;
                        if (num_errors < 20) {
                            System.err.println((i == 0 ? "POJO" : "MOJO") + " EasyPredict Predictions mismatch for row " + rowData);
                            System.err.println("  Expected predictions: " + Arrays.toString(expected_preds));
                            System.err.println("  Actual predictions:   " + Arrays.toString(actual_preds));
                        }
                        break;
                    }
                }
            }
        }
        if (num_errors != 0)
            System.err.println("Number of errors: " + num_errors + (num_errors > 20 ? " (only first 20 are shown)" : "") + " out of " + num_total + " rows tested.");
        return num_errors == 0;
    } finally {
        // Remove temp keys.
        cleanup_adapt(fr, data);
    }
}
Also used : PredictException(hex.genmodel.easy.exception.PredictException) BufferedString(water.parser.BufferedString) EasyPredictModelWrapper(hex.genmodel.easy.EasyPredictModelWrapper) RowData(hex.genmodel.easy.RowData) BufferedString(water.parser.BufferedString) PredictException(hex.genmodel.easy.exception.PredictException) GenModel(hex.genmodel.GenModel) StreamingSchema(water.api.StreamingSchema)

Aggregations

PredictException (hex.genmodel.easy.exception.PredictException)3 GenModel (hex.genmodel.GenModel)1 DeepwaterMojoModel (hex.genmodel.algos.deepwater.DeepwaterMojoModel)1 WordEmbeddingModel (hex.genmodel.algos.word2vec.WordEmbeddingModel)1 EasyPredictModelWrapper (hex.genmodel.easy.EasyPredictModelWrapper)1 RowData (hex.genmodel.easy.RowData)1 PredictNumberFormatException (hex.genmodel.easy.exception.PredictNumberFormatException)1 PredictUnknownCategoricalLevelException (hex.genmodel.easy.exception.PredictUnknownCategoricalLevelException)1 PredictUnknownTypeException (hex.genmodel.easy.exception.PredictUnknownTypeException)1 BufferedImage (java.awt.image.BufferedImage)1 ByteArrayInputStream (java.io.ByteArrayInputStream)1 File (java.io.File)1 IOException (java.io.IOException)1 InputStream (java.io.InputStream)1 URL (java.net.URL)1 HashMap (java.util.HashMap)1 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)1 StreamingSchema (water.api.StreamingSchema)1 BufferedString (water.parser.BufferedString)1