Search in sources :

Example 26 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class IDXDataSourceTest method testFileNotFound.

@Test
public void testFileNotFound() throws IOException {
    // First check that the configuration system throws out of postConfig
    ConfigurationManager cm = new ConfigurationManager("/org/tribuo/datasource/config.xml");
    try {
        // this config file is in the tests, we know the type
        @SuppressWarnings("unchecked") IDXDataSource<MockOutput> tmp = (IDXDataSource<MockOutput>) cm.lookup("train");
        fail("Should have thrown PropertyException");
    } catch (PropertyException e) {
        if (!e.getMessage().contains("Failed to load from path - ")) {
            fail("Incorrect exception message", e);
        }
    } catch (RuntimeException e) {
        fail("Incorrect exception thrown", e);
    }
    // Next check the constructor throws
    MockOutputFactory factory = new MockOutputFactory();
    try {
        IDXDataSource<MockOutput> tmp = new IDXDataSource<>(Paths.get("these-features-dont-exist"), Paths.get("these-outputs-dont-exist"), factory);
        fail("Should have thrown FileNotFoundException");
    } catch (FileNotFoundException e) {
        if (!e.getMessage().contains("Failed to load from path - ")) {
            fail("Incorrect exception message", e);
        }
    }
}
Also used : MockOutput(org.tribuo.test.MockOutput) PropertyException(com.oracle.labs.mlrg.olcut.config.PropertyException) MockOutputFactory(org.tribuo.test.MockOutputFactory) FileNotFoundException(java.io.FileNotFoundException) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) Test(org.junit.jupiter.api.Test)

Example 27 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class Helpers method testConfigurableRoundtrip.

/**
 * Takes an object that is both {@link Provenancable} and {@link Configurable} and tests whether the configuration
 * and provenance representations are the same using {@link ConfigurationData#structuralEquals(List, List, String, String)}.
 * @param itm The object whose equality is to be tested
 */
public static <P extends ConfiguredObjectProvenance, C extends Configurable & Provenancable<P>> void testConfigurableRoundtrip(C itm) {
    ConfigurationManager cm = new ConfigurationManager();
    String name = cm.importConfigurable(itm, "item");
    List<ConfigurationData> configData = cm.getComponentNames().stream().map(cm::getConfigurationData).filter(Optional::isPresent).map(Optional::get).collect(Collectors.toList());
    List<ConfigurationData> provenData = ProvenanceUtil.extractConfiguration(itm.getProvenance());
    Assertions.assertTrue(ConfigurationData.structuralEquals(configData, provenData, name, provenData.get(0).getName()));
}
Also used : Optional(java.util.Optional) ConfigurationData(com.oracle.labs.mlrg.olcut.config.ConfigurationData) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager)

Example 28 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class SequenceModelExplorer method main.

/**
 * Runs the sequence model explorer.
 * @param args CLI arguments.
 */
public static void main(String[] args) {
    SequenceModelExplorerOptions options = new SequenceModelExplorerOptions();
    ConfigurationManager cm = new ConfigurationManager(args, options, false);
    SequenceModelExplorer driver = new SequenceModelExplorer();
    if (options.modelFilename != null) {
        logger.log(Level.INFO, driver.loadModel(driver.shell, new File(options.modelFilename)));
    }
    driver.startShell();
}
Also used : ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) File(java.io.File)

Example 29 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class SeqTrainTest method main.

/**
 * @param args the command line arguments
 * @throws ClassNotFoundException if it failed to load the model.
 * @throws IOException            if there is any error reading the examples.
 */
public static void main(String[] args) throws ClassNotFoundException, IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    SeqTrainTestOptions o = new SeqTrainTestOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    SequenceDataset<Label> train;
    SequenceDataset<Label> test;
    switch(o.datasetName) {
        case "Gorilla":
        case "gorilla":
            logger.info("Generating gorilla dataset");
            train = SequenceDataGenerator.generateGorillaDataset(1);
            test = SequenceDataGenerator.generateGorillaDataset(1);
            break;
        default:
            if ((o.trainDataset != null) && (o.testDataset != null)) {
                logger.info("Loading training data from " + o.trainDataset);
                try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.trainDataset.toFile())));
                    ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.testDataset.toFile())))) {
                    // deserialising a generic dataset.
                    @SuppressWarnings("unchecked") SequenceDataset<Label> tmpTrain = (SequenceDataset<Label>) ois.readObject();
                    train = tmpTrain;
                    logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
                    logger.info("Found " + train.getFeatureIDMap().size() + " features");
                    logger.info("Loading testing data from " + o.testDataset);
                    // deserialising a generic dataset.
                    @SuppressWarnings("unchecked") SequenceDataset<Label> tmpTest = (SequenceDataset<Label>) oits.readObject();
                    test = tmpTest;
                    logger.info(String.format("Loaded %d testing examples", test.size()));
                }
            } else {
                logger.warning("Unknown dataset " + o.datasetName);
                logger.info(cm.usage());
                return;
            }
    }
    logger.info("Training using " + o.trainer.toString());
    final long trainStart = System.currentTimeMillis();
    SequenceModel<Label> model = o.trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
    LabelSequenceEvaluator labelEvaluator = new LabelSequenceEvaluator();
    final long testStart = System.currentTimeMillis();
    LabelSequenceEvaluation evaluation = labelEvaluator.evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    System.out.println();
    System.out.println(evaluation.getConfusionMatrix().toString());
    if (o.outputPath != null) {
        try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputPath.toFile()))) {
            oos.writeObject(model);
            logger.info("Serialized model to file: " + o.outputPath);
        }
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) SequenceDataset(org.tribuo.sequence.SequenceDataset) ObjectOutputStream(java.io.ObjectOutputStream) FileInputStream(java.io.FileInputStream) BufferedInputStream(java.io.BufferedInputStream) FileOutputStream(java.io.FileOutputStream) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ObjectInputStream(java.io.ObjectInputStream)

