Search in sources :

Example 1 with RowData

use of hex.genmodel.easy.RowData in project h2o-3 by h2oai.

the class MungeCsv method main.

/**
   * CSV reader and predictor test program.
   *
   * @param args Command-line args.
   * @throws Exception
   */
public static void main(String[] args) throws Exception {
    parseArgs(args);
    GenMunger rawMunger;
    rawMunger = (hex.genmodel.GenMunger) Class.forName(assemblyClassName).newInstance();
    BufferedReader input = new BufferedReader(new FileReader(inputCSVFileName));
    BufferedWriter output = new BufferedWriter(new FileWriter(outputCSVFileName));
    // Emit outputCSV column names.
    String[] rawHeader = rawMunger.outNames();
    StringBuilder header = new StringBuilder();
    for (int i = 0; i < rawHeader.length; ++i) {
        header.append("\"").append(rawHeader[i]).append("\"");
        if (i < rawHeader.length - 1)
            header.append(",");
    }
    output.write(header.toString());
    output.write("\n");
    // Loop over inputCSV one row at a time.
    int lineNum = 0;
    String line;
    try {
        while ((line = input.readLine()) != null) {
            lineNum++;
            // skip the header.
            if (lineNum == 1)
                continue;
            // Parse the CSV line.  Somewhat handles quoted commas.  But this ain't no parser test!
            RowData row;
            try {
                row = parseDataRow(line, rawMunger);
            } catch (NumberFormatException nfe) {
                nfe.printStackTrace();
                System.out.println("Failed to parse row: " + lineNum);
                throw new RuntimeException();
            }
            RowData mungedRow = rawMunger.fit(row);
            for (int i = 0; i < rawMunger.outNames().length; ++i) {
                Object val = mungedRow == null ? Double.NaN : mungedRow.get(rawMunger.outNames()[i]);
                if (val instanceof Double)
                    output.write(String.valueOf(val));
                else
                    output.write("\"" + val + "\"");
                if (i < rawMunger.outNames().length - 1)
                    output.write(",");
            }
            output.write("\n");
        }
    } catch (Exception e) {
        System.out.println("Caught exception on line " + lineNum);
        System.out.println("");
        e.printStackTrace();
        System.exit(1);
    } finally {
        // Clean up.
        output.close();
        input.close();
    }
    // Predictions were successfully generated.  Calling program can now compare them with something.
    System.exit(0);
}
Also used : FileWriter(java.io.FileWriter) BufferedWriter(java.io.BufferedWriter) RowData(hex.genmodel.easy.RowData) GenMunger(hex.genmodel.GenMunger) BufferedReader(java.io.BufferedReader) FileReader(java.io.FileReader)

Example 2 with RowData

use of hex.genmodel.easy.RowData in project h2o-3 by h2oai.

the class PredictCsv method run.

