Search in sources :

Example 31 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class TrainTest method main.

/**
 * CLI entry point.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    TensorflowOptions o = new TensorflowOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.trainingPath == null || o.testingPath == null) {
        logger.info(cm.usage());
        return;
    }
    Pair<Dataset<Label>, Dataset<Label>> data = load(o.trainingPath, o.testingPath, new LabelFactory());
    Dataset<Label> train = data.getA();
    Dataset<Label> test = data.getB();
    if ((o.inputName == null || o.inputName.isEmpty()) || (o.outputName == null || o.outputName.isEmpty())) {
        throw new IllegalArgumentException("Must specify both 'input-name' and 'output-name'");
    }
    FeatureConverter inputConverter;
    switch(o.inputType) {
        case IMAGE:
            String[] splitFormat = o.imageFormat.split(",");
            if (splitFormat.length != 3) {
                logger.info(cm.usage());
                logger.info("Invalid image format specified. Found " + o.imageFormat);
                return;
            }
            int width = Integer.parseInt(splitFormat[0]);
            int height = Integer.parseInt(splitFormat[1]);
            int channels = Integer.parseInt(splitFormat[2]);
            inputConverter = new ImageConverter(o.inputName, width, height, channels);
            break;
        case DENSE:
            inputConverter = new DenseFeatureConverter(o.inputName);
            break;
        default:
            logger.info(cm.usage());
            logger.info("Unknown input type. Found " + o.inputType);
            return;
    }
    OutputConverter<Label> labelConverter = new LabelConverter();
    Trainer<Label> trainer;
    if (o.checkpointPath == null) {
        logger.info("Using TensorflowTrainer");
        trainer = new TensorFlowTrainer<>(o.protobufPath, o.outputName, o.optimiser, o.getGradientParams(), inputConverter, labelConverter, o.batchSize, o.epochs, o.testBatchSize, o.loggingInterval);
    } else {
        logger.info("Using TensorflowCheckpointTrainer, writing to path " + o.checkpointPath);
        trainer = new TensorFlowTrainer<>(o.protobufPath, o.outputName, o.optimiser, o.getGradientParams(), inputConverter, labelConverter, o.batchSize, o.epochs, o.testBatchSize, o.loggingInterval, o.checkpointPath);
    }
    logger.info("Training using " + trainer.toString());
    final long trainStart = System.currentTimeMillis();
    Model<Label> model = trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
    final long testStart = System.currentTimeMillis();
    LabelEvaluator evaluator = new LabelEvaluator();
    LabelEvaluation evaluation = evaluator.evaluate(model, test);
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    if (model.generatesProbabilities()) {
        logger.info("Average AUC = " + evaluation.averageAUCROC(false));
        logger.info("Average weighted AUC = " + evaluation.averageAUCROC(true));
    }
    System.out.println(evaluation.toString());
    System.out.println(evaluation.getConfusionMatrix().toString());
    if (o.outputPath != null) {
        try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputPath.toFile()))) {
            oos.writeObject(model);
        }
        logger.info("Serialized model to file: " + o.outputPath);
    }
    if (o.checkpointPath == null) {
        ((TensorFlowNativeModel<?>) model).close();
    } else {
        ((TensorFlowCheckpointModel<?>) model).close();
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) ObjectOutputStream(java.io.ObjectOutputStream) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ImmutableDataset(org.tribuo.ImmutableDataset) Dataset(org.tribuo.Dataset) MutableDataset(org.tribuo.MutableDataset) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator) LabelFactory(org.tribuo.classification.LabelFactory) FileOutputStream(java.io.FileOutputStream)

Example 32 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class BERTFeatureExtractor method main.

/**
 * Test harness for running a BERT model and inspecting the output.
 * @param args The CLI arguments.
 * @throws IOException If the files couldn't be read or written to.
 * @throws OrtException If the BERT model failed to load, or threw an exception during computation.
 */
public static void main(String[] args) throws IOException, OrtException {
    BERTFeatureExtractorOptions opts = new BERTFeatureExtractorOptions();
    ConfigurationManager cm = new ConfigurationManager(args, opts);
    List<String> lines = Files.readAllLines(opts.inputFile, StandardCharsets.UTF_8);
    List<BERTResult> results = new ArrayList<>();
    for (String line : lines) {
        results.add(opts.bert.bert(line));
    }
    ObjectMapper mapper = new ObjectMapper();
    try (BufferedWriter writer = new BufferedWriter(new FileWriter(opts.outputFile.toFile()))) {
        writer.write(mapper.writeValueAsString(results));
    }
}
Also used : FileWriter(java.io.FileWriter) ArrayList(java.util.ArrayList) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) BufferedWriter(java.io.BufferedWriter)

Example 33 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class ModelExplorer method main.

/**
 * Entry point.
 * @param args CLI args.
 */
public static void main(String[] args) {
    ModelExplorerOptions options = new ModelExplorerOptions();
    ConfigurationManager cm = new ConfigurationManager(args, options, false);
    ModelExplorer driver = new ModelExplorer();
    if (options.modelFilename != null) {
        logger.log(Level.INFO, driver.loadModel(driver.shell, new File(options.modelFilename)));
    }
    driver.startShell();
}
Also used : ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) File(java.io.File)

Example 34 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class StripProvenance method main.

/**
 * Runs StripProvenance.
 * @param args the command line arguments
 * @param <T>  The {@link Output} subclass.
 */
