Search in sources :

Example 21 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class ConfigurableTrainTest method main.

/**
 * @param args the command line arguments
 * @param <T> The {@link Output} subclass.
 */
@SuppressWarnings("unchecked")
public static <T extends Output<T>> void main(String[] args) {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null || o.outputFactory == null) {
        logger.info(cm.usage());
        System.exit(1);
    }
    Pair<Dataset<T>, Dataset<T>> data = null;
    try {
        data = o.general.load((OutputFactory<T>) o.outputFactory);
    } catch (IOException e) {
        logger.log(Level.SEVERE, "Failed to load data", e);
        System.exit(1);
    }
    Dataset<T> train = data.getA();
    Dataset<T> test = data.getB();
    if (o.trainer == null) {
        logger.warning("No trainer supplied");
        logger.info(cm.usage());
        System.exit(1);
    }
    if (o.transformationMap != null) {
        o.trainer = new TransformTrainer<>(o.trainer, o.transformationMap);
    }
    logger.info("Trainer is " + o.trainer.getProvenance().toString());
    logger.info("Outputs are " + train.getOutputInfo().toReadableString());
    logger.info("Number of features: " + train.getFeatureMap().size());
    final long trainStart = System.currentTimeMillis();
    Model<T> model = ((Trainer<T>) o.trainer).train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
    Evaluator<T, ? extends Evaluation<T>> evaluator = train.getOutputFactory().getEvaluator();
    final long testStart = System.currentTimeMillis();
    Evaluation<T> evaluation = evaluator.evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    if (o.general.outputPath != null) {
        try {
            o.general.saveModel(model);
        } catch (IOException e) {
            logger.log(Level.SEVERE, "Error writing model", e);
        }
    }
    if (o.crossValidation) {
        if (o.numFolds > 1) {
            logger.info("Running " + o.numFolds + " fold cross-validation");
            CrossValidation<T, ? extends Evaluation<T>> cv = new CrossValidation<>((Trainer<T>) o.trainer, train, evaluator, o.numFolds, o.general.seed);
            List<? extends Pair<? extends Evaluation<T>, Model<T>>> evaluations = cv.evaluate();
            List<Evaluation<T>> evals = evaluations.stream().map(Pair::getA).collect(Collectors.toList());
            // Summarize across everything
            Map<MetricID<T>, DescriptiveStats> summary = EvaluationAggregator.summarize(evals);
            List<MetricID<T>> keys = new ArrayList<>(summary.keySet()).stream().sorted(Comparator.comparing(Pair::getB)).collect(Collectors.toList());
            System.out.println("Summary across the folds:");
            for (MetricID<T> key : keys) {
                DescriptiveStats stats = summary.get(key);
                System.out.printf("%-10s  %.5f (%.5f)%n", key, stats.getMean(), stats.getStandardDeviation());
            }
        } else {
            logger.warning("The number of cross-validation folds must be greater than 1, found " + o.numFolds);
        }
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) TransformTrainer(org.tribuo.transform.TransformTrainer) Trainer(org.tribuo.Trainer) MetricID(org.tribuo.evaluation.metrics.MetricID) DescriptiveStats(org.tribuo.evaluation.DescriptiveStats) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) Pair(com.oracle.labs.mlrg.olcut.util.Pair) Evaluation(org.tribuo.evaluation.Evaluation) Dataset(org.tribuo.Dataset) IOException(java.io.IOException) Model(org.tribuo.Model) CrossValidation(org.tribuo.evaluation.CrossValidation) OutputFactory(org.tribuo.OutputFactory)

Example 22 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class SQLToCSV method main.

/**
 * Reads an SQL query from the standard input and writes the results of the
 * query to the standard output.
 *
 * @param args Single arg is the JDBC connection string.
 */