Example 30 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class ConfigurableTrainTest method main.

/**
 * @param args the command line arguments
 */
public static void main(String[] args) {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null) {
        logger.info(cm.usage());
        System.exit(1);
    }
    Pair<Dataset<Label>, Dataset<Label>> data = null;
    try {
        data = o.general.load(new LabelFactory());
    } catch (IOException e) {
        logger.log(Level.SEVERE, "Failed to load data", e);
        System.exit(1);
    }
    Dataset<Label> train = data.getA();
    Dataset<Label> test = data.getB();
    if (o.trainer == null) {
        logger.warning("No trainer supplied");
        logger.info(cm.usage());
        System.exit(1);
    }
    logger.info("Trainer is " + o.trainer.toString());
    if (o.weights != null) {
        Map<Label, Float> weightsMap = processWeights(o.weights);
        if (o.trainer instanceof WeightedLabels) {
            ((WeightedLabels) o.trainer).setLabelWeights(weightsMap);
            logger.info("Setting label weights using " + weightsMap.toString());
        } else if (o.trainer instanceof WeightedExamples) {
            ((MutableDataset<Label>) train).setWeights(weightsMap);
            logger.info("Setting example weights using " + weightsMap.toString());
        } else {
            logger.warning("The selected trainer does not support weighted training. The chosen trainer is " + o.trainer.toString());
            logger.info(cm.usage());
            System.exit(1);
        }
    }
    logger.info("Labels are " + train.getOutputInfo().toReadableString());
    final long trainStart = System.currentTimeMillis();
    Model<Label> model = o.trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
    LabelEvaluator labelEvaluator = new LabelEvaluator();
    final long testStart = System.currentTimeMillis();
    List<Prediction<Label>> predictions = model.predict(test);
    LabelEvaluation labelEvaluation = labelEvaluator.evaluate(model, predictions, test.getProvenance());
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(labelEvaluation.toString());
    ConfusionMatrix<Label> matrix = labelEvaluation.getConfusionMatrix();
    System.out.println(matrix.toString());
    if (model.generatesProbabilities()) {
        System.out.println("Average AUC = " + labelEvaluation.averageAUCROC(false));
        System.out.println("Average weighted AUC = " + labelEvaluation.averageAUCROC(true));
    }
    if (o.predictionPath != null) {
        try (BufferedWriter wrt = Files.newBufferedWriter(o.predictionPath)) {
            List<String> labels = model.getOutputIDInfo().getDomain().stream().map(Label::getLabel).sorted().collect(Collectors.toList());
            wrt.write("Label,");
            wrt.write(String.join(",", labels));
            wrt.newLine();
            for (Prediction<Label> pred : predictions) {
                Example<Label> ex = pred.getExample();
                wrt.write(ex.getOutput().getLabel() + ",");
                wrt.write(labels.stream().map(l -> Double.toString(pred.getOutputScores().get(l).getScore())).collect(Collectors.joining(",")));
                wrt.newLine();
            }
            wrt.flush();
        } catch (IOException e) {
            logger.log(Level.SEVERE, "Error writing predictions", e);
        }
    }
    if (o.general.outputPath != null) {
        try {
            o.general.saveModel(model);
        } catch (IOException e) {
            logger.log(Level.SEVERE, "Error writing model", e);
        }
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) WeightedExamples(org.tribuo.WeightedExamples) BufferedWriter(java.io.BufferedWriter) WeightedLabels(org.tribuo.classification.WeightedLabels) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) Dataset(org.tribuo.Dataset) MutableDataset(org.tribuo.MutableDataset) Prediction(org.tribuo.Prediction) IOException(java.io.IOException) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator) LabelFactory(org.tribuo.classification.LabelFactory)

Aggregations

ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)42 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)32 Label (org.tribuo.classification.Label)16 Dataset (org.tribuo.Dataset)15 IOException (java.io.IOException)8 FileOutputStream (java.io.FileOutputStream)7 ObjectOutputStream (java.io.ObjectOutputStream)7 RegressionFactory (org.tribuo.regression.RegressionFactory)7 Regressor (org.tribuo.regression.Regressor)7 RegressionEvaluation (org.tribuo.regression.evaluation.RegressionEvaluation)7 BufferedWriter (java.io.BufferedWriter)4 File (java.io.File)4 OutputStreamWriter (java.io.OutputStreamWriter)4 MutableDataset (org.tribuo.MutableDataset)4 LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)4 LabelEvaluator (org.tribuo.classification.evaluation.LabelEvaluator)4 FileInputStream (java.io.FileInputStream)3 ObjectInputStream (java.io.ObjectInputStream)3 PrintWriter (java.io.PrintWriter)3 ArrayList (java.util.ArrayList)3