private void run() throws Exception {
    ModelCategory category = model.getModelCategory();
    CSVReader reader = new CSVReader(new FileReader(inputCSVFileName));
    BufferedWriter output = new BufferedWriter(new FileWriter(outputCSVFileName));
    // Emit outputCSV column names.
    switch(category) {
        case AutoEncoder:
            output.write(model.getHeader());
            break;
        case Binomial:
        case Multinomial:
            output.write("predict");
            String[] responseDomainValues = model.getResponseDomainValues();
            for (String s : responseDomainValues) {
                output.write(",");
                output.write(s);
            }
            break;
        case Clustering:
            output.write("cluster");
            break;
        case Regression:
            output.write("predict");
            break;
        default:
            throw new Exception("Unknown model category " + category);
    }
    output.write("\n");
    // Loop over inputCSV one row at a time.
    //
    // TODO: performance of scoring can be considerably improved if instead of scoring each row at a time we passed
    //       all the rows to the score function, in which case it can evaluate each tree for each row, avoiding
    //       multiple rounds of fetching each tree from the filesystem.
    //
    int lineNum = 0;
    try {
        String[] inputColumnNames = null;
        String[] splitLine;
        while ((splitLine = reader.readNext()) != null) {
            lineNum++;
            // Handle the header.
            if (lineNum == 1) {
                inputColumnNames = splitLine;
                continue;
            }
            // Parse the CSV line.  Don't handle quoted commas.  This isn't a parser test.
            RowData row = formatDataRow(splitLine, inputColumnNames);
            // Emit the result to the output file.
            switch(category) {
                case AutoEncoder:
                    {
                        throw new UnsupportedOperationException();
                    // AutoEncoderModelPrediction p = model.predictAutoEncoder(row);
                    // break;
                    }
                case Binomial:
                    {
                        BinomialModelPrediction p = model.predictBinomial(row);
                        output.write(p.label);
                        output.write(",");
                        for (int i = 0; i < p.classProbabilities.length; i++) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(myDoubleToString(p.classProbabilities[i]));
                        }
                        break;
                    }
                case Multinomial:
                    {
                        MultinomialModelPrediction p = model.predictMultinomial(row);
                        output.write(p.label);
                        output.write(",");
                        for (int i = 0; i < p.classProbabilities.length; i++) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(myDoubleToString(p.classProbabilities[i]));
                        }
                        break;
                    }
                case Clustering:
                    {
                        ClusteringModelPrediction p = model.predictClustering(row);
                        output.write(myDoubleToString(p.cluster));
                        break;
                    }
                case Regression:
                    {
                        RegressionModelPrediction p = model.predictRegression(row);
                        output.write(myDoubleToString(p.value));
                        break;
                    }
                default:
                    throw new Exception("Unknown model category " + category);
            }
            output.write("\n");
        }
    } catch (Exception e) {
        System.out.println("Caught exception on line " + lineNum);
        System.out.println("");
        e.printStackTrace();
        System.exit(1);
    }
    // Clean up.
    output.close();
    reader.close();
}
Also used : CSVReader(au.com.bytecode.opencsv.CSVReader) FileWriter(java.io.FileWriter) IOException(java.io.IOException) BufferedWriter(java.io.BufferedWriter) RowData(hex.genmodel.easy.RowData) FileReader(java.io.FileReader) ModelCategory(hex.ModelCategory)

Example 3 with RowData

use of hex.genmodel.easy.RowData in project h2o-3 by h2oai.

the class Model method testJavaScoring.