@SuppressWarnings("unchecked")
public static <T extends Output<T>> void main(String[] args) {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    StripProvenanceOptions o = new StripProvenanceOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.inputModel == null || o.outputModel == null) {
        logger.info(cm.usage());
        System.exit(1);
    }
    try (ObjectInputStream ois = IOUtil.getObjectInputStream(o.inputModel);
        ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputModel))) {
        logger.info("Loading model from " + o.inputModel);
        Model<T> input = (Model<T>) ois.readObject();
        ModelProvenance oldProvenance = input.getProvenance();
        logger.info("Marshalling provenance and creating JSON.");
        JsonProvenanceSerialization jsonProvenanceSerialization = new JsonProvenanceSerialization(true);
        String jsonResult = jsonProvenanceSerialization.marshalAndSerialize(oldProvenance);
        logger.info("Hashing JSON file");
        MessageDigest digest = o.hashType.getDigest();
        byte[] digestBytes = digest.digest(jsonResult.getBytes(StandardCharsets.UTF_8));
        String provenanceHash = ProvenanceUtil.bytesToHexString(digestBytes);
        logger.info("Provenance hash = " + provenanceHash);
        if (o.provenanceFile != null) {
            logger.info("Writing JSON provenance to " + o.provenanceFile.toString());
            try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(o.provenanceFile), StandardCharsets.UTF_8))) {
                writer.println(jsonResult);
            }
        }
        ModelTuple<T> tuple = convertModel(input, provenanceHash, o);
        logger.info("Writing model to " + o.outputModel);
        oos.writeObject(tuple.model);
        ModelProvenance newProvenance = tuple.provenance;
        logger.info("Marshalling provenance and creating JSON.");
        String newJsonResult = jsonProvenanceSerialization.marshalAndSerialize(newProvenance);
        logger.info("Old provenance = \n" + jsonResult);
        logger.info("New provenance = \n" + newJsonResult);
    } catch (NoSuchMethodException e) {
        logger.log(Level.SEVERE, "Model.copy method missing on a class which extends Model.", e);
    } catch (IllegalAccessException e) {
        logger.log(Level.SEVERE, "Failed to modify protection on inner copy method on Model.", e);
    } catch (InvocationTargetException e) {
        logger.log(Level.SEVERE, "Failed to invoke inner copy method on Model.", e);
    } catch (UnsupportedEncodingException e) {
        logger.log(Level.SEVERE, "Unsupported encoding exception.", e);
    } catch (FileNotFoundException e) {
        logger.log(Level.SEVERE, "Failed to find the input file.", e);
    } catch (IOException e) {
        logger.log(Level.SEVERE, "IO error when reading or writing a file.", e);
    } catch (ClassNotFoundException e) {
        logger.log(Level.SEVERE, "The model and/or provenance classes are not on the classpath.", e);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) ModelProvenance(org.tribuo.provenance.ModelProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) FileNotFoundException(java.io.FileNotFoundException) UnsupportedEncodingException(java.io.UnsupportedEncodingException) IOException(java.io.IOException) ObjectOutputStream(java.io.ObjectOutputStream) InvocationTargetException(java.lang.reflect.InvocationTargetException) DATASET(org.tribuo.json.StripProvenance.ProvenanceTypes.DATASET) JsonProvenanceSerialization(com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceSerialization) FileOutputStream(java.io.FileOutputStream) Model(org.tribuo.Model) EnsembleModel(org.tribuo.ensemble.EnsembleModel) OutputStreamWriter(java.io.OutputStreamWriter) MessageDigest(java.security.MessageDigest) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ObjectInputStream(java.io.ObjectInputStream) PrintWriter(java.io.PrintWriter)

Example 35 with ConfigurationManager

use of com.oracle.labs.mlrg.olcut.config.ConfigurationManager in project tribuo by oracle.

the class TrainTest method main.

/**
 * Runs a TrainTest CLI.
 * @param args the command line arguments
 * @throws IOException if there is any error reading the examples.
 */
public static void main(String[] args) throws IOException {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    HdbscanCLIOptions o = new HdbscanCLIOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null) {
        logger.info(cm.usage());
        return;
    }
    ClusteringFactory factory = new ClusteringFactory();
    Pair<Dataset<ClusterID>, Dataset<ClusterID>> data = o.general.load(factory);
    Dataset<ClusterID> train = data.getA();
    HdbscanTrainer trainer = o.hdbscanOptions.getTrainer();
    Model<ClusterID> model = trainer.train(train);
    logger.info("Finished training model");
    ClusteringEvaluation evaluation = factory.getEvaluator().evaluate(model, train);
    logger.info("Finished evaluating model");
    System.out.println("Normalized MI = " + evaluation.normalizedMI());
    System.out.println("Adjusted MI = " + evaluation.adjustedMI());
    if (o.general.outputPath != null) {
        o.general.saveModel(model);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) ClusterID(org.tribuo.clustering.ClusterID) Dataset(org.tribuo.Dataset) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ClusteringEvaluation(org.tribuo.clustering.evaluation.ClusteringEvaluation)

Aggregations

ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)42 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)32 Label (org.tribuo.classification.Label)16 Dataset (org.tribuo.Dataset)15 IOException (java.io.IOException)8 FileOutputStream (java.io.FileOutputStream)7 ObjectOutputStream (java.io.ObjectOutputStream)7 RegressionFactory (org.tribuo.regression.RegressionFactory)7 Regressor (org.tribuo.regression.Regressor)7 RegressionEvaluation (org.tribuo.regression.evaluation.RegressionEvaluation)7 BufferedWriter (java.io.BufferedWriter)4 File (java.io.File)4 OutputStreamWriter (java.io.OutputStreamWriter)4 MutableDataset (org.tribuo.MutableDataset)4 LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)4 LabelEvaluator (org.tribuo.classification.evaluation.LabelEvaluator)4 FileInputStream (java.io.FileInputStream)3 ObjectInputStream (java.io.ObjectInputStream)3 PrintWriter (java.io.PrintWriter)3 ArrayList (java.util.ArrayList)3