use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class TrainTest method main.
public static void main(String[] args) throws IOException {
TrainTestOptions o = new TrainTestOptions();
try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
Trainer<Label> trainer = o.trainerOptions.getTrainer();
trainer = o.ensembleOptions.wrapTrainer(trainer);
TrainTestHelper.run(cm, o.general, trainer);
} catch (UsageException e) {
logger.info(e.getMessage());
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
TrainTestOptions o = new TrainTestOptions();
try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
Trainer<Label> trainer = o.trainerOptions.getTrainer();
TrainTestHelper.run(cm, o.general, trainer);
} catch (UsageException e) {
logger.info(e.getMessage());
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
SLMOptions o = new SLMOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null || o.general.testingPath == null) {
logger.info(cm.usage());
return;
}
RegressionFactory factory = new RegressionFactory();
Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
Dataset<Regressor> train = data.getA();
Dataset<Regressor> test = data.getB();
SparseTrainer<Regressor> trainer;
switch(o.algorithm) {
case SFS:
trainer = new SLMTrainer(false, Math.min(train.getFeatureMap().size(), o.maxNumFeatures));
break;
case LARS:
trainer = new LARSTrainer(Math.min(train.getFeatureMap().size(), o.maxNumFeatures));
break;
case LARSLASSO:
trainer = new LARSLassoTrainer(Math.min(train.getFeatureMap().size(), o.maxNumFeatures));
break;
case SFSN:
trainer = new SLMTrainer(true, Math.min(train.getFeatureMap().size(), o.maxNumFeatures));
break;
case ELASTICNET:
trainer = new ElasticNetCDTrainer(o.alpha, o.l1Ratio, 1e-4, o.iterations, false, o.general.seed);
break;
default:
logger.warning("Unknown SLMType, found " + o.algorithm);
return;
}
logger.info("Training using " + trainer.toString());
final long trainStart = System.currentTimeMillis();
SparseModel<Regressor> model = trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
logger.info("Selected features: " + model.getActiveFeatures());
Map<String, SparseVector> weights = ((SparseLinearModel) model).getWeights();
for (Map.Entry<String, SparseVector> e : weights.entrySet()) {
logger.info("Target:" + e.getKey());
logger.info("\tWeights: " + e.getValue());
logger.info("\tWeights one norm: " + e.getValue().oneNorm());
logger.info("\tWeights two norm: " + e.getValue().twoNorm());
}
final long testStart = System.currentTimeMillis();
RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
if (o.general.outputPath != null) {
o.general.saveModel(model);
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
SGDOptions o = new SGDOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null || o.general.testingPath == null) {
logger.info(cm.usage());
return;
}
logger.info("Configuring gradient optimiser");
RegressionObjective obj = null;
switch(o.loss) {
case ABSOLUTE:
obj = new AbsoluteLoss();
break;
case SQUARED:
obj = new SquaredLoss();
break;
case HUBER:
obj = new Huber();
break;
default:
logger.warning("Unknown objective function " + o.loss);
logger.info(cm.usage());
return;
}
StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
logger.info(String.format("Set logging interval to %d", o.loggingInterval));
RegressionFactory factory = new RegressionFactory();
Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
Dataset<Regressor> train = data.getA();
Dataset<Regressor> test = data.getB();
Trainer<Regressor> trainer = new LinearSGDTrainer(obj, grad, o.epochs, o.loggingInterval, o.minibatchSize, o.general.seed);
logger.info("Training using " + trainer.toString());
final long trainStart = System.currentTimeMillis();
Model<Regressor> model = trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
final long testStart = System.currentTimeMillis();
RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
if (o.general.outputPath != null) {
o.general.saveModel(model);
}
}
use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.
the class InformationTheoryDemo method main.
/**
* Runs a simple demo of the information theory functions.
* @param args The CLI arguments.
*/
public static void main(String[] args) {
DemoOptions options = new DemoOptions();
try {
ConfigurationManager cm = new ConfigurationManager(args, options, false);
} catch (UsageException e) {
System.out.println(e.getUsage());
}
List<Integer> x;
List<Integer> y;
List<Integer> z;
switch(options.type) {
case RANDOM:
x = generateUniform(1000, 5);
y = generateUniform(1000, 5);
z = generateUniform(1000, 5);
break;
case XOR:
CachedTriple<List<Integer>, List<Integer>, List<Integer>> trip = generateXOR(1000);
x = trip.getA();
y = trip.getB();
z = trip.getC();
break;
case CORRELATED:
CachedTriple<List<Integer>, List<Integer>, List<Integer>> tripC = generateCorrelated(1000, 5, 0.7, 0.5);
x = tripC.getA();
y = tripC.getB();
z = tripC.getC();
break;
default:
logger.log(Level.WARNING, "Unknown test case, exiting");
return;
}
double hx = InformationTheory.entropy(x);
double hy = InformationTheory.entropy(y);
double hz = InformationTheory.entropy(z);
double hxy = InformationTheory.jointEntropy(x, y);
double hxz = InformationTheory.jointEntropy(x, z);
double hyz = InformationTheory.jointEntropy(y, z);
double ixy = InformationTheory.mi(x, y);
double ixz = InformationTheory.mi(x, z);
double iyz = InformationTheory.mi(y, z);
InformationTheory.GTestStatistics gxy = InformationTheory.gTest(x, y, null);
InformationTheory.GTestStatistics gxz = InformationTheory.gTest(x, z, null);
InformationTheory.GTestStatistics gyz = InformationTheory.gTest(y, z, null);
if (InformationTheory.LOG_BASE == InformationTheory.LOG_2) {
logger.log(Level.INFO, "Using log_2");
} else if (InformationTheory.LOG_BASE == InformationTheory.LOG_E) {
logger.log(Level.INFO, "Using log_e");
} else {
logger.log(Level.INFO, "Using unexpected log base, LOG_BASE = " + InformationTheory.LOG_BASE);
}
logger.log(Level.INFO, "The entropy of X, H(X) is " + hx);
logger.log(Level.INFO, "The entropy of Y, H(Y) is " + hy);
logger.log(Level.INFO, "The entropy of Z, H(Z) is " + hz);
logger.log(Level.INFO, "The joint entropy of X and Y, H(X,Y) is " + hxy);
logger.log(Level.INFO, "The joint entropy of X and Z, H(X,Z) is " + hxz);
logger.log(Level.INFO, "The joint entropy of Y and Z, H(Y,Z) is " + hyz);
logger.log(Level.INFO, "The mutual information between X and Y, I(X;Y) is " + ixy);
logger.log(Level.INFO, "The mutual information between X and Z, I(X;Z) is " + ixz);
logger.log(Level.INFO, "The mutual information between Y and Z, I(Y;Z) is " + iyz);
logger.log(Level.INFO, "The G-Test between X and Y, G(X;Y) is " + gxy);
logger.log(Level.INFO, "The G-Test between X and Z, G(X;Z) is " + gxz);
logger.log(Level.INFO, "The G-Test between Y and Z, G(Y;Z) is " + gyz);
}
Aggregations