use of org.tribuo.regression.rtree.impurity.RegressorImpurity in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
RegressionTreeOptions o = new RegressionTreeOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
RegressionFactory factory = new RegressionFactory(o.splitChar);
Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
Dataset<Regressor> train = data.getA();
Dataset<Regressor> test = data.getB();
RegressorImpurity impurity;
switch(o.impurityType) {
case MAE:
impurity = new MeanAbsoluteError();
break;
case MSE:
impurity = new MeanSquaredError();
break;
default:
logger.severe("unknown impurity type " + o.impurityType);
return;
}
if (o.general.trainingPath == null || o.general.testingPath == null) {
logger.info(cm.usage());
return;
}
SparseTrainer<Regressor> trainer;
switch(o.treeType) {
case CART_INDEPENDENT:
trainer = new CARTRegressionTrainer(o.depth, o.minChildWeight, o.minImpurityDecrease, o.fraction, o.useRandomSplitPoints, impurity, o.general.seed);
break;
case CART_JOINT:
trainer = new CARTJointRegressionTrainer(o.depth, o.minChildWeight, o.minImpurityDecrease, o.fraction, o.useRandomSplitPoints, impurity, o.normalize, o.general.seed);
break;
default:
logger.severe("unknown tree type " + o.treeType);
return;
}
logger.info("Training using " + trainer.toString());
final long trainStart = System.currentTimeMillis();
SparseModel<Regressor> model = trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
if (o.printTree) {
logger.info(model.toString());
}
logger.info("Selected features: " + model.getActiveFeatures());
final long testStart = System.currentTimeMillis();
RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
if (o.general.outputPath != null) {
o.general.saveModel(model);
}
}
Aggregations