Search in sources :

Example 1 with LabsLogFormatter

use of com.oracle.labs.mlrg.olcut.util.LabsLogFormatter in project tribuo by oracle.

the class SplitTextData method main.

/**
 * Runs the SplitTextData CLI.
 * @param args The CLI arguments.
 * @throws IOException If the files could not be read or written to.
 */
public static void main(String[] args) throws IOException {
    // Use the labs format logging.
    for (Handler h : Logger.getLogger("").getHandlers()) {
        h.setLevel(Level.ALL);
        h.setFormatter(new LabsLogFormatter());
        try {
            h.setEncoding("utf-8");
        } catch (SecurityException | UnsupportedEncodingException ex) {
            logger.severe("Error setting output encoding");
        }
    }
    TrainTestSplitOptions options = new TrainTestSplitOptions();
    ConfigurationManager cm = new ConfigurationManager(args, options);
    if ((options.inputPath == null) || (options.trainPath == null) || (options.validationPath == null) || (options.splitFraction < 0.0) || (options.splitFraction > 1.0)) {
        System.out.println("Incorrect arguments");
        System.out.println(cm.usage());
        return;
    }
    int n = 0;
    int validCounter = 0;
    int invalidCounter = 0;
    BufferedReader input = new BufferedReader(new InputStreamReader(new FileInputStream(options.inputPath.toFile()), StandardCharsets.UTF_8));
    PrintWriter trainOutput = new PrintWriter(new OutputStreamWriter(new BufferedOutputStream(new FileOutputStream(options.trainPath.toFile())), StandardCharsets.UTF_8));
    PrintWriter testOutput = new PrintWriter(new OutputStreamWriter(new BufferedOutputStream(new FileOutputStream(options.validationPath.toFile())), StandardCharsets.UTF_8));
    ArrayList<Line> lines = new ArrayList<>();
    while (input.ready()) {
        n++;
        String line = input.readLine().trim();
        if (line.isEmpty()) {
            invalidCounter++;
            continue;
        }
        String[] fields = line.split("##");
        if (fields.length != 2) {
            invalidCounter++;
            logger.warning(String.format("Bad line in %s at %d: %s", options.inputPath, n, line.substring(Math.min(50, line.length()))));
            continue;
        }
        String label = fields[0].trim().toUpperCase();
        lines.add(new Line(label, fields[1]));
        validCounter++;
    }
    input.close();
    logger.info("Found " + validCounter + " valid examples, " + invalidCounter + " invalid examples out of " + n + " lines.");
    int numTraining = Math.round(options.splitFraction * validCounter);
    int numTesting = validCounter - numTraining;
    logger.info("Outputting " + numTraining + " training examples, and " + numTesting + " testing examples, with a " + options.splitFraction + " split.");
    Collections.shuffle(lines, new Random(options.seed));
    for (int i = 0; i < numTraining; i++) {
        trainOutput.println(lines.get(i));
    }
    for (int i = numTraining; i < validCounter; i++) {
        testOutput.println(lines.get(i));
    }
    trainOutput.close();
    testOutput.close();
}
Also used : InputStreamReader(java.io.InputStreamReader) ArrayList(java.util.ArrayList) Handler(java.util.logging.Handler) UnsupportedEncodingException(java.io.UnsupportedEncodingException) FileInputStream(java.io.FileInputStream) Random(java.util.Random) FileOutputStream(java.io.FileOutputStream) LabsLogFormatter(com.oracle.labs.mlrg.olcut.util.LabsLogFormatter) BufferedReader(java.io.BufferedReader) OutputStreamWriter(java.io.OutputStreamWriter) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) BufferedOutputStream(java.io.BufferedOutputStream) PrintWriter(java.io.PrintWriter)

Aggregations

ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)1 LabsLogFormatter (com.oracle.labs.mlrg.olcut.util.LabsLogFormatter)1 BufferedOutputStream (java.io.BufferedOutputStream)1 BufferedReader (java.io.BufferedReader)1 FileInputStream (java.io.FileInputStream)1 FileOutputStream (java.io.FileOutputStream)1 InputStreamReader (java.io.InputStreamReader)1 OutputStreamWriter (java.io.OutputStreamWriter)1 PrintWriter (java.io.PrintWriter)1 UnsupportedEncodingException (java.io.UnsupportedEncodingException)1 ArrayList (java.util.ArrayList)1 Random (java.util.Random)1 Handler (java.util.logging.Handler)1