Search in sources :

Example 1 with RegressorImpurity

use of org.tribuo.regression.rtree.impurity.RegressorImpurity in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    RegressionTreeOptions o = new RegressionTreeOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    RegressionFactory factory = new RegressionFactory(o.splitChar);
    Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
    Dataset<Regressor> train = data.getA();
    Dataset<Regressor> test = data.getB();
    RegressorImpurity impurity;
    switch(o.impurityType) {
        case MAE:
            impurity = new MeanAbsoluteError();
            break;
        case MSE:
            impurity = new MeanSquaredError();
            break;
        default:
            logger.severe("unknown impurity type " + o.impurityType);
            return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null) {
        logger.info(cm.usage());
        return;
    }
    SparseTrainer<Regressor> trainer;
    switch(o.treeType) {
        case CART_INDEPENDENT:
            trainer = new CARTRegressionTrainer(o.depth, o.minChildWeight, o.minImpurityDecrease, o.fraction, o.useRandomSplitPoints, impurity, o.general.seed);
            break;
        case CART_JOINT:
            trainer = new CARTJointRegressionTrainer(o.depth, o.minChildWeight, o.minImpurityDecrease, o.fraction, o.useRandomSplitPoints, impurity, o.normalize, o.general.seed);
            break;
        default:
            logger.severe("unknown tree type " + o.treeType);
            return;
    }
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    SparseModel<Regressor> model = trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
    if (o.printTree) {
        logger.info(model.toString());
    }
    logger.info("Selected features: " + model.getActiveFeatures());
    final long testStart = System.currentTimeMillis();
    RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : RegressorImpurity(org.tribuo.regression.rtree.impurity.RegressorImpurity) UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) RegressionFactory(org.tribuo.regression.RegressionFactory) Dataset(org.tribuo.Dataset) MeanAbsoluteError(org.tribuo.regression.rtree.impurity.MeanAbsoluteError) MeanSquaredError(org.tribuo.regression.rtree.impurity.MeanSquaredError) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Regressor(org.tribuo.regression.Regressor) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Aggregations

ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)1 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)1 Dataset (org.tribuo.Dataset)1 RegressionFactory (org.tribuo.regression.RegressionFactory)1 Regressor (org.tribuo.regression.Regressor)1 RegressionEvaluation (org.tribuo.regression.evaluation.RegressionEvaluation)1 MeanAbsoluteError (org.tribuo.regression.rtree.impurity.MeanAbsoluteError)1 MeanSquaredError (org.tribuo.regression.rtree.impurity.MeanSquaredError)1 RegressorImpurity (org.tribuo.regression.rtree.impurity.RegressorImpurity)1