public boolean testJavaScoring(Frame data, Frame model_predictions, double rel_epsilon, double abs_epsilon, double fraction) {
    ModelBuilder mb = ModelBuilder.make(_parms.algoName().toLowerCase(), null, null);
    boolean havePojo = mb.havePojo();
    boolean haveMojo = mb.haveMojo();
    Random rnd = RandomUtils.getRNG(data.byteSize());
    assert data.numRows() == model_predictions.numRows();
    Frame fr = new Frame(data);
    boolean computeMetrics = data.vec(_output.responseName()) != null && !data.vec(_output.responseName()).isBad();
    try {
        String[] warns = adaptTestForTrain(fr, true, computeMetrics);
        if (warns.length > 0)
            System.err.println(Arrays.toString(warns));
        // Output is in the model's domain, but needs to be mapped to the scored
        // dataset's domain.
        int[] omap = null;
        if (_output.isClassifier()) {
            Vec actual = fr.vec(_output.responseName());
            // Scored/test domain; can be null
            String[] sdomain = actual == null ? null : actual.domain();
            // Domain of predictions (union of test and train)
            String[] mdomain = model_predictions.vec(0).domain();
            if (sdomain != null && !Arrays.equals(mdomain, sdomain)) {
                // Map from model-domain to scoring-domain
                omap = CategoricalWrappedVec.computeMap(mdomain, sdomain);
            }
        }
        String modelName = JCodeGen.toJavaId(_key.toString());
        boolean preview = false;
        GenModel genmodel = null;
        Vec[] dvecs = fr.vecs();
        Vec[] pvecs = model_predictions.vecs();
        double[] features = null;
        int num_errors = 0;
        int num_total = 0;
        // First try internal POJO via fast double[] API
        if (havePojo) {
            try {
                String java_text = toJava(preview, true);
                Class clz = JCodeGen.compile(modelName, java_text);
                genmodel = (GenModel) clz.newInstance();
            } catch (Exception e) {
                e.printStackTrace();
                throw H2O.fail("Internal POJO compilation failed", e);
            }
            features = MemoryManager.malloc8d(genmodel._names.length);
            double[] predictions = MemoryManager.malloc8d(genmodel.nclasses() + 1);
            // Compare predictions, counting mis-predicts
            for (int row = 0; row < fr.numRows(); row++) {
                // For all rows, single-threaded
                if (rnd.nextDouble() >= fraction)
                    continue;
                num_total++;
                // Native Java API
                for (// Build feature set
                int col = 0; // Build feature set
                col < features.length; // Build feature set
                col++) features[col] = dvecs[col].at(row);
                // POJO predictions
                genmodel.score0(features, predictions);
                for (int col = _output.isClassifier() ? 1 : 0; col < pvecs.length; col++) {
                    // Compare predictions
                    // Load internal scoring predictions
                    double d = pvecs[col].at(row);
                    // map categorical response to scoring domain
                    if (col == 0 && omap != null)
                        d = omap[(int) d];
                    if (!MathUtils.compare(predictions[col], d, abs_epsilon, rel_epsilon)) {
                        if (num_errors++ < 10)
                            System.err.println("Predictions mismatch, row " + row + ", col " + model_predictions._names[col] + ", internal prediction=" + d + ", POJO prediction=" + predictions[col]);
                        break;
                    }
                }
            }
        }
        // EasyPredict API with POJO and/or MOJO
        for (int i = 0; i < 2; ++i) {
            if (i == 0 && !havePojo)
                continue;
            if (i == 1 && !haveMojo)
                continue;
            if (i == 1) {
                // MOJO
                final String filename = modelName + ".zip";
                StreamingSchema ss = new StreamingSchema(getMojo(), filename);
                try {
                    FileOutputStream os = new FileOutputStream(ss.getFilename());
                    ss.getStreamWriter().writeTo(os);
                    os.close();
                    genmodel = MojoModel.load(filename);
                    features = MemoryManager.malloc8d(genmodel._names.length);
                } catch (IOException e1) {
                    e1.printStackTrace();
                    throw H2O.fail("Internal MOJO loading failed", e1);
                } finally {
                    boolean deleted = new File(filename).delete();
                    if (!deleted)
                        Log.warn("Failed to delete the file");
                }
            }
            EasyPredictModelWrapper epmw = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel(genmodel).setConvertUnknownCategoricalLevelsToNa(true));
            RowData rowData = new RowData();
            BufferedString bStr = new BufferedString();
            for (int row = 0; row < fr.numRows(); row++) {
                // For all rows, single-threaded
                if (rnd.nextDouble() >= fraction)
                    continue;
                if (genmodel.getModelCategory() == ModelCategory.AutoEncoder)
                    continue;
                // Generate input row
                for (int col = 0; col < features.length; col++) {
                    if (dvecs[col].isString()) {
                        rowData.put(genmodel._names[col], dvecs[col].atStr(bStr, row).toString());
                    } else {
                        double val = dvecs[col].at(row);
                        rowData.put(genmodel._names[col], genmodel._domains[col] == null ? (Double) val : // missing categorical values are kept as NaN, the score0 logic passes it on to bitSetContains()
                        Double.isNaN(val) ? // missing categorical values are kept as NaN, the score0 logic passes it on to bitSetContains()
                        val : //unseen levels are treated as such
                        (int) val < genmodel._domains[col].length ? genmodel._domains[col][(int) val] : "UnknownLevel");
                    }
                }
                // Make a prediction
                AbstractPrediction p;
                try {
                    p = epmw.predict(rowData);
                } catch (PredictException e) {
                    num_errors++;
                    if (num_errors < 20) {
                        System.err.println("EasyPredict threw an exception when predicting row " + rowData);
                        e.printStackTrace();
                    }
                    continue;
                }
                // Convert model predictions and "internal" predictions into the same shape
                double[] expected_preds = new double[pvecs.length];
                double[] actual_preds = new double[pvecs.length];
                for (int col = 0; col < pvecs.length; col++) {
                    // Compare predictions
                    // Load internal scoring predictions
                    double d = pvecs[col].at(row);
                    // map categorical response to scoring domain
                    if (col == 0 && omap != null)
                        d = omap[(int) d];
                    double d2 = Double.NaN;
                    switch(genmodel.getModelCategory()) {
                        case Clustering:
                            d2 = ((ClusteringModelPrediction) p).cluster;
                            break;
                        case Regression:
                            d2 = ((RegressionModelPrediction) p).value;
                            break;
                        case Binomial:
                            BinomialModelPrediction bmp = (BinomialModelPrediction) p;
                            d2 = (col == 0) ? bmp.labelIndex : bmp.classProbabilities[col - 1];
                            break;
                        case Multinomial:
                            MultinomialModelPrediction mmp = (MultinomialModelPrediction) p;
                            d2 = (col == 0) ? mmp.labelIndex : mmp.classProbabilities[col - 1];
                            break;
                        case DimReduction:
                            d2 = ((DimReductionModelPrediction) p).dimensions[col];
                            break;
                    }
                    expected_preds[col] = d;
                    actual_preds[col] = d2;
                }
                // Verify the correctness of the prediction
                num_total++;
                for (int col = genmodel.isClassifier() ? 1 : 0; col < pvecs.length; col++) {
                    if (!MathUtils.compare(actual_preds[col], expected_preds[col], abs_epsilon, rel_epsilon)) {
                        num_errors++;
                        if (num_errors < 20) {
                            System.err.println((i == 0 ? "POJO" : "MOJO") + " EasyPredict Predictions mismatch for row " + rowData);
                            System.err.println("  Expected predictions: " + Arrays.toString(expected_preds));
                            System.err.println("  Actual predictions:   " + Arrays.toString(actual_preds));
                        }
                        break;
                    }
                }
            }
        }
        if (num_errors != 0)
            System.err.println("Number of errors: " + num_errors + (num_errors > 20 ? " (only first 20 are shown)" : "") + " out of " + num_total + " rows tested.");
        return num_errors == 0;
    } finally {
        // Remove temp keys.
        cleanup_adapt(fr, data);
    }
}
Also used : PredictException(hex.genmodel.easy.exception.PredictException) BufferedString(water.parser.BufferedString) EasyPredictModelWrapper(hex.genmodel.easy.EasyPredictModelWrapper) RowData(hex.genmodel.easy.RowData) BufferedString(water.parser.BufferedString) PredictException(hex.genmodel.easy.exception.PredictException) GenModel(hex.genmodel.GenModel) StreamingSchema(water.api.StreamingSchema)

