Search in sources :

Example 1 with RegressionObjective

use of org.tribuo.regression.sgd.RegressionObjective in project tribuo by oracle.

the class TrainTest method main.

/**
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    FMRegressionOptions o = new FMRegressionOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null) {
        logger.info(cm.usage());
        return;
    }
    logger.info("Configuring gradient optimiser");
    RegressionObjective obj = null;
    switch(o.loss) {
        case ABSOLUTE:
            obj = new AbsoluteLoss();
            break;
        case SQUARED:
            obj = new SquaredLoss();
            break;
        case HUBER:
            obj = new Huber();
            break;
        default:
            logger.warning("Unknown objective function " + o.loss);
            logger.info(cm.usage());
            return;
    }
    StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
    logger.info(String.format("Set logging interval to %d", o.loggingInterval));
    RegressionFactory factory = new RegressionFactory();
    Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
    Dataset<Regressor> train = data.getA();
    Dataset<Regressor> test = data.getB();
    logger.info("Feature domain - " + train.getFeatureIDMap());
    Trainer<Regressor> trainer = new FMRegressionTrainer(obj, grad, o.epochs, o.loggingInterval, o.minibatchSize, o.general.seed, o.factorSize, o.variance, o.standardise);
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    Model<Regressor> model = trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
    final long testStart = System.currentTimeMillis();
    RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) RegressionObjective(org.tribuo.regression.sgd.RegressionObjective) SquaredLoss(org.tribuo.regression.sgd.objectives.SquaredLoss) RegressionFactory(org.tribuo.regression.RegressionFactory) Dataset(org.tribuo.Dataset) Huber(org.tribuo.regression.sgd.objectives.Huber) AbsoluteLoss(org.tribuo.regression.sgd.objectives.AbsoluteLoss) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) StochasticGradientOptimiser(org.tribuo.math.StochasticGradientOptimiser) Regressor(org.tribuo.regression.Regressor) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Aggregations

ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)1 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)1 Dataset (org.tribuo.Dataset)1 StochasticGradientOptimiser (org.tribuo.math.StochasticGradientOptimiser)1 RegressionFactory (org.tribuo.regression.RegressionFactory)1 Regressor (org.tribuo.regression.Regressor)1 RegressionEvaluation (org.tribuo.regression.evaluation.RegressionEvaluation)1 RegressionObjective (org.tribuo.regression.sgd.RegressionObjective)1 AbsoluteLoss (org.tribuo.regression.sgd.objectives.AbsoluteLoss)1 Huber (org.tribuo.regression.sgd.objectives.Huber)1 SquaredLoss (org.tribuo.regression.sgd.objectives.SquaredLoss)1