use of com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceSerialization in project tribuo by oracle.
the class StripProvenance method main.
/**
* Runs StripProvenance.
* @param args the command line arguments
* @param <T> The {@link Output} subclass.
*/
@SuppressWarnings("unchecked")
public static <T extends Output<T>> void main(String[] args) {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
StripProvenanceOptions o = new StripProvenanceOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.inputModel == null || o.outputModel == null) {
logger.info(cm.usage());
System.exit(1);
}
try (ObjectInputStream ois = IOUtil.getObjectInputStream(o.inputModel);
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputModel))) {
logger.info("Loading model from " + o.inputModel);
Model<T> input = (Model<T>) ois.readObject();
ModelProvenance oldProvenance = input.getProvenance();
logger.info("Marshalling provenance and creating JSON.");
JsonProvenanceSerialization jsonProvenanceSerialization = new JsonProvenanceSerialization(true);
String jsonResult = jsonProvenanceSerialization.marshalAndSerialize(oldProvenance);
logger.info("Hashing JSON file");
MessageDigest digest = o.hashType.getDigest();
byte[] digestBytes = digest.digest(jsonResult.getBytes(StandardCharsets.UTF_8));
String provenanceHash = ProvenanceUtil.bytesToHexString(digestBytes);
logger.info("Provenance hash = " + provenanceHash);
if (o.provenanceFile != null) {
logger.info("Writing JSON provenance to " + o.provenanceFile.toString());
try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(o.provenanceFile), StandardCharsets.UTF_8))) {
writer.println(jsonResult);
}
}
ModelTuple<T> tuple = convertModel(input, provenanceHash, o);
logger.info("Writing model to " + o.outputModel);
oos.writeObject(tuple.model);
ModelProvenance newProvenance = tuple.provenance;
logger.info("Marshalling provenance and creating JSON.");
String newJsonResult = jsonProvenanceSerialization.marshalAndSerialize(newProvenance);
logger.info("Old provenance = \n" + jsonResult);
logger.info("New provenance = \n" + newJsonResult);
} catch (NoSuchMethodException e) {
logger.log(Level.SEVERE, "Model.copy method missing on a class which extends Model.", e);
} catch (IllegalAccessException e) {
logger.log(Level.SEVERE, "Failed to modify protection on inner copy method on Model.", e);
} catch (InvocationTargetException e) {
logger.log(Level.SEVERE, "Failed to invoke inner copy method on Model.", e);
} catch (UnsupportedEncodingException e) {
logger.log(Level.SEVERE, "Unsupported encoding exception.", e);
} catch (FileNotFoundException e) {
logger.log(Level.SEVERE, "Failed to find the input file.", e);
} catch (IOException e) {
logger.log(Level.SEVERE, "IO error when reading or writing a file.", e);
} catch (ClassNotFoundException e) {
logger.log(Level.SEVERE, "The model and/or provenance classes are not on the classpath.", e);
}
}
Aggregations