Example 4 with RowData

use of hex.genmodel.easy.RowData in project h2o-3 by h2oai.

the class PredictCsv method formatDataRow.

private static RowData formatDataRow(String[] splitLine, String[] inputColumnNames) {
    // Assemble the input values for the row.
    RowData row = new RowData();
    int maxI = Math.min(inputColumnNames.length, splitLine.length);
    for (int i = 0; i < maxI; i++) {
        String columnName = inputColumnNames[i];
        String cellData = splitLine[i];
        switch(cellData) {
            case "":
            case "NA":
            case "N/A":
            case "-":
                continue;
            default:
                row.put(columnName, cellData);
        }
    }
    return row;
}
Also used : RowData(hex.genmodel.easy.RowData)

Example 5 with RowData

use of hex.genmodel.easy.RowData in project h2o-3 by h2oai.

the class GenMunger method fillDefault.

public RowData fillDefault(String[] vals) {
    RowData row = new RowData();
    String[] types = inTypes();
    String[] names = inNames();
    for (int i = 0; i < types.length; ++i) row.put(names[i], vals == null ? null : valueOf(types[i], vals[i]));
    return row;
}
Also used : RowData(hex.genmodel.easy.RowData)

Aggregations

RowData (hex.genmodel.easy.RowData)7 EasyPredictModelWrapper (hex.genmodel.easy.EasyPredictModelWrapper)2 BufferedWriter (java.io.BufferedWriter)2 FileReader (java.io.FileReader)2 FileWriter (java.io.FileWriter)2 CSVReader (au.com.bytecode.opencsv.CSVReader)1 ModelCategory (hex.ModelCategory)1 GenModel (hex.genmodel.GenModel)1 GenMunger (hex.genmodel.GenMunger)1 PredictException (hex.genmodel.easy.exception.PredictException)1 RegressionModelPrediction (hex.genmodel.easy.prediction.RegressionModelPrediction)1 BufferedReader (java.io.BufferedReader)1 IOException (java.io.IOException)1 URL (java.net.URL)1 StreamingSchema (water.api.StreamingSchema)1 BufferedString (water.parser.BufferedString)1