public static void main(String[] args) {
    LabsLogFormatter.setAllLogFormatters();
    SQLToCSVOptions opts = new SQLToCSVOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, opts);
    } catch (UsageException e) {
        logger.info(e.getUsage());
        System.exit(1);
    }
    if (opts.dbConfig == null) {
        if (opts.connString == null) {
            logger.log(Level.SEVERE, "Must specify connection string with -n");
            System.exit(1);
        }
        if (opts.username != null || opts.password != null) {
            if (opts.username == null || opts.password == null) {
                logger.log(Level.SEVERE, "Must specify both of user and password with -u, -p if one is specified!");
                System.exit(1);
            }
        }
    } else if (opts.username != null || opts.password != null || opts.connString != null) {
        logger.warning("dbConfig provided but username/password/connstring also provided. Options from -u, -p, -n being ignored");
    }
    String query;
    try (BufferedReader br = opts.inputPath != null ? Files.newBufferedReader(opts.inputPath) : new BufferedReader(new InputStreamReader(System.in, StandardCharsets.UTF_8))) {
        StringBuilder qsb = new StringBuilder();
        String l;
        while ((l = br.readLine()) != null) {
            qsb.append(l);
            qsb.append("\n");
        }
        query = qsb.toString().trim();
    } catch (IOException ex) {
        logger.log(Level.SEVERE, "Error reading query: " + ex);
        System.exit(1);
        return;
    }
    if (query.isEmpty()) {
        logger.log(Level.SEVERE, "Query is empty string");
        System.exit(1);
    }
    Connection conn = null;
    try {
        if (opts.dbConfig != null) {
            conn = opts.dbConfig.getConnection();
        } else if (opts.username != null) {
            conn = DriverManager.getConnection(opts.connString, opts.username, opts.password);
        } else {
            conn = DriverManager.getConnection(opts.connString);
        }
    } catch (SQLException ex) {
        logger.log(Level.SEVERE, "Can't connect to database: " + opts.connString, ex);
        System.exit(1);
    }
    try (Statement stmt = conn.createStatement()) {
        stmt.setFetchSize(1000);
        stmt.setFetchDirection(ResultSet.FETCH_FORWARD);
        ResultSet results;
        try {
            results = stmt.executeQuery(query);
        } catch (SQLException ex) {
            logger.log(Level.SEVERE, "Error running query", ex);
            try {
                conn.close();
            } catch (SQLException ex1) {
                logger.log(Level.SEVERE, "Failed to close connection", ex1);
            }
            return;
        }
        try (ICSVWriter writer = new CSVParserWriter(opts.outputPath != null ? Files.newBufferedWriter(opts.outputPath) : new BufferedWriter(new OutputStreamWriter(System.out, StandardCharsets.UTF_8), 1024 * 1024), new RFC4180Parser(), "\n")) {
            writer.writeAll(results, true);
        } catch (IOException ex) {
            logger.log(Level.SEVERE, "Error writing CSV", ex);
            System.exit(1);
        } catch (SQLException ex) {
            logger.log(Level.SEVERE, "Error retrieving results", ex);
            System.exit(1);
        }
    } catch (SQLException ex) {
        logger.log(Level.SEVERE, "Couldn't create statement", ex);
        try {
            conn.close();
        } catch (SQLException ex1) {
            logger.log(Level.SEVERE, "Failed to close connection", ex1);
        }
        System.exit(1);
        return;
    }
    try {
        conn.close();
    } catch (SQLException ex1) {
        logger.log(Level.SEVERE, "Failed to close connection", ex1);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) InputStreamReader(java.io.InputStreamReader) RFC4180Parser(com.opencsv.RFC4180Parser) SQLException(java.sql.SQLException) Statement(java.sql.Statement) Connection(java.sql.Connection) IOException(java.io.IOException) BufferedWriter(java.io.BufferedWriter) ICSVWriter(com.opencsv.ICSVWriter) BufferedReader(java.io.BufferedReader) ResultSet(java.sql.ResultSet) CSVParserWriter(com.opencsv.CSVParserWriter) OutputStreamWriter(java.io.OutputStreamWriter) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 23 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class SplitTextData method main.

/**
 * Runs the SplitTextData CLI.
 * @param args The CLI arguments.
 * @throws IOException If the files could not be read or written to.
 */
