use of org.tribuo.ensemble.EnsembleModel in project tribuo by oracle.
the class StripProvenance method convertModel.
/**
* Creates a copy of the old model with the requested provenance removed.
* @param oldModel The model to remove provenance from.
* @param provenanceHash A hash of the old provenance.
* @param opt The program options.
* @param <T> The output type.
* @return A copy of the model with redacted provenance.
* @throws InvocationTargetException If the model doesn't expose a copy method (all models should do).
* @throws IllegalAccessException If the model's copy method is not accessible.
* @throws NoSuchMethodException If the model's copy method isn't present.
*/
// cast of model after call to copy which returns model.
@SuppressWarnings("unchecked")
private static <T extends Output<T>> ModelTuple<T> convertModel(Model<T> oldModel, String provenanceHash, StripProvenanceOptions opt) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
if (oldModel instanceof EnsembleModel) {
EnsembleModelProvenance oldProvenance = ((EnsembleModel<T>) oldModel).getProvenance();
List<ModelProvenance> newProvenances = new ArrayList<>();
List<Model<T>> newModels = new ArrayList<>();
for (Model<T> e : ((EnsembleModel<T>) oldModel).getModels()) {
ModelTuple<T> tuple = convertModel(e, provenanceHash, opt);
newProvenances.add(tuple.provenance);
newModels.add(tuple.model);
}
ListProvenance<ModelProvenance> listProv = new ListProvenance<>(newProvenances);
EnsembleModelProvenance cleanedProvenance = cleanEnsembleProvenance(oldProvenance, listProv, provenanceHash, opt);
Class<? extends Model> clazz = oldModel.getClass();
Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class, List.class);
boolean accessible = copyMethod.isAccessible();
copyMethod.setAccessible(true);
String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced";
EnsembleModel<T> output = (EnsembleModel<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance, newModels);
copyMethod.setAccessible(accessible);
return new ModelTuple<>(output, cleanedProvenance);
} else {
ModelProvenance oldProvenance = oldModel.getProvenance();
ModelProvenance cleanedProvenance = cleanProvenance(oldProvenance, provenanceHash, opt);
Class<? extends Model> clazz = oldModel.getClass();
Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class);
boolean accessible = copyMethod.isAccessible();
copyMethod.setAccessible(true);
String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced";
Model<T> output = (Model<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance);
copyMethod.setAccessible(accessible);
return new ModelTuple<>(output, cleanedProvenance);
}
}
Aggregations