use of org.tribuo.regression.sgd.objectives.SquaredLoss in project tribuo by oracle.
the class TestSGDLinear method testNegativeInvocationCount.
@Test
public void testNegativeInvocationCount() {
assertThrows(IllegalArgumentException.class, () -> {
LinearSGDTrainer t = new LinearSGDTrainer(new SquaredLoss(), new AdaGrad(0.1, 0.1), 5, 1000);
t.setInvocationCount(-1);
});
}
use of org.tribuo.regression.sgd.objectives.SquaredLoss in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
SGDOptions o = new SGDOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null || o.general.testingPath == null) {
logger.info(cm.usage());
return;
}
logger.info("Configuring gradient optimiser");
RegressionObjective obj = null;
switch(o.loss) {
case ABSOLUTE:
obj = new AbsoluteLoss();
break;
case SQUARED:
obj = new SquaredLoss();
break;
case HUBER:
obj = new Huber();
break;
default:
logger.warning("Unknown objective function " + o.loss);
logger.info(cm.usage());
return;
}
StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
logger.info(String.format("Set logging interval to %d", o.loggingInterval));
RegressionFactory factory = new RegressionFactory();
Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
Dataset<Regressor> train = data.getA();
Dataset<Regressor> test = data.getB();
Trainer<Regressor> trainer = new LinearSGDTrainer(obj, grad, o.epochs, o.loggingInterval, o.minibatchSize, o.general.seed);
logger.info("Training using " + trainer.toString());
final long trainStart = System.currentTimeMillis();
Model<Regressor> model = trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
final long testStart = System.currentTimeMillis();
RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
if (o.general.outputPath != null) {
o.general.saveModel(model);
}
}
use of org.tribuo.regression.sgd.objectives.SquaredLoss in project tribuo by oracle.
the class EnsembleExportTest method testHeterogeneousRegressionExport.
@Test
public void testHeterogeneousRegressionExport() throws IOException, OrtException {
// Prep data
DataSource<Regressor> trainSource = new NonlinearGaussianDataSource(100, 1L);
MutableDataset<Regressor> train = new MutableDataset<>(trainSource);
DataSource<Regressor> testSource = new NonlinearGaussianDataSource(100, 2L);
MutableDataset<Regressor> test = new MutableDataset<>(testSource);
// Train model
SquaredLoss loss = new SquaredLoss();
AdaGrad adagrad = new AdaGrad(0.1, 0.1);
org.tribuo.regression.sgd.linear.LinearSGDTrainer lr = new org.tribuo.regression.sgd.linear.LinearSGDTrainer(loss, adagrad, 2, 1L);
BaggingTrainer<Regressor> t = new BaggingTrainer<>(lr, AVERAGING, 5);
WeightedEnsembleModel<Regressor> bagModel = (WeightedEnsembleModel<Regressor>) t.train(train);
FMRegressionTrainer fmT = new FMRegressionTrainer(loss, adagrad, 2, 100, 1, 1L, 5, 0.1, true);
AbstractFMModel<Regressor> fmModel = fmT.train(train);
WeightedEnsembleModel<Regressor> ensemble = WeightedEnsembleModel.createEnsembleFromExistingModels("Bag+FM", Arrays.asList(bagModel, fmModel), AVERAGING, new float[] { 0.3f, 0.7f });
// Write out model
Path onnxFile = Files.createTempFile("tribuo-bagging-test", ".onnx");
ensemble.saveONNXModel("org.tribuo.ensemble.test", 1, onnxFile);
OnnxTestUtils.onnxRegressorComparison(ensemble, onnxFile, test, 1e-5);
onnxFile.toFile().delete();
}
use of org.tribuo.regression.sgd.objectives.SquaredLoss in project tribuo by oracle.
the class EnsembleExportTest method testHomogenousRegressionExport.
@Test
public void testHomogenousRegressionExport() throws IOException, OrtException {
// Prep data
DataSource<Regressor> trainSource = new NonlinearGaussianDataSource(100, 1L);
MutableDataset<Regressor> train = new MutableDataset<>(trainSource);
DataSource<Regressor> testSource = new NonlinearGaussianDataSource(100, 2L);
MutableDataset<Regressor> test = new MutableDataset<>(testSource);
// Train model
org.tribuo.regression.sgd.linear.LinearSGDTrainer lr = new org.tribuo.regression.sgd.linear.LinearSGDTrainer(new SquaredLoss(), new AdaGrad(0.1, 0.1), 2, 1L);
BaggingTrainer<Regressor> t = new BaggingTrainer<>(lr, AVERAGING, 5);
WeightedEnsembleModel<Regressor> ensemble = (WeightedEnsembleModel<Regressor>) t.train(train);
// Write out model
Path onnxFile = Files.createTempFile("tribuo-bagging-test", ".onnx");
ensemble.saveONNXModel("org.tribuo.ensemble.test", 1, onnxFile);
OnnxTestUtils.onnxRegressorComparison(ensemble, onnxFile, test, 1e-5);
onnxFile.toFile().delete();
}
use of org.tribuo.regression.sgd.objectives.SquaredLoss in project tribuo by oracle.
the class TrainTest method main.
/**
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
FMRegressionOptions o = new FMRegressionOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null || o.general.testingPath == null) {
logger.info(cm.usage());
return;
}
logger.info("Configuring gradient optimiser");
RegressionObjective obj = null;
switch(o.loss) {
case ABSOLUTE:
obj = new AbsoluteLoss();
break;
case SQUARED:
obj = new SquaredLoss();
break;
case HUBER:
obj = new Huber();
break;
default:
logger.warning("Unknown objective function " + o.loss);
logger.info(cm.usage());
return;
}
StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
logger.info(String.format("Set logging interval to %d", o.loggingInterval));
RegressionFactory factory = new RegressionFactory();
Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
Dataset<Regressor> train = data.getA();
Dataset<Regressor> test = data.getB();
logger.info("Feature domain - " + train.getFeatureIDMap());
Trainer<Regressor> trainer = new FMRegressionTrainer(obj, grad, o.epochs, o.loggingInterval, o.minibatchSize, o.general.seed, o.factorSize, o.variance, o.standardise);
logger.info("Training using " + trainer.toString());
final long trainStart = System.currentTimeMillis();
Model<Regressor> model = trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
final long testStart = System.currentTimeMillis();
RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
if (o.general.outputPath != null) {
o.general.saveModel(model);
}
}
Aggregations