Search in sources :

Example 16 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    TrainTestOptions o = new TrainTestOptions();
    try (ConfigurationManager cm = new ConfigurationManager(args, o)) {
        Trainer<Label> trainer = o.libsvmOptions.getTrainer();
        trainer = o.ensembleOptions.wrapTrainer(trainer);
        Model<Label> model = TrainTestHelper.run(cm, o.general, trainer);
        if (model instanceof LibSVMClassificationModel) {
            logger.info("Used " + ((LibSVMClassificationModel) model).getNumberOfSupportVectors() + " support vectors");
        }
    } catch (UsageException e) {
        System.out.println(e.getUsage());
    } catch (ArgumentException e) {
        System.out.println("Invalid argument: " + e.getMessage());
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) ArgumentException(com.oracle.labs.mlrg.olcut.config.ArgumentException) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 17 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class RunAll method main.

/**
 * Runs the RunALL CLI.
 * @param args The CLI arguments.
 * @throws IOException If it failed to load the data.
 */
public static void main(String[] args) throws IOException {
    LabsLogFormatter.setAllLogFormatters();
    RunAllOptions o = new RunAllOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null || o.directory == null) {
        logger.info(cm.usage());
        System.exit(1);
    }
    Pair<Dataset<Label>, Dataset<Label>> data = null;
    try {
        data = o.general.load(new LabelFactory());
    } catch (IOException e) {
        logger.log(Level.SEVERE, "Failed to load data", e);
        System.exit(1);
    }
    Dataset<Label> train = data.getA();
    Dataset<Label> test = data.getB();
    logger.info("Creating directory - " + o.directory.toString());
    if (!o.directory.exists() && !o.directory.mkdirs()) {
        logger.warning("Failed to create directory.");
    }
    Map<String, Double> performances = new HashMap<>();
    List<Trainer> trainers = cm.lookupAll(Trainer.class);
    for (Trainer<?> t : trainers) {
        String name = t.getClass().getSimpleName();
        logger.info("Training model using " + t.toString());
        // configuration system cast.
        @SuppressWarnings("unchecked") Model<Label> curModel = ((Trainer<Label>) t).train(train);
        LabelEvaluator evaluator = new LabelEvaluator();
        LabelEvaluation evaluation = evaluator.evaluate(curModel, test);
        Double old = performances.put(name, evaluation.microAveragedF1());
        if (old != null) {
            logger.info("Found two trainers with the name " + name);
        }
        String outputPath = o.directory.toString() + "/" + name;
        try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(outputPath + ".model"))) {
            oos.writeObject(curModel);
        }
        try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(outputPath + ".output"), StandardCharsets.UTF_8))) {
            writer.println("Model = " + name);
            writer.println("Provenance = " + curModel.toString());
            writer.println();
            ConfusionMatrix<Label> matrix = evaluation.getConfusionMatrix();
            writer.println("ConfusionMatrix:\n" + matrix.toString());
            writer.println();
            writer.println("Evaluation:\n" + evaluation.toString());
        }
    }
    for (Map.Entry<String, Double> e : performances.entrySet()) {
        logger.info("Trainer = " + e.getKey() + ", F1 = " + e.getValue());
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) HashMap(java.util.HashMap) Label(org.tribuo.classification.Label) Trainer(org.tribuo.Trainer) ObjectOutputStream(java.io.ObjectOutputStream) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) PrintWriter(java.io.PrintWriter) Dataset(org.tribuo.Dataset) IOException(java.io.IOException) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator) LabelFactory(org.tribuo.classification.LabelFactory) FileOutputStream(java.io.FileOutputStream) OutputStreamWriter(java.io.OutputStreamWriter) HashMap(java.util.HashMap) Map(java.util.Map)

