use of com.oracle.labs.mlrg.olcut.util.LabsLogFormatter in project tribuo by oracle.
the class SplitTextData method main.
/**
* Runs the SplitTextData CLI.
* @param args The CLI arguments.
* @throws IOException If the files could not be read or written to.
*/
public static void main(String[] args) throws IOException {
// Use the labs format logging.
for (Handler h : Logger.getLogger("").getHandlers()) {
h.setLevel(Level.ALL);
h.setFormatter(new LabsLogFormatter());
try {
h.setEncoding("utf-8");
} catch (SecurityException | UnsupportedEncodingException ex) {
logger.severe("Error setting output encoding");
}
}
TrainTestSplitOptions options = new TrainTestSplitOptions();
ConfigurationManager cm = new ConfigurationManager(args, options);
if ((options.inputPath == null) || (options.trainPath == null) || (options.validationPath == null) || (options.splitFraction < 0.0) || (options.splitFraction > 1.0)) {
System.out.println("Incorrect arguments");
System.out.println(cm.usage());
return;
}
int n = 0;
int validCounter = 0;
int invalidCounter = 0;
BufferedReader input = new BufferedReader(new InputStreamReader(new FileInputStream(options.inputPath.toFile()), StandardCharsets.UTF_8));
PrintWriter trainOutput = new PrintWriter(new OutputStreamWriter(new BufferedOutputStream(new FileOutputStream(options.trainPath.toFile())), StandardCharsets.UTF_8));
PrintWriter testOutput = new PrintWriter(new OutputStreamWriter(new BufferedOutputStream(new FileOutputStream(options.validationPath.toFile())), StandardCharsets.UTF_8));
ArrayList<Line> lines = new ArrayList<>();
while (input.ready()) {
n++;
String line = input.readLine().trim();
if (line.isEmpty()) {
invalidCounter++;
continue;
}
String[] fields = line.split("##");
if (fields.length != 2) {
invalidCounter++;
logger.warning(String.format("Bad line in %s at %d: %s", options.inputPath, n, line.substring(Math.min(50, line.length()))));
continue;
}
String label = fields[0].trim().toUpperCase();
lines.add(new Line(label, fields[1]));
validCounter++;
}
input.close();
logger.info("Found " + validCounter + " valid examples, " + invalidCounter + " invalid examples out of " + n + " lines.");
int numTraining = Math.round(options.splitFraction * validCounter);
int numTesting = validCounter - numTraining;
logger.info("Outputting " + numTraining + " training examples, and " + numTesting + " testing examples, with a " + options.splitFraction + " split.");
Collections.shuffle(lines, new Random(options.seed));
for (int i = 0; i < numTraining; i++) {
trainOutput.println(lines.get(i));
}
for (int i = numTraining; i < validCounter; i++) {
testOutput.println(lines.get(i));
}
trainOutput.close();
testOutput.close();
}
Aggregations