use of hex.genmodel.easy.RowData in project h2o-3 by h2oai.
the class MungeCsv method main.
/**
* CSV reader and predictor test program.
*
* @param args Command-line args.
* @throws Exception
*/
public static void main(String[] args) throws Exception {
parseArgs(args);
GenMunger rawMunger;
rawMunger = (hex.genmodel.GenMunger) Class.forName(assemblyClassName).newInstance();
BufferedReader input = new BufferedReader(new FileReader(inputCSVFileName));
BufferedWriter output = new BufferedWriter(new FileWriter(outputCSVFileName));
// Emit outputCSV column names.
String[] rawHeader = rawMunger.outNames();
StringBuilder header = new StringBuilder();
for (int i = 0; i < rawHeader.length; ++i) {
header.append("\"").append(rawHeader[i]).append("\"");
if (i < rawHeader.length - 1)
header.append(",");
}
output.write(header.toString());
output.write("\n");
// Loop over inputCSV one row at a time.
int lineNum = 0;
String line;
try {
while ((line = input.readLine()) != null) {
lineNum++;
// skip the header.
if (lineNum == 1)
continue;
// Parse the CSV line. Somewhat handles quoted commas. But this ain't no parser test!
RowData row;
try {
row = parseDataRow(line, rawMunger);
} catch (NumberFormatException nfe) {
nfe.printStackTrace();
System.out.println("Failed to parse row: " + lineNum);
throw new RuntimeException();
}
RowData mungedRow = rawMunger.fit(row);
for (int i = 0; i < rawMunger.outNames().length; ++i) {
Object val = mungedRow == null ? Double.NaN : mungedRow.get(rawMunger.outNames()[i]);
if (val instanceof Double)
output.write(String.valueOf(val));
else
output.write("\"" + val + "\"");
if (i < rawMunger.outNames().length - 1)
output.write(",");
}
output.write("\n");
}
} catch (Exception e) {
System.out.println("Caught exception on line " + lineNum);
System.out.println("");
e.printStackTrace();
System.exit(1);
} finally {
// Clean up.
output.close();
input.close();
}
// Predictions were successfully generated. Calling program can now compare them with something.
System.exit(0);
}
use of hex.genmodel.easy.RowData in project h2o-3 by h2oai.
the class PredictCsv method run.
private void run() throws Exception {
ModelCategory category = model.getModelCategory();
CSVReader reader = new CSVReader(new FileReader(inputCSVFileName));
BufferedWriter output = new BufferedWriter(new FileWriter(outputCSVFileName));
// Emit outputCSV column names.
switch(category) {
case AutoEncoder:
output.write(model.getHeader());
break;
case Binomial:
case Multinomial:
output.write("predict");
String[] responseDomainValues = model.getResponseDomainValues();
for (String s : responseDomainValues) {
output.write(",");
output.write(s);
}
break;
case Clustering:
output.write("cluster");
break;
case Regression:
output.write("predict");
break;
default:
throw new Exception("Unknown model category " + category);
}
output.write("\n");
// Loop over inputCSV one row at a time.
//
// TODO: performance of scoring can be considerably improved if instead of scoring each row at a time we passed
// all the rows to the score function, in which case it can evaluate each tree for each row, avoiding
// multiple rounds of fetching each tree from the filesystem.
//
int lineNum = 0;
try {
String[] inputColumnNames = null;
String[] splitLine;
while ((splitLine = reader.readNext()) != null) {
lineNum++;
// Handle the header.
if (lineNum == 1) {
inputColumnNames = splitLine;
continue;
}
// Parse the CSV line. Don't handle quoted commas. This isn't a parser test.
RowData row = formatDataRow(splitLine, inputColumnNames);
// Emit the result to the output file.
switch(category) {
case AutoEncoder:
{
throw new UnsupportedOperationException();
// AutoEncoderModelPrediction p = model.predictAutoEncoder(row);
// break;
}
case Binomial:
{
BinomialModelPrediction p = model.predictBinomial(row);
output.write(p.label);
output.write(",");
for (int i = 0; i < p.classProbabilities.length; i++) {
if (i > 0) {
output.write(",");
}
output.write(myDoubleToString(p.classProbabilities[i]));
}
break;
}
case Multinomial:
{
MultinomialModelPrediction p = model.predictMultinomial(row);
output.write(p.label);
output.write(",");
for (int i = 0; i < p.classProbabilities.length; i++) {
if (i > 0) {
output.write(",");
}
output.write(myDoubleToString(p.classProbabilities[i]));
}
break;
}
case Clustering:
{
ClusteringModelPrediction p = model.predictClustering(row);
output.write(myDoubleToString(p.cluster));
break;
}
case Regression:
{
RegressionModelPrediction p = model.predictRegression(row);
output.write(myDoubleToString(p.value));
break;
}
default:
throw new Exception("Unknown model category " + category);
}
output.write("\n");
}
} catch (Exception e) {
System.out.println("Caught exception on line " + lineNum);
System.out.println("");
e.printStackTrace();
System.exit(1);
}
// Clean up.
output.close();
reader.close();
}
use of hex.genmodel.easy.RowData in project h2o-3 by h2oai.
the class Model method testJavaScoring.
public boolean testJavaScoring(Frame data, Frame model_predictions, double rel_epsilon, double abs_epsilon, double fraction) {
ModelBuilder mb = ModelBuilder.make(_parms.algoName().toLowerCase(), null, null);
boolean havePojo = mb.havePojo();
boolean haveMojo = mb.haveMojo();
Random rnd = RandomUtils.getRNG(data.byteSize());
assert data.numRows() == model_predictions.numRows();
Frame fr = new Frame(data);
boolean computeMetrics = data.vec(_output.responseName()) != null && !data.vec(_output.responseName()).isBad();
try {
String[] warns = adaptTestForTrain(fr, true, computeMetrics);
if (warns.length > 0)
System.err.println(Arrays.toString(warns));
// Output is in the model's domain, but needs to be mapped to the scored
// dataset's domain.
int[] omap = null;
if (_output.isClassifier()) {
Vec actual = fr.vec(_output.responseName());
// Scored/test domain; can be null
String[] sdomain = actual == null ? null : actual.domain();
// Domain of predictions (union of test and train)
String[] mdomain = model_predictions.vec(0).domain();
if (sdomain != null && !Arrays.equals(mdomain, sdomain)) {
// Map from model-domain to scoring-domain
omap = CategoricalWrappedVec.computeMap(mdomain, sdomain);
}
}
String modelName = JCodeGen.toJavaId(_key.toString());
boolean preview = false;
GenModel genmodel = null;
Vec[] dvecs = fr.vecs();
Vec[] pvecs = model_predictions.vecs();
double[] features = null;
int num_errors = 0;
int num_total = 0;
// First try internal POJO via fast double[] API
if (havePojo) {
try {
String java_text = toJava(preview, true);
Class clz = JCodeGen.compile(modelName, java_text);
genmodel = (GenModel) clz.newInstance();
} catch (Exception e) {
e.printStackTrace();
throw H2O.fail("Internal POJO compilation failed", e);
}
features = MemoryManager.malloc8d(genmodel._names.length);
double[] predictions = MemoryManager.malloc8d(genmodel.nclasses() + 1);
// Compare predictions, counting mis-predicts
for (int row = 0; row < fr.numRows(); row++) {
// For all rows, single-threaded
if (rnd.nextDouble() >= fraction)
continue;
num_total++;
// Native Java API
for (// Build feature set
int col = 0; // Build feature set
col < features.length; // Build feature set
col++) features[col] = dvecs[col].at(row);
// POJO predictions
genmodel.score0(features, predictions);
for (int col = _output.isClassifier() ? 1 : 0; col < pvecs.length; col++) {
// Compare predictions
// Load internal scoring predictions
double d = pvecs[col].at(row);
// map categorical response to scoring domain
if (col == 0 && omap != null)
d = omap[(int) d];
if (!MathUtils.compare(predictions[col], d, abs_epsilon, rel_epsilon)) {
if (num_errors++ < 10)
System.err.println("Predictions mismatch, row " + row + ", col " + model_predictions._names[col] + ", internal prediction=" + d + ", POJO prediction=" + predictions[col]);
break;
}
}
}
}
// EasyPredict API with POJO and/or MOJO
for (int i = 0; i < 2; ++i) {
if (i == 0 && !havePojo)
continue;
if (i == 1 && !haveMojo)
continue;
if (i == 1) {
// MOJO
final String filename = modelName + ".zip";
StreamingSchema ss = new StreamingSchema(getMojo(), filename);
try {
FileOutputStream os = new FileOutputStream(ss.getFilename());
ss.getStreamWriter().writeTo(os);
os.close();
genmodel = MojoModel.load(filename);
features = MemoryManager.malloc8d(genmodel._names.length);
} catch (IOException e1) {
e1.printStackTrace();
throw H2O.fail("Internal MOJO loading failed", e1);
} finally {
boolean deleted = new File(filename).delete();
if (!deleted)
Log.warn("Failed to delete the file");
}
}
EasyPredictModelWrapper epmw = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel(genmodel).setConvertUnknownCategoricalLevelsToNa(true));
RowData rowData = new RowData();
BufferedString bStr = new BufferedString();
for (int row = 0; row < fr.numRows(); row++) {
// For all rows, single-threaded
if (rnd.nextDouble() >= fraction)
continue;
if (genmodel.getModelCategory() == ModelCategory.AutoEncoder)
continue;
// Generate input row
for (int col = 0; col < features.length; col++) {
if (dvecs[col].isString()) {
rowData.put(genmodel._names[col], dvecs[col].atStr(bStr, row).toString());
} else {
double val = dvecs[col].at(row);
rowData.put(genmodel._names[col], genmodel._domains[col] == null ? (Double) val : // missing categorical values are kept as NaN, the score0 logic passes it on to bitSetContains()
Double.isNaN(val) ? // missing categorical values are kept as NaN, the score0 logic passes it on to bitSetContains()
val : //unseen levels are treated as such
(int) val < genmodel._domains[col].length ? genmodel._domains[col][(int) val] : "UnknownLevel");
}
}
// Make a prediction
AbstractPrediction p;
try {
p = epmw.predict(rowData);
} catch (PredictException e) {
num_errors++;
if (num_errors < 20) {
System.err.println("EasyPredict threw an exception when predicting row " + rowData);
e.printStackTrace();
}
continue;
}
// Convert model predictions and "internal" predictions into the same shape
double[] expected_preds = new double[pvecs.length];
double[] actual_preds = new double[pvecs.length];
for (int col = 0; col < pvecs.length; col++) {
// Compare predictions
// Load internal scoring predictions
double d = pvecs[col].at(row);
// map categorical response to scoring domain
if (col == 0 && omap != null)
d = omap[(int) d];
double d2 = Double.NaN;
switch(genmodel.getModelCategory()) {
case Clustering:
d2 = ((ClusteringModelPrediction) p).cluster;
break;
case Regression:
d2 = ((RegressionModelPrediction) p).value;
break;
case Binomial:
BinomialModelPrediction bmp = (BinomialModelPrediction) p;
d2 = (col == 0) ? bmp.labelIndex : bmp.classProbabilities[col - 1];
break;
case Multinomial:
MultinomialModelPrediction mmp = (MultinomialModelPrediction) p;
d2 = (col == 0) ? mmp.labelIndex : mmp.classProbabilities[col - 1];
break;
case DimReduction:
d2 = ((DimReductionModelPrediction) p).dimensions[col];
break;
}
expected_preds[col] = d;
actual_preds[col] = d2;
}
// Verify the correctness of the prediction
num_total++;
for (int col = genmodel.isClassifier() ? 1 : 0; col < pvecs.length; col++) {
if (!MathUtils.compare(actual_preds[col], expected_preds[col], abs_epsilon, rel_epsilon)) {
num_errors++;
if (num_errors < 20) {
System.err.println((i == 0 ? "POJO" : "MOJO") + " EasyPredict Predictions mismatch for row " + rowData);
System.err.println(" Expected predictions: " + Arrays.toString(expected_preds));
System.err.println(" Actual predictions: " + Arrays.toString(actual_preds));
}
break;
}
}
}
}
if (num_errors != 0)
System.err.println("Number of errors: " + num_errors + (num_errors > 20 ? " (only first 20 are shown)" : "") + " out of " + num_total + " rows tested.");
return num_errors == 0;
} finally {
// Remove temp keys.
cleanup_adapt(fr, data);
}
}
use of hex.genmodel.easy.RowData in project h2o-3 by h2oai.
the class PredictCsv method formatDataRow.
private static RowData formatDataRow(String[] splitLine, String[] inputColumnNames) {
// Assemble the input values for the row.
RowData row = new RowData();
int maxI = Math.min(inputColumnNames.length, splitLine.length);
for (int i = 0; i < maxI; i++) {
String columnName = inputColumnNames[i];
String cellData = splitLine[i];
switch(cellData) {
case "":
case "NA":
case "N/A":
case "-":
continue;
default:
row.put(columnName, cellData);
}
}
return row;
}
use of hex.genmodel.easy.RowData in project h2o-3 by h2oai.
the class GenMunger method fillDefault.
public RowData fillDefault(String[] vals) {
RowData row = new RowData();
String[] types = inTypes();
String[] names = inNames();
for (int i = 0; i < types.length; ++i) row.put(names[i], vals == null ? null : valueOf(types[i], vals[i]));
return row;
}
Aggregations