use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.
the class ConfigurableTrainTest method main.
/**
* @param args the command line arguments
* @param <T> The {@link Output} subclass.
*/
@SuppressWarnings("unchecked")
public static <T extends Output<T>> void main(String[] args) {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null || o.general.testingPath == null || o.outputFactory == null) {
logger.info(cm.usage());
System.exit(1);
}
Pair<Dataset<T>, Dataset<T>> data = null;
try {
data = o.general.load((OutputFactory<T>) o.outputFactory);
} catch (IOException e) {
logger.log(Level.SEVERE, "Failed to load data", e);
System.exit(1);
}
Dataset<T> train = data.getA();
Dataset<T> test = data.getB();
if (o.trainer == null) {
logger.warning("No trainer supplied");
logger.info(cm.usage());
System.exit(1);
}
if (o.transformationMap != null) {
o.trainer = new TransformTrainer<>(o.trainer, o.transformationMap);
}
logger.info("Trainer is " + o.trainer.getProvenance().toString());
logger.info("Outputs are " + train.getOutputInfo().toReadableString());
logger.info("Number of features: " + train.getFeatureMap().size());
final long trainStart = System.currentTimeMillis();
Model<T> model = ((Trainer<T>) o.trainer).train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
Evaluator<T, ? extends Evaluation<T>> evaluator = train.getOutputFactory().getEvaluator();
final long testStart = System.currentTimeMillis();
Evaluation<T> evaluation = evaluator.evaluate(model, test);
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(evaluation.toString());
if (o.general.outputPath != null) {
try {
o.general.saveModel(model);
} catch (IOException e) {
logger.log(Level.SEVERE, "Error writing model", e);
}
}
if (o.crossValidation) {
if (o.numFolds > 1) {
logger.info("Running " + o.numFolds + " fold cross-validation");
CrossValidation<T, ? extends Evaluation<T>> cv = new CrossValidation<>((Trainer<T>) o.trainer, train, evaluator, o.numFolds, o.general.seed);
List<? extends Pair<? extends Evaluation<T>, Model<T>>> evaluations = cv.evaluate();
List<Evaluation<T>> evals = evaluations.stream().map(Pair::getA).collect(Collectors.toList());
// Summarize across everything
Map<MetricID<T>, DescriptiveStats> summary = EvaluationAggregator.summarize(evals);
List<MetricID<T>> keys = new ArrayList<>(summary.keySet()).stream().sorted(Comparator.comparing(Pair::getB)).collect(Collectors.toList());
System.out.println("Summary across the folds:");
for (MetricID<T> key : keys) {
DescriptiveStats stats = summary.get(key);
System.out.printf("%-10s %.5f (%.5f)%n", key, stats.getMean(), stats.getStandardDeviation());
}
} else {
logger.warning("The number of cross-validation folds must be greater than 1, found " + o.numFolds);
}
}
}
use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.
the class SQLToCSV method main.
/**
* Reads an SQL query from the standard input and writes the results of the
* query to the standard output.
*
* @param args Single arg is the JDBC connection string.
*/
public static void main(String[] args) {
LabsLogFormatter.setAllLogFormatters();
SQLToCSVOptions opts = new SQLToCSVOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, opts);
} catch (UsageException e) {
logger.info(e.getUsage());
System.exit(1);
}
if (opts.dbConfig == null) {
if (opts.connString == null) {
logger.log(Level.SEVERE, "Must specify connection string with -n");
System.exit(1);
}
if (opts.username != null || opts.password != null) {
if (opts.username == null || opts.password == null) {
logger.log(Level.SEVERE, "Must specify both of user and password with -u, -p if one is specified!");
System.exit(1);
}
}
} else if (opts.username != null || opts.password != null || opts.connString != null) {
logger.warning("dbConfig provided but username/password/connstring also provided. Options from -u, -p, -n being ignored");
}
String query;
try (BufferedReader br = opts.inputPath != null ? Files.newBufferedReader(opts.inputPath) : new BufferedReader(new InputStreamReader(System.in, StandardCharsets.UTF_8))) {
StringBuilder qsb = new StringBuilder();
String l;
while ((l = br.readLine()) != null) {
qsb.append(l);
qsb.append("\n");
}
query = qsb.toString().trim();
} catch (IOException ex) {
logger.log(Level.SEVERE, "Error reading query: " + ex);
System.exit(1);
return;
}
if (query.isEmpty()) {
logger.log(Level.SEVERE, "Query is empty string");
System.exit(1);
}
Connection conn = null;
try {
if (opts.dbConfig != null) {
conn = opts.dbConfig.getConnection();
} else if (opts.username != null) {
conn = DriverManager.getConnection(opts.connString, opts.username, opts.password);
} else {
conn = DriverManager.getConnection(opts.connString);
}
} catch (SQLException ex) {
logger.log(Level.SEVERE, "Can't connect to database: " + opts.connString, ex);
System.exit(1);
}
try (Statement stmt = conn.createStatement()) {
stmt.setFetchSize(1000);
stmt.setFetchDirection(ResultSet.FETCH_FORWARD);
ResultSet results;
try {
results = stmt.executeQuery(query);
} catch (SQLException ex) {
logger.log(Level.SEVERE, "Error running query", ex);
try {
conn.close();
} catch (SQLException ex1) {
logger.log(Level.SEVERE, "Failed to close connection", ex1);
}
return;
}
try (ICSVWriter writer = new CSVParserWriter(opts.outputPath != null ? Files.newBufferedWriter(opts.outputPath) : new BufferedWriter(new OutputStreamWriter(System.out, StandardCharsets.UTF_8), 1024 * 1024), new RFC4180Parser(), "\n")) {
writer.writeAll(results, true);
} catch (IOException ex) {
logger.log(Level.SEVERE, "Error writing CSV", ex);
System.exit(1);
} catch (SQLException ex) {
logger.log(Level.SEVERE, "Error retrieving results", ex);
System.exit(1);
}
} catch (SQLException ex) {
logger.log(Level.SEVERE, "Couldn't create statement", ex);
try {
conn.close();
} catch (SQLException ex1) {
logger.log(Level.SEVERE, "Failed to close connection", ex1);
}
System.exit(1);
return;
}
try {
conn.close();
} catch (SQLException ex1) {
logger.log(Level.SEVERE, "Failed to close connection", ex1);
}
}
use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.
the class SplitTextData method main.
/**
* Runs the SplitTextData CLI.
* @param args The CLI arguments.
* @throws IOException If the files could not be read or written to.
*/
public static void main(String[] args) throws IOException {
// Use the labs format logging.
for (Handler h : Logger.getLogger("").getHandlers()) {
h.setLevel(Level.ALL);
h.setFormatter(new LabsLogFormatter());
try {
h.setEncoding("utf-8");
} catch (SecurityException | UnsupportedEncodingException ex) {
logger.severe("Error setting output encoding");
}
}
TrainTestSplitOptions options = new TrainTestSplitOptions();
ConfigurationManager cm = new ConfigurationManager(args, options);
if ((options.inputPath == null) || (options.trainPath == null) || (options.validationPath == null) || (options.splitFraction < 0.0) || (options.splitFraction > 1.0)) {
System.out.println("Incorrect arguments");
System.out.println(cm.usage());
return;
}
int n = 0;
int validCounter = 0;
int invalidCounter = 0;
BufferedReader input = new BufferedReader(new InputStreamReader(new FileInputStream(options.inputPath.toFile()), StandardCharsets.UTF_8));
PrintWriter trainOutput = new PrintWriter(new OutputStreamWriter(new BufferedOutputStream(new FileOutputStream(options.trainPath.toFile())), StandardCharsets.UTF_8));
PrintWriter testOutput = new PrintWriter(new OutputStreamWriter(new BufferedOutputStream(new FileOutputStream(options.validationPath.toFile())), StandardCharsets.UTF_8));
ArrayList<Line> lines = new ArrayList<>();
while (input.ready()) {
n++;
String line = input.readLine().trim();
if (line.isEmpty()) {
invalidCounter++;
continue;
}
String[] fields = line.split("##");
if (fields.length != 2) {
invalidCounter++;
logger.warning(String.format("Bad line in %s at %d: %s", options.inputPath, n, line.substring(Math.min(50, line.length()))));
continue;
}
String label = fields[0].trim().toUpperCase();
lines.add(new Line(label, fields[1]));
validCounter++;
}
input.close();
logger.info("Found " + validCounter + " valid examples, " + invalidCounter + " invalid examples out of " + n + " lines.");
int numTraining = Math.round(options.splitFraction * validCounter);
int numTesting = validCounter - numTraining;
logger.info("Outputting " + numTraining + " training examples, and " + numTesting + " testing examples, with a " + options.splitFraction + " split.");
Collections.shuffle(lines, new Random(options.seed));
for (int i = 0; i < numTraining; i++) {
trainOutput.println(lines.get(i));
}
for (int i = numTraining; i < validCounter; i++) {
testOutput.println(lines.get(i));
}
trainOutput.close();
testOutput.close();
}
use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.
the class PreprocessAndSerialize method main.
/**
* Run the PreprocessAndSerialize CLI.
* @param args The CLI args.
*/
public static void main(String[] args) {
LabsLogFormatter.setAllLogFormatters();
PreprocessAndSerializeOptions opts = new PreprocessAndSerializeOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, opts);
} catch (UsageException e) {
logger.info(e.getUsage());
System.exit(1);
}
logger.info("Reading datasource into dataset");
MutableDataset<?> dataset = new MutableDataset<>(opts.dataSource);
logger.info("Finished reading dataset");
if (opts.output.endsWith("gz")) {
logger.info("Writing zipped dataset");
}
try (ObjectOutputStream os = IOUtil.getObjectOutputStream(opts.output.toString(), opts.output.endsWith("gz"))) {
os.writeObject(dataset);
} catch (IOException e) {
logger.log(Level.SEVERE, "Error writing serialized dataset", e);
System.exit(1);
}
}
use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.
the class CoreTokenizerOptionsTest method testComma.
@Test
public void testComma() {
char escape = ConfigurationManager.CUR_ESCAPE_CHAR;
String[] args = new String[] { "--my-chars", "a," + escape + ",,b,c" };
CommaOptions options = new CommaOptions();
ConfigurationManager cm = new ConfigurationManager(args, options);
cm.close();
assertEquals(4, options.myChars.length);
assertEquals('a', options.myChars[0]);
assertEquals(',', options.myChars[1]);
assertEquals('b', options.myChars[2]);
assertEquals('c', options.myChars[3]);
}
Aggregations