Search in sources :

Example 41 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    RegressionTreeOptions o = new RegressionTreeOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    RegressionFactory factory = new RegressionFactory(o.splitChar);
    Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
    Dataset<Regressor> train = data.getA();
    Dataset<Regressor> test = data.getB();
    RegressorImpurity impurity;
    switch(o.impurityType) {
        case MAE:
            impurity = new MeanAbsoluteError();
            break;
        case MSE:
            impurity = new MeanSquaredError();
            break;
        default:
            logger.severe("unknown impurity type " + o.impurityType);
            return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null) {
        logger.info(cm.usage());
        return;
    }
    SparseTrainer<Regressor> trainer;
    switch(o.treeType) {
        case CART_INDEPENDENT:
            trainer = new CARTRegressionTrainer(o.depth, o.minChildWeight, o.minImpurityDecrease, o.fraction, o.useRandomSplitPoints, impurity, o.general.seed);
            break;
        case CART_JOINT:
            trainer = new CARTJointRegressionTrainer(o.depth, o.minChildWeight, o.minImpurityDecrease, o.fraction, o.useRandomSplitPoints, impurity, o.normalize, o.general.seed);
            break;
        default:
            logger.severe("unknown tree type " + o.treeType);
            return;
    }
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    SparseModel<Regressor> model = trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
    if (o.printTree) {
        logger.info(model.toString());
    }
    logger.info("Selected features: " + model.getActiveFeatures());
    final long testStart = System.currentTimeMillis();
    RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : RegressorImpurity(org.tribuo.regression.rtree.impurity.RegressorImpurity) UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) RegressionFactory(org.tribuo.regression.RegressionFactory) Dataset(org.tribuo.Dataset) MeanAbsoluteError(org.tribuo.regression.rtree.impurity.MeanAbsoluteError) MeanSquaredError(org.tribuo.regression.rtree.impurity.MeanSquaredError) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Regressor(org.tribuo.regression.Regressor) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 42 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    XGBoostOptions o = new XGBoostOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null) {
        logger.info(cm.usage());
        logger.info("Please supply a training path and a testing path");
        return;
    }
    if (o.ensembleSize == -1) {
        logger.info(cm.usage());
        logger.info("Please supply the number of trees.");
        return;
    }
    RegressionFactory factory = new RegressionFactory();
    Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
    Dataset<Regressor> train = data.getA();
    Dataset<Regressor> test = data.getB();
    // public XGBoostRegressionTrainer(RegressionType rType, int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, long seed) {
    XGBoostRegressionTrainer trainer = new XGBoostRegressionTrainer(o.rType, o.ensembleSize, o.eta, o.gamma, o.depth, o.minWeight, o.subsample, o.subsampleFeatures, o.lambda, o.alpha, o.numThreads, o.quiet, o.general.seed);
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    Model<Regressor> model = trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
    final long testStart = System.currentTimeMillis();
    RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) RegressionFactory(org.tribuo.regression.RegressionFactory) Dataset(org.tribuo.Dataset) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Regressor(org.tribuo.regression.Regressor) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Aggregations

ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)42 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)32 Label (org.tribuo.classification.Label)16 Dataset (org.tribuo.Dataset)15 IOException (java.io.IOException)8 FileOutputStream (java.io.FileOutputStream)7 ObjectOutputStream (java.io.ObjectOutputStream)7 RegressionFactory (org.tribuo.regression.RegressionFactory)7 Regressor (org.tribuo.regression.Regressor)7 RegressionEvaluation (org.tribuo.regression.evaluation.RegressionEvaluation)7 BufferedWriter (java.io.BufferedWriter)4 File (java.io.File)4 OutputStreamWriter (java.io.OutputStreamWriter)4 MutableDataset (org.tribuo.MutableDataset)4 LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)4 LabelEvaluator (org.tribuo.classification.evaluation.LabelEvaluator)4 FileInputStream (java.io.FileInputStream)3 ObjectInputStream (java.io.ObjectInputStream)3 PrintWriter (java.io.PrintWriter)3 ArrayList (java.util.ArrayList)3