Search in sources :

Example 1 with EnsembleModel

use of org.tribuo.ensemble.EnsembleModel in project tribuo by oracle.

the class StripProvenance method convertModel.

/**
 * Creates a copy of the old model with the requested provenance removed.
 * @param oldModel The model to remove provenance from.
 * @param provenanceHash A hash of the old provenance.
 * @param opt The program options.
 * @param <T> The output type.
 * @return A copy of the model with redacted provenance.
 * @throws InvocationTargetException If the model doesn't expose a copy method (all models should do).
 * @throws IllegalAccessException If the model's copy method is not accessible.
 * @throws NoSuchMethodException If the model's copy method isn't present.
 */
// cast of model after call to copy which returns model.
@SuppressWarnings("unchecked")
private static <T extends Output<T>> ModelTuple<T> convertModel(Model<T> oldModel, String provenanceHash, StripProvenanceOptions opt) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
    if (oldModel instanceof EnsembleModel) {
        EnsembleModelProvenance oldProvenance = ((EnsembleModel<T>) oldModel).getProvenance();
        List<ModelProvenance> newProvenances = new ArrayList<>();
        List<Model<T>> newModels = new ArrayList<>();
        for (Model<T> e : ((EnsembleModel<T>) oldModel).getModels()) {
            ModelTuple<T> tuple = convertModel(e, provenanceHash, opt);
            newProvenances.add(tuple.provenance);
            newModels.add(tuple.model);
        }
        ListProvenance<ModelProvenance> listProv = new ListProvenance<>(newProvenances);
        EnsembleModelProvenance cleanedProvenance = cleanEnsembleProvenance(oldProvenance, listProv, provenanceHash, opt);
        Class<? extends Model> clazz = oldModel.getClass();
        Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class, List.class);
        boolean accessible = copyMethod.isAccessible();
        copyMethod.setAccessible(true);
        String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced";
        EnsembleModel<T> output = (EnsembleModel<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance, newModels);
        copyMethod.setAccessible(accessible);
        return new ModelTuple<>(output, cleanedProvenance);
    } else {
        ModelProvenance oldProvenance = oldModel.getProvenance();
        ModelProvenance cleanedProvenance = cleanProvenance(oldProvenance, provenanceHash, opt);
        Class<? extends Model> clazz = oldModel.getClass();
        Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class);
        boolean accessible = copyMethod.isAccessible();
        copyMethod.setAccessible(true);
        String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced";
        Model<T> output = (Model<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance);
        copyMethod.setAccessible(accessible);
        return new ModelTuple<>(output, cleanedProvenance);
    }
}
Also used : ModelProvenance(org.tribuo.provenance.ModelProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) ArrayList(java.util.ArrayList) Method(java.lang.reflect.Method) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) DATASET(org.tribuo.json.StripProvenance.ProvenanceTypes.DATASET) Model(org.tribuo.Model) EnsembleModel(org.tribuo.ensemble.EnsembleModel) ListProvenance(com.oracle.labs.mlrg.olcut.provenance.ListProvenance) EnsembleModel(org.tribuo.ensemble.EnsembleModel)

Aggregations

ListProvenance (com.oracle.labs.mlrg.olcut.provenance.ListProvenance)1 Method (java.lang.reflect.Method)1 ArrayList (java.util.ArrayList)1 Model (org.tribuo.Model)1 EnsembleModel (org.tribuo.ensemble.EnsembleModel)1 DATASET (org.tribuo.json.StripProvenance.ProvenanceTypes.DATASET)1 EnsembleModelProvenance (org.tribuo.provenance.EnsembleModelProvenance)1 ModelProvenance (org.tribuo.provenance.ModelProvenance)1