Search in sources :

Example 21 with UsageException

use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.

the class SeqTrainTest method main.

/**
 * @param args the command line arguments
 * @throws ClassNotFoundException if it failed to load the model.
 * @throws IOException            if there is any error reading the examples.
 */
public static void main(String[] args) throws ClassNotFoundException, IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    SeqTrainTestOptions o = new SeqTrainTestOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    SequenceDataset<Label> train;
    SequenceDataset<Label> test;
    switch(o.datasetName) {
        case "Gorilla":
        case "gorilla":
            logger.info("Generating gorilla dataset");
            train = SequenceDataGenerator.generateGorillaDataset(1);
            test = SequenceDataGenerator.generateGorillaDataset(1);
            break;
        default:
            if ((o.trainDataset != null) && (o.testDataset != null)) {
                logger.info("Loading training data from " + o.trainDataset);
                try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.trainDataset.toFile())));
                    ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.testDataset.toFile())))) {
                    // deserialising a generic dataset.
                    @SuppressWarnings("unchecked") SequenceDataset<Label> tmpTrain = (SequenceDataset<Label>) ois.readObject();
                    train = tmpTrain;
                    logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
                    logger.info("Found " + train.getFeatureIDMap().size() + " features");
                    logger.info("Loading testing data from " + o.testDataset);
                    // deserialising a generic dataset.
                    @SuppressWarnings("unchecked") SequenceDataset<Label> tmpTest = (SequenceDataset<Label>) oits.readObject();
                    test = tmpTest;
                    logger.info(String.format("Loaded %d testing examples", test.size()));
                }
            } else {
                logger.warning("Unknown dataset " + o.datasetName);
                logger.info(cm.usage());
                return;
            }
    }
    logger.info("Training using " + o.trainer.toString());
    final long trainStart = System.currentTimeMillis();
    SequenceModel<Label> model = o.trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
    LabelSequenceEvaluator labelEvaluator = new LabelSequenceEvaluator();
    final long testStart = System.currentTimeMillis();
    LabelSequenceEvaluation evaluation = labelEvaluator.evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(evaluation.toString());
    System.out.println();
    System.out.println(evaluation.getConfusionMatrix().toString());
    if (o.outputPath != null) {
        try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputPath.toFile()))) {
            oos.writeObject(model);
            logger.info("Serialized model to file: " + o.outputPath);
        }
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) SequenceDataset(org.tribuo.sequence.SequenceDataset) ObjectOutputStream(java.io.ObjectOutputStream) FileInputStream(java.io.FileInputStream) BufferedInputStream(java.io.BufferedInputStream) FileOutputStream(java.io.FileOutputStream) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ObjectInputStream(java.io.ObjectInputStream)

Example 22 with UsageException

use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.

the class ConfigurableTrainTest method main.

/**
 * @param args the command line arguments
 */
