use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
TrainTestOptions o = new TrainTestOptions();
try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
Trainer<Label> trainer = o.libsvmOptions.getTrainer();
trainer = o.ensembleOptions.wrapTrainer(trainer);
Model<Label> model = TrainTestHelper.run(cm, o.general, trainer);
if (model instanceof LibSVMClassificationModel) {
logger.info("Used " + ((LibSVMClassificationModel) model).getNumberOfSupportVectors() + " support vectors");
}
} catch (UsageException e) {
System.out.println(e.getUsage());
} catch (ArgumentException e) {
System.out.println("Invalid argument: " + e.getMessage());
}
}
use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.
the class RunAll method main.
/**
* Runs the RunALL CLI.
* @param args The CLI arguments.
* @throws IOException If it failed to load the data.
*/
public static void main(String[] args) throws IOException {
LabsLogFormatter.setAllLogFormatters();
RunAllOptions o = new RunAllOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null || o.general.testingPath == null || o.directory == null) {
logger.info(cm.usage());
System.exit(1);
}
Pair<Dataset<Label>, Dataset<Label>> data = null;
try {
data = o.general.load(new LabelFactory());
} catch (IOException e) {
logger.log(Level.SEVERE, "Failed to load data", e);
System.exit(1);
}
Dataset<Label> train = data.getA();
Dataset<Label> test = data.getB();
logger.info("Creating directory - " + o.directory.toString());
if (!o.directory.exists() && !o.directory.mkdirs()) {
logger.warning("Failed to create directory.");
}
Map<String, Double> performances = new HashMap<>();
List<Trainer> trainers = cm.lookupAll(Trainer.class);
for (Trainer<?> t : trainers) {
String name = t.getClass().getSimpleName();
logger.info("Training model using " + t.toString());
// configuration system cast.
@SuppressWarnings("unchecked") Model<Label> curModel = ((Trainer<Label>) t).train(train);
LabelEvaluator evaluator = new LabelEvaluator();
LabelEvaluation evaluation = evaluator.evaluate(curModel, test);
Double old = performances.put(name, evaluation.microAveragedF1());
if (old != null) {
logger.info("Found two trainers with the name " + name);
}
String outputPath = o.directory.toString() + "/" + name;
try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(outputPath + ".model"))) {
oos.writeObject(curModel);
}
try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(outputPath + ".output"), StandardCharsets.UTF_8))) {
writer.println("Model = " + name);
writer.println("Provenance = " + curModel.toString());
writer.println();
ConfusionMatrix<Label> matrix = evaluation.getConfusionMatrix();
writer.println("ConfusionMatrix:\n" + matrix.toString());
writer.println();
writer.println("Evaluation:\n" + evaluation.toString());
}
}
for (Map.Entry<String, Double> e : performances.entrySet()) {
logger.info("Trainer = " + e.getKey() + ", F1 = " + e.getValue());
}
}
use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
LibSVMOptions o = new LibSVMOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null || o.general.testingPath == null) {
logger.info(cm.usage());
return;
}
RegressionFactory factory = new RegressionFactory();
Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
Dataset<Regressor> train = data.getA();
Dataset<Regressor> test = data.getB();
SVMParameters<Regressor> parameters = new SVMParameters<>(new SVMRegressionType(o.svmType), o.kernelType);
parameters.setGamma(o.gamma);
parameters.setCoeff(o.coeff);
parameters.setDegree(o.degree);
Trainer<Regressor> trainer = new LibSVMRegressionTrainer(parameters, o.standardize);
logger.info("Training using " + trainer.toString());
final long trainStart = System.currentTimeMillis();
Model<Regressor> model = trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
logger.info("Support vectors - " + ((LibSVMRegressionModel) model).getNumberOfSupportVectors());
final long testStart = System.currentTimeMillis();
RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
if (o.general.outputPath != null) {
o.general.saveModel(model);
}
}
use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.
the class OCIModelCLI method main.
/**
* Entry point for the OCIModelCLI.
* @param args The CLI arguments.
* @throws IOException If the model or dataset failed to load.
* @throws ClassNotFoundException If the model or dataset had an unknown class.
*/
public static void main(String[] args) throws IOException, ClassNotFoundException {
OCIModelOptions o = new OCIModelOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o, false);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
switch(o.mode) {
case CREATE_AND_DEPLOY:
createModelAndDeploy(o);
break;
case DEPLOY:
deploy(o);
break;
case SCORE:
modelScoring(o);
break;
}
}
use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.
the class TrainTest method main.
/**
* Runs a TrainTest CLI.
* @param args the command line arguments
* @throws IOException if there is any error reading the examples.
*/
public static void main(String[] args) throws IOException {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
KMeansOptions o = new KMeansOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null) {
logger.info(cm.usage());
return;
}
ClusteringFactory factory = new ClusteringFactory();
Pair<Dataset<ClusterID>, Dataset<ClusterID>> data = o.general.load(factory);
Dataset<ClusterID> train = data.getA();
// public KMeansTrainer(int centroids, int iterations, DistanceType distType, int numThreads, int seed)
KMeansTrainer trainer = new KMeansTrainer(o.centroids, o.iterations, o.distType, o.initialisation, o.numThreads, o.general.seed);
Model<ClusterID> model = trainer.train(train);
logger.info("Finished training model");
ClusteringEvaluation evaluation = factory.getEvaluator().evaluate(model, train);
logger.info("Finished evaluating model");
System.out.println("Normalized MI = " + evaluation.normalizedMI());
System.out.println("Adjusted MI = " + evaluation.adjustedMI());
if (o.general.outputPath != null) {
o.general.saveModel(model);
}
}
Aggregations