public static void main(String[] args) throws IOException {
    // Use the labs format logging.
    for (Handler h : Logger.getLogger("").getHandlers()) {
        h.setLevel(Level.ALL);
        h.setFormatter(new LabsLogFormatter());
        try {
            h.setEncoding("utf-8");
        } catch (SecurityException | UnsupportedEncodingException ex) {
            logger.severe("Error setting output encoding");
        }
    }
    TrainTestSplitOptions options = new TrainTestSplitOptions();
    ConfigurationManager cm = new ConfigurationManager(args, options);
    if ((options.inputPath == null) || (options.trainPath == null) || (options.validationPath == null) || (options.splitFraction < 0.0) || (options.splitFraction > 1.0)) {
        System.out.println("Incorrect arguments");
        System.out.println(cm.usage());
        return;
    }
    int n = 0;
    int validCounter = 0;
    int invalidCounter = 0;
    BufferedReader input = new BufferedReader(new InputStreamReader(new FileInputStream(options.inputPath.toFile()), StandardCharsets.UTF_8));
    PrintWriter trainOutput = new PrintWriter(new OutputStreamWriter(new BufferedOutputStream(new FileOutputStream(options.trainPath.toFile())), StandardCharsets.UTF_8));
    PrintWriter testOutput = new PrintWriter(new OutputStreamWriter(new BufferedOutputStream(new FileOutputStream(options.validationPath.toFile())), StandardCharsets.UTF_8));
    ArrayList<Line> lines = new ArrayList<>();
    while (input.ready()) {
        n++;
        String line = input.readLine().trim();
        if (line.isEmpty()) {
            invalidCounter++;
            continue;
        }
        String[] fields = line.split("##");
        if (fields.length != 2) {
            invalidCounter++;
            logger.warning(String.format("Bad line in %s at %d: %s", options.inputPath, n, line.substring(Math.min(50, line.length()))));
            continue;
        }
        String label = fields[0].trim().toUpperCase();
        lines.add(new Line(label, fields[1]));
        validCounter++;
    }
    input.close();
    logger.info("Found " + validCounter + " valid examples, " + invalidCounter + " invalid examples out of " + n + " lines.");
    int numTraining = Math.round(options.splitFraction * validCounter);
    int numTesting = validCounter - numTraining;
    logger.info("Outputting " + numTraining + " training examples, and " + numTesting + " testing examples, with a " + options.splitFraction + " split.");
    Collections.shuffle(lines, new Random(options.seed));
    for (int i = 0; i < numTraining; i++) {
        trainOutput.println(lines.get(i));
    }
    for (int i = numTraining; i < validCounter; i++) {
        testOutput.println(lines.get(i));
    }
    trainOutput.close();
    testOutput.close();
}
Also used : InputStreamReader(java.io.InputStreamReader) ArrayList(java.util.ArrayList) Handler(java.util.logging.Handler) UnsupportedEncodingException(java.io.UnsupportedEncodingException) FileInputStream(java.io.FileInputStream) Random(java.util.Random) FileOutputStream(java.io.FileOutputStream) LabsLogFormatter(com.oracle.labs.mlrg.olcut.util.LabsLogFormatter) BufferedReader(java.io.BufferedReader) OutputStreamWriter(java.io.OutputStreamWriter) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) BufferedOutputStream(java.io.BufferedOutputStream) PrintWriter(java.io.PrintWriter)

Example 24 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class PreprocessAndSerialize method main.

/**
 * Run the PreprocessAndSerialize CLI.
 * @param args The CLI args.
 */
public static void main(String[] args) {
    LabsLogFormatter.setAllLogFormatters();
    PreprocessAndSerializeOptions opts = new PreprocessAndSerializeOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, opts);
    } catch (UsageException e) {
        logger.info(e.getUsage());
        System.exit(1);
    }
    logger.info("Reading datasource into dataset");
    MutableDataset<?> dataset = new MutableDataset<>(opts.dataSource);
    logger.info("Finished reading dataset");
    if (opts.output.endsWith("gz")) {
        logger.info("Writing zipped dataset");
    }
    try (ObjectOutputStream os = IOUtil.getObjectOutputStream(opts.output.toString(), opts.output.endsWith("gz"))) {
        os.writeObject(dataset);
    } catch (IOException e) {
        logger.log(Level.SEVERE, "Error writing serialized dataset", e);
        System.exit(1);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) IOException(java.io.IOException) ObjectOutputStream(java.io.ObjectOutputStream) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) MutableDataset(org.tribuo.MutableDataset)

Example 25 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class CoreTokenizerOptionsTest method testComma.

@Test
public void testComma() {
    char escape = ConfigurationManager.CUR_ESCAPE_CHAR;
    String[] args = new String[] { "--my-chars", "a," + escape + ",,b,c" };
    CommaOptions options = new CommaOptions();
    ConfigurationManager cm = new ConfigurationManager(args, options);
    cm.close();
    assertEquals(4, options.myChars.length);
    assertEquals('a', options.myChars[0]);
    assertEquals(',', options.myChars[1]);
    assertEquals('b', options.myChars[2]);
    assertEquals('c', options.myChars[3]);
}
Also used : ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) Test(org.junit.jupiter.api.Test)

Aggregations

ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)42 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)32 Label (org.tribuo.classification.Label)16 Dataset (org.tribuo.Dataset)15 IOException (java.io.IOException)8 FileOutputStream (java.io.FileOutputStream)7 ObjectOutputStream (java.io.ObjectOutputStream)7 RegressionFactory (org.tribuo.regression.RegressionFactory)7 Regressor (org.tribuo.regression.Regressor)7 RegressionEvaluation (org.tribuo.regression.evaluation.RegressionEvaluation)7 BufferedWriter (java.io.BufferedWriter)4 File (java.io.File)4 OutputStreamWriter (java.io.OutputStreamWriter)4 MutableDataset (org.tribuo.MutableDataset)4 LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)4 LabelEvaluator (org.tribuo.classification.evaluation.LabelEvaluator)4 FileInputStream (java.io.FileInputStream)3 ObjectInputStream (java.io.ObjectInputStream)3 PrintWriter (java.io.PrintWriter)3 ArrayList (java.util.ArrayList)3