public static void main(String[] args) {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null) {
        logger.info(cm.usage());
        System.exit(1);
    }
    Pair<Dataset<Label>, Dataset<Label>> data = null;
    try {
        data = o.general.load(new LabelFactory());
    } catch (IOException e) {
        logger.log(Level.SEVERE, "Failed to load data", e);
        System.exit(1);
    }
    Dataset<Label> train = data.getA();
    Dataset<Label> test = data.getB();
    if (o.trainer == null) {
        logger.warning("No trainer supplied");
        logger.info(cm.usage());
        System.exit(1);
    }
    logger.info("Trainer is " + o.trainer.toString());
    if (o.weights != null) {
        Map<Label, Float> weightsMap = processWeights(o.weights);
        if (o.trainer instanceof WeightedLabels) {
            ((WeightedLabels) o.trainer).setLabelWeights(weightsMap);
            logger.info("Setting label weights using " + weightsMap.toString());
        } else if (o.trainer instanceof WeightedExamples) {
            ((MutableDataset<Label>) train).setWeights(weightsMap);
            logger.info("Setting example weights using " + weightsMap.toString());
        } else {
            logger.warning("The selected trainer does not support weighted training. The chosen trainer is " + o.trainer.toString());
            logger.info(cm.usage());
            System.exit(1);
        }
    }
    logger.info("Labels are " + train.getOutputInfo().toReadableString());
    final long trainStart = System.currentTimeMillis();
    Model<Label> model = o.trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
    LabelEvaluator labelEvaluator = new LabelEvaluator();
    final long testStart = System.currentTimeMillis();
    List<Prediction<Label>> predictions = model.predict(test);
    LabelEvaluation labelEvaluation = labelEvaluator.evaluate(model, predictions, test.getProvenance());
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(labelEvaluation.toString());
    ConfusionMatrix<Label> matrix = labelEvaluation.getConfusionMatrix();
    System.out.println(matrix.toString());
    if (model.generatesProbabilities()) {
        System.out.println("Average AUC = " + labelEvaluation.averageAUCROC(false));
        System.out.println("Average weighted AUC = " + labelEvaluation.averageAUCROC(true));
    }
    if (o.predictionPath != null) {
        try (BufferedWriter wrt = Files.newBufferedWriter(o.predictionPath)) {
            List<String> labels = model.getOutputIDInfo().getDomain().stream().map(Label::getLabel).sorted().collect(Collectors.toList());
            wrt.write("Label,");
            wrt.write(String.join(",", labels));
            wrt.newLine();
            for (Prediction<Label> pred : predictions) {
                Example<Label> ex = pred.getExample();
                wrt.write(ex.getOutput().getLabel() + ",");
                wrt.write(labels.stream().map(l -> Double.toString(pred.getOutputScores().get(l).getScore())).collect(Collectors.joining(",")));
                wrt.newLine();
            }
            wrt.flush();
        } catch (IOException e) {
            logger.log(Level.SEVERE, "Error writing predictions", e);
        }
    }
    if (o.general.outputPath != null) {
        try {
            o.general.saveModel(model);
        } catch (IOException e) {
            logger.log(Level.SEVERE, "Error writing model", e);
        }
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) WeightedExamples(org.tribuo.WeightedExamples) BufferedWriter(java.io.BufferedWriter) WeightedLabels(org.tribuo.classification.WeightedLabels) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) Dataset(org.tribuo.Dataset) MutableDataset(org.tribuo.MutableDataset) Prediction(org.tribuo.Prediction) IOException(java.io.IOException) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator) LabelFactory(org.tribuo.classification.LabelFactory)

Example 23 with UsageException

use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.

the class TrainTest method main.

