Search in sources :

Example 1 with ModelCategory

use of hex.ModelCategory in project h2o-3 by h2oai.

the class PredictCsv method run.

private void run() throws Exception {
    ModelCategory category = model.getModelCategory();
    CSVReader reader = new CSVReader(new FileReader(inputCSVFileName));
    BufferedWriter output = new BufferedWriter(new FileWriter(outputCSVFileName));
    // Emit outputCSV column names.
    switch(category) {
        case AutoEncoder:
            output.write(model.getHeader());
            break;
        case Binomial:
        case Multinomial:
            output.write("predict");
            String[] responseDomainValues = model.getResponseDomainValues();
            for (String s : responseDomainValues) {
                output.write(",");
                output.write(s);
            }
            break;
        case Clustering:
            output.write("cluster");
            break;
        case Regression:
            output.write("predict");
            break;
        default:
            throw new Exception("Unknown model category " + category);
    }
    output.write("\n");
    // Loop over inputCSV one row at a time.
    //
    // TODO: performance of scoring can be considerably improved if instead of scoring each row at a time we passed
    //       all the rows to the score function, in which case it can evaluate each tree for each row, avoiding
    //       multiple rounds of fetching each tree from the filesystem.
    //
    int lineNum = 0;
    try {
        String[] inputColumnNames = null;
        String[] splitLine;
        while ((splitLine = reader.readNext()) != null) {
            lineNum++;
            // Handle the header.
            if (lineNum == 1) {
                inputColumnNames = splitLine;
                continue;
            }
            // Parse the CSV line.  Don't handle quoted commas.  This isn't a parser test.
            RowData row = formatDataRow(splitLine, inputColumnNames);
            // Emit the result to the output file.
            switch(category) {
                case AutoEncoder:
                    {
                        throw new UnsupportedOperationException();
                    // AutoEncoderModelPrediction p = model.predictAutoEncoder(row);
                    // break;
                    }
                case Binomial:
                    {
                        BinomialModelPrediction p = model.predictBinomial(row);
                        output.write(p.label);
                        output.write(",");
                        for (int i = 0; i < p.classProbabilities.length; i++) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(myDoubleToString(p.classProbabilities[i]));
                        }
                        break;
                    }
                case Multinomial:
                    {
                        MultinomialModelPrediction p = model.predictMultinomial(row);
                        output.write(p.label);
                        output.write(",");
                        for (int i = 0; i < p.classProbabilities.length; i++) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(myDoubleToString(p.classProbabilities[i]));
                        }
                        break;
                    }
                case Clustering:
                    {
                        ClusteringModelPrediction p = model.predictClustering(row);
                        output.write(myDoubleToString(p.cluster));
                        break;
                    }
                case Regression:
                    {
                        RegressionModelPrediction p = model.predictRegression(row);
                        output.write(myDoubleToString(p.value));
                        break;
                    }
                default:
                    throw new Exception("Unknown model category " + category);
            }
            output.write("\n");
        }
    } catch (Exception e) {
        System.out.println("Caught exception on line " + lineNum);
        System.out.println("");
        e.printStackTrace();
        System.exit(1);
    }
    // Clean up.
    output.close();
    reader.close();
}
Also used : CSVReader(au.com.bytecode.opencsv.CSVReader) FileWriter(java.io.FileWriter) IOException(java.io.IOException) BufferedWriter(java.io.BufferedWriter) RowData(hex.genmodel.easy.RowData) FileReader(java.io.FileReader) ModelCategory(hex.ModelCategory)

Aggregations

CSVReader (au.com.bytecode.opencsv.CSVReader)1 ModelCategory (hex.ModelCategory)1 RowData (hex.genmodel.easy.RowData)1 BufferedWriter (java.io.BufferedWriter)1 FileReader (java.io.FileReader)1 FileWriter (java.io.FileWriter)1 IOException (java.io.IOException)1