Search in sources :

Example 1 with JsonProvenanceSerialization

use of com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceSerialization in project tribuo by oracle.

the class StripProvenance method main.

/**
 * Runs StripProvenance.
 * @param args the command line arguments
 * @param <T>  The {@link Output} subclass.
 */
@SuppressWarnings("unchecked")
public static <T extends Output<T>> void main(String[] args) {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    StripProvenanceOptions o = new StripProvenanceOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.inputModel == null || o.outputModel == null) {
        logger.info(cm.usage());
        System.exit(1);
    }
    try (ObjectInputStream ois = IOUtil.getObjectInputStream(o.inputModel);
        ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputModel))) {
        logger.info("Loading model from " + o.inputModel);
        Model<T> input = (Model<T>) ois.readObject();
        ModelProvenance oldProvenance = input.getProvenance();
        logger.info("Marshalling provenance and creating JSON.");
        JsonProvenanceSerialization jsonProvenanceSerialization = new JsonProvenanceSerialization(true);
        String jsonResult = jsonProvenanceSerialization.marshalAndSerialize(oldProvenance);
        logger.info("Hashing JSON file");
        MessageDigest digest = o.hashType.getDigest();
        byte[] digestBytes = digest.digest(jsonResult.getBytes(StandardCharsets.UTF_8));
        String provenanceHash = ProvenanceUtil.bytesToHexString(digestBytes);
        logger.info("Provenance hash = " + provenanceHash);
        if (o.provenanceFile != null) {
            logger.info("Writing JSON provenance to " + o.provenanceFile.toString());
            try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(o.provenanceFile), StandardCharsets.UTF_8))) {
                writer.println(jsonResult);
            }
        }
        ModelTuple<T> tuple = convertModel(input, provenanceHash, o);
        logger.info("Writing model to " + o.outputModel);
        oos.writeObject(tuple.model);
        ModelProvenance newProvenance = tuple.provenance;
        logger.info("Marshalling provenance and creating JSON.");
        String newJsonResult = jsonProvenanceSerialization.marshalAndSerialize(newProvenance);
        logger.info("Old provenance = \n" + jsonResult);
        logger.info("New provenance = \n" + newJsonResult);
    } catch (NoSuchMethodException e) {
        logger.log(Level.SEVERE, "Model.copy method missing on a class which extends Model.", e);
    } catch (IllegalAccessException e) {
        logger.log(Level.SEVERE, "Failed to modify protection on inner copy method on Model.", e);
    } catch (InvocationTargetException e) {
        logger.log(Level.SEVERE, "Failed to invoke inner copy method on Model.", e);
    } catch (UnsupportedEncodingException e) {
        logger.log(Level.SEVERE, "Unsupported encoding exception.", e);
    } catch (FileNotFoundException e) {
        logger.log(Level.SEVERE, "Failed to find the input file.", e);
    } catch (IOException e) {
        logger.log(Level.SEVERE, "IO error when reading or writing a file.", e);
    } catch (ClassNotFoundException e) {
        logger.log(Level.SEVERE, "The model and/or provenance classes are not on the classpath.", e);
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) ModelProvenance(org.tribuo.provenance.ModelProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) FileNotFoundException(java.io.FileNotFoundException) UnsupportedEncodingException(java.io.UnsupportedEncodingException) IOException(java.io.IOException) ObjectOutputStream(java.io.ObjectOutputStream) InvocationTargetException(java.lang.reflect.InvocationTargetException) DATASET(org.tribuo.json.StripProvenance.ProvenanceTypes.DATASET) JsonProvenanceSerialization(com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceSerialization) FileOutputStream(java.io.FileOutputStream) Model(org.tribuo.Model) EnsembleModel(org.tribuo.ensemble.EnsembleModel) OutputStreamWriter(java.io.OutputStreamWriter) MessageDigest(java.security.MessageDigest) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) ObjectInputStream(java.io.ObjectInputStream) PrintWriter(java.io.PrintWriter)

Aggregations

ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)1 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)1 JsonProvenanceSerialization (com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceSerialization)1 FileNotFoundException (java.io.FileNotFoundException)1 FileOutputStream (java.io.FileOutputStream)1 IOException (java.io.IOException)1 ObjectInputStream (java.io.ObjectInputStream)1 ObjectOutputStream (java.io.ObjectOutputStream)1 OutputStreamWriter (java.io.OutputStreamWriter)1 PrintWriter (java.io.PrintWriter)1 UnsupportedEncodingException (java.io.UnsupportedEncodingException)1 InvocationTargetException (java.lang.reflect.InvocationTargetException)1 MessageDigest (java.security.MessageDigest)1 Model (org.tribuo.Model)1 EnsembleModel (org.tribuo.ensemble.EnsembleModel)1 DATASET (org.tribuo.json.StripProvenance.ProvenanceTypes.DATASET)1 EnsembleModelProvenance (org.tribuo.provenance.EnsembleModelProvenance)1 ModelProvenance (org.tribuo.provenance.ModelProvenance)1