Example 18 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    LibSVMOptions o = new LibSVMOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null) {
        logger.info(cm.usage());
        return;
    }
    RegressionFactory factory = new RegressionFactory();
    Pair<Dataset<Regressor>, Dataset<Regressor>> data = o.general.load(factory);
    Dataset<Regressor> train = data.getA();
    Dataset<Regressor> test = data.getB();
    SVMParameters<Regressor> parameters = new SVMParameters<>(new SVMRegressionType(o.svmType), o.kernelType);
    parameters.setGamma(o.gamma);
    parameters.setCoeff(o.coeff);
    parameters.setDegree(o.degree);
    Trainer<Regressor> trainer = new LibSVMRegressionTrainer(parameters, o.standardize);
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    Model<Regressor> model = trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training regressor " + Util.formatDuration(trainStart, trainStop));
    logger.info("Support vectors - " + ((LibSVMRegressionModel) model).getNumberOfSupportVectors());
    final long testStart = System.currentTimeMillis();
    RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) RegressionFactory(org.tribuo.regression.RegressionFactory) Dataset(org.tribuo.Dataset) SVMParameters(org.tribuo.common.libsvm.SVMParameters) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Regressor(org.tribuo.regression.Regressor) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 19 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class OCIModelCLI method main.

/**
 * Entry point for the OCIModelCLI.
 * @param args The CLI arguments.
 * @throws IOException If the model or dataset failed to load.
 * @throws ClassNotFoundException If the model or dataset had an unknown class.
 */
public static void main(String[] args) throws IOException, ClassNotFoundException {
    OCIModelOptions o = new OCIModelOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o, false);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    switch(o.mode) {
        case CREATE_AND_DEPLOY:
            createModelAndDeploy(o);
            break;
        case DEPLOY:
            deploy(o);
            break;
        case SCORE:
            modelScoring(o);
            break;
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 20 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    KMeansOptions o = new KMeansOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null) {
        logger.info(cm.usage());
        return;
    }
    ClusteringFactory factory = new ClusteringFactory();
    Pair<Dataset<ClusterID>, Dataset<ClusterID>> data = o.general.load(factory);
    Dataset<ClusterID> train = data.getA();
    // public KMeansTrainer(int centroids, int iterations, DistanceType distType, int numThreads, int seed)
    KMeansTrainer trainer = new KMeansTrainer(o.centroids, o.iterations, o.distType, o.initialisation, o.numThreads, o.general.seed);
    Model<ClusterID> model = trainer.train(train);
    logger.info("Finished training model");
    ClusteringEvaluation evaluation = factory.getEvaluator().evaluate(model, train);
    logger.info("Finished evaluating model");
    System.out.println("Normalized MI = " + evaluation.normalizedMI());
    System.out.println("Adjusted MI = " + evaluation.adjustedMI());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) ClusterID(org.tribuo.clustering.ClusterID) Dataset(org.tribuo.Dataset) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ClusteringEvaluation(org.tribuo.clustering.evaluation.ClusteringEvaluation)

Aggregations

ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)42 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)32 Label (org.tribuo.classification.Label)16 Dataset (org.tribuo.Dataset)15 IOException (java.io.IOException)8 FileOutputStream (java.io.FileOutputStream)7 ObjectOutputStream (java.io.ObjectOutputStream)7 RegressionFactory (org.tribuo.regression.RegressionFactory)7 Regressor (org.tribuo.regression.Regressor)7 RegressionEvaluation (org.tribuo.regression.evaluation.RegressionEvaluation)7 BufferedWriter (java.io.BufferedWriter)4 File (java.io.File)4 OutputStreamWriter (java.io.OutputStreamWriter)4 MutableDataset (org.tribuo.MutableDataset)4 LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)4 LabelEvaluator (org.tribuo.classification.evaluation.LabelEvaluator)4 FileInputStream (java.io.FileInputStream)3 ObjectInputStream (java.io.ObjectInputStream)3 PrintWriter (java.io.PrintWriter)3 ArrayList (java.util.ArrayList)3