/**
 * CLI entry point.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    TensorflowOptions o = new TensorflowOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.trainingPath == null || o.testingPath == null) {
        logger.info(cm.usage());
        return;
    }
    Pair<Dataset<Label>, Dataset<Label>> data = load(o.trainingPath, o.testingPath, new LabelFactory());
    Dataset<Label> train = data.getA();
    Dataset<Label> test = data.getB();
    if ((o.inputName == null || o.inputName.isEmpty()) || (o.outputName == null || o.outputName.isEmpty())) {
        throw new IllegalArgumentException("Must specify both 'input-name' and 'output-name'");
    }
    FeatureConverter inputConverter;
    switch(o.inputType) {
        case IMAGE:
            String[] splitFormat = o.imageFormat.split(",");
            if (splitFormat.length != 3) {
                logger.info(cm.usage());
                logger.info("Invalid image format specified. Found " + o.imageFormat);
                return;
            }
            int width = Integer.parseInt(splitFormat[0]);
            int height = Integer.parseInt(splitFormat[1]);
            int channels = Integer.parseInt(splitFormat[2]);
            inputConverter = new ImageConverter(o.inputName, width, height, channels);
            break;
        case DENSE:
            inputConverter = new DenseFeatureConverter(o.inputName);
            break;
        default:
            logger.info(cm.usage());
            logger.info("Unknown input type. Found " + o.inputType);
            return;
    }
    OutputConverter<Label> labelConverter = new LabelConverter();
    Trainer<Label> trainer;
    if (o.checkpointPath == null) {
        logger.info("Using TensorflowTrainer");
        trainer = new TensorFlowTrainer<>(o.protobufPath, o.outputName, o.optimiser, o.getGradientParams(), inputConverter, labelConverter, o.batchSize, o.epochs, o.testBatchSize, o.loggingInterval);
    } else {
        logger.info("Using TensorflowCheckpointTrainer, writing to path " + o.checkpointPath);
        trainer = new TensorFlowTrainer<>(o.protobufPath, o.outputName, o.optimiser, o.getGradientParams(), inputConverter, labelConverter, o.batchSize, o.epochs, o.testBatchSize, o.loggingInterval, o.checkpointPath);
    }
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    Model<Label> model = trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
    final long testStart = System.currentTimeMillis();
    LabelEvaluator evaluator = new LabelEvaluator();
    LabelEvaluation evaluation = evaluator.evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    if (model.generatesProbabilities()) {
        logger.info("Average AUC = " + evaluation.averageAUCROC(false));
        logger.info("Average weighted AUC = " + evaluation.averageAUCROC(true));
    }
    System.out.println(evaluation.toString());
    System.out.println(evaluation.getConfusionMatrix().toString());
    if (o.outputPath != null) {
        try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputPath.toFile()))) {
            oos.writeObject(model);
        }
        logger.info("Serialized model to file: " + o.outputPath);
    }
    if (o.checkpointPath == null) {
        ((TensorFlowNativeModel<?>) model).close();
    } else {
        ((TensorFlowCheckpointModel<?>) model).close();
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) ObjectOutputStream(java.io.ObjectOutputStream) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ImmutableDataset(org.tribuo.ImmutableDataset) Dataset(org.tribuo.Dataset) MutableDataset(org.tribuo.MutableDataset) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator) LabelFactory(org.tribuo.classification.LabelFactory) FileOutputStream(java.io.FileOutputStream)

Example 24 with UsageException

use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.

the class StripProvenance method main.

/**
 * Runs StripProvenance.
 * @param args the command line arguments
 * @param <T>  The {@link Output} subclass.
 */
@SuppressWarnings("unchecked")
public static <T extends Output<T>> void main(String[] args) {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    StripProvenanceOptions o = new StripProvenanceOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.inputModel == null || o.outputModel == null) {
        logger.info(cm.usage());
        System.exit(1);
    }
    try (ObjectInputStream ois = IOUtil.getObjectInputStream(o.inputModel);
        ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputModel))) {
        logger.info("Loading model from " + o.inputModel);
        Model<T> input = (Model<T>) ois.readObject();
        ModelProvenance oldProvenance = input.getProvenance();
        logger.info("Marshalling provenance and creating JSON.");
        JsonProvenanceSerialization jsonProvenanceSerialization = new JsonProvenanceSerialization(true);
        String jsonResult = jsonProvenanceSerialization.marshalAndSerialize(oldProvenance);
        logger.info("Hashing JSON file");
        MessageDigest digest = o.hashType.getDigest();
        byte[] digestBytes = digest.digest(jsonResult.getBytes(StandardCharsets.UTF_8));
        String provenanceHash = ProvenanceUtil.bytesToHexString(digestBytes);
        logger.info("Provenance hash = " + provenanceHash);
        if (o.provenanceFile != null) {
            logger.info("Writing JSON provenance to " + o.provenanceFile.toString());
            try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(o.provenanceFile), StandardCharsets.UTF_8))) {
                writer.println(jsonResult);
            }
        }
        ModelTuple<T> tuple = convertModel(input, provenanceHash, o);
        logger.info("Writing model to " + o.outputModel);
        oos.writeObject(tuple.model);
        ModelProvenance newProvenance = tuple.provenance;
        logger.info("Marshalling provenance and creating JSON.");
        String newJsonResult = jsonProvenanceSerialization.marshalAndSerialize(newProvenance);
        logger.info("Old provenance = \n" + jsonResult);
        logger.info("New provenance = \n" + newJsonResult);
    } catch (NoSuchMethodException e) {
        logger.log(Level.SEVERE, "Model.copy method missing on a class which extends Model.", e);
    } catch (IllegalAccessException e) {
        logger.log(Level.SEVERE, "Failed to modify protection on inner copy method on Model.", e);
    } catch (InvocationTargetException e) {
        logger.log(Level.SEVERE, "Failed to invoke inner copy method on Model.", e);
    } catch (UnsupportedEncodingException e) {
        logger.log(Level.SEVERE, "Unsupported encoding exception.", e);
    } catch (FileNotFoundException e) {
        logger.log(Level.SEVERE, "Failed to find the input file.", e);
    } catch (IOException e) {
        logger.log(Level.SEVERE, "IO error when reading or writing a file.", e);
    } catch (ClassNotFoundException e) {
        logger.log(Level.SEVERE, "The model and/or provenance classes are not on the classpath.", e);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) ModelProvenance(org.tribuo.provenance.ModelProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) FileNotFoundException(java.io.FileNotFoundException) UnsupportedEncodingException(java.io.UnsupportedEncodingException) IOException(java.io.IOException) ObjectOutputStream(java.io.ObjectOutputStream) InvocationTargetException(java.lang.reflect.InvocationTargetException) DATASET(org.tribuo.json.StripProvenance.ProvenanceTypes.DATASET) JsonProvenanceSerialization(com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceSerialization) FileOutputStream(java.io.FileOutputStream) Model(org.tribuo.Model) EnsembleModel(org.tribuo.ensemble.EnsembleModel) OutputStreamWriter(java.io.OutputStreamWriter) MessageDigest(java.security.MessageDigest) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ObjectInputStream(java.io.ObjectInputStream) PrintWriter(java.io.PrintWriter)

Example 25 with UsageException

use of com.oracle.labs.mlrg.olcut.config.UsageException in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    HdbscanCLIOptions o = new HdbscanCLIOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null) {
        logger.info(cm.usage());
        return;
    }
    ClusteringFactory factory = new ClusteringFactory();
    Pair<Dataset<ClusterID>, Dataset<ClusterID>> data = o.general.load(factory);
    Dataset<ClusterID> train = data.getA();
    HdbscanTrainer trainer = o.hdbscanOptions.getTrainer();
    Model<ClusterID> model = trainer.train(train);
    logger.info("Finished training model");
    ClusteringEvaluation evaluation = factory.getEvaluator().evaluate(model, train);
    logger.info("Finished evaluating model");
    System.out.println("Normalized MI = " + evaluation.normalizedMI());
    System.out.println("Adjusted MI = " + evaluation.adjustedMI());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) ClusterID(org.tribuo.clustering.ClusterID) Dataset(org.tribuo.Dataset) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ClusteringEvaluation(org.tribuo.clustering.evaluation.ClusteringEvaluation)

Aggregations

ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)32 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)32 Label (org.tribuo.classification.Label)15 Dataset (org.tribuo.Dataset)14 IOException (java.io.IOException)8 ObjectOutputStream (java.io.ObjectOutputStream)7 RegressionFactory (org.tribuo.regression.RegressionFactory)7 Regressor (org.tribuo.regression.Regressor)7 RegressionEvaluation (org.tribuo.regression.evaluation.RegressionEvaluation)7 FileOutputStream (java.io.FileOutputStream)6 MutableDataset (org.tribuo.MutableDataset)4 LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)4 LabelEvaluator (org.tribuo.classification.evaluation.LabelEvaluator)4 BufferedWriter (java.io.BufferedWriter)3 ObjectInputStream (java.io.ObjectInputStream)3 OutputStreamWriter (java.io.OutputStreamWriter)3 Model (org.tribuo.Model)3 LabelFactory (org.tribuo.classification.LabelFactory)3 ArgumentException (com.oracle.labs.mlrg.olcut.config.ArgumentException)2 BufferedInputStream (java.io.BufferedInputStream)2