use of hex.ModelCategory in project h2o-3 by h2oai.
the class PredictCsv method run.
private void run() throws Exception {
ModelCategory category = model.getModelCategory();
CSVReader reader = new CSVReader(new FileReader(inputCSVFileName));
BufferedWriter output = new BufferedWriter(new FileWriter(outputCSVFileName));
// Emit outputCSV column names.
switch(category) {
case AutoEncoder:
output.write(model.getHeader());
break;
case Binomial:
case Multinomial:
output.write("predict");
String[] responseDomainValues = model.getResponseDomainValues();
for (String s : responseDomainValues) {
output.write(",");
output.write(s);
}
break;
case Clustering:
output.write("cluster");
break;
case Regression:
output.write("predict");
break;
default:
throw new Exception("Unknown model category " + category);
}
output.write("\n");
// Loop over inputCSV one row at a time.
//
// TODO: performance of scoring can be considerably improved if instead of scoring each row at a time we passed
// all the rows to the score function, in which case it can evaluate each tree for each row, avoiding
// multiple rounds of fetching each tree from the filesystem.
//
int lineNum = 0;
try {
String[] inputColumnNames = null;
String[] splitLine;
while ((splitLine = reader.readNext()) != null) {
lineNum++;
// Handle the header.
if (lineNum == 1) {
inputColumnNames = splitLine;
continue;
}
// Parse the CSV line. Don't handle quoted commas. This isn't a parser test.
RowData row = formatDataRow(splitLine, inputColumnNames);
// Emit the result to the output file.
switch(category) {
case AutoEncoder:
{
throw new UnsupportedOperationException();
// AutoEncoderModelPrediction p = model.predictAutoEncoder(row);
// break;
}
case Binomial:
{
BinomialModelPrediction p = model.predictBinomial(row);
output.write(p.label);
output.write(",");
for (int i = 0; i < p.classProbabilities.length; i++) {
if (i > 0) {
output.write(",");
}
output.write(myDoubleToString(p.classProbabilities[i]));
}
break;
}
case Multinomial:
{
MultinomialModelPrediction p = model.predictMultinomial(row);
output.write(p.label);
output.write(",");
for (int i = 0; i < p.classProbabilities.length; i++) {
if (i > 0) {
output.write(",");
}
output.write(myDoubleToString(p.classProbabilities[i]));
}
break;
}
case Clustering:
{
ClusteringModelPrediction p = model.predictClustering(row);
output.write(myDoubleToString(p.cluster));
break;
}
case Regression:
{
RegressionModelPrediction p = model.predictRegression(row);
output.write(myDoubleToString(p.value));
break;
}
default:
throw new Exception("Unknown model category " + category);
}
output.write("\n");
}
} catch (Exception e) {
System.out.println("Caught exception on line " + lineNum);
System.out.println("");
e.printStackTrace();
System.exit(1);
}
// Clean up.
output.close();
reader.close();
}
Aggregations