use of org.tribuo.provenance.EnsembleModelProvenance in project tribuo by oracle.
the class StripProvenance method cleanEnsembleProvenance.
/**
* Creates a new ensemble provenance with the requested information removed.
* @param old The old ensemble provenance.
* @param memberProvenance The new member provenances.
* @param provenanceHash The old ensemble provenance hash.
* @param opt The program options.
* @return The new ensemble provenance with the requested fields removed.
*/
private static EnsembleModelProvenance cleanEnsembleProvenance(EnsembleModelProvenance old, ListProvenance<ModelProvenance> memberProvenance, String provenanceHash, StripProvenanceOptions opt) {
// Dataset provenance
DatasetProvenance datasetProvenance;
if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(DATASET)) {
datasetProvenance = new EmptyDatasetProvenance();
} else {
datasetProvenance = old.getDatasetProvenance();
}
// Trainer provenance
TrainerProvenance trainerProvenance;
if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(TRAINER)) {
trainerProvenance = new EmptyTrainerProvenance();
} else {
trainerProvenance = old.getTrainerProvenance();
}
// Instance provenance
OffsetDateTime time;
Map<String, Provenance> instanceProvenance;
if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(INSTANCE)) {
instanceProvenance = new HashMap<>();
time = OffsetDateTime.MIN;
} else {
instanceProvenance = new HashMap<>(old.getInstanceProvenance().getMap());
time = old.getTrainingTime();
}
if (opt.storeHash) {
logger.info("Writing provenance hash into instance map.");
instanceProvenance.put("original-provenance-hash", new HashProvenance(opt.hashType, "original-provenance-hash", provenanceHash));
}
return new EnsembleModelProvenance(old.getClassName(), time, datasetProvenance, trainerProvenance, instanceProvenance, memberProvenance);
}
use of org.tribuo.provenance.EnsembleModelProvenance in project tribuo by oracle.
the class StripProvenance method convertModel.
/**
* Creates a copy of the old model with the requested provenance removed.
* @param oldModel The model to remove provenance from.
* @param provenanceHash A hash of the old provenance.
* @param opt The program options.
* @param <T> The output type.
* @return A copy of the model with redacted provenance.
* @throws InvocationTargetException If the model doesn't expose a copy method (all models should do).
* @throws IllegalAccessException If the model's copy method is not accessible.
* @throws NoSuchMethodException If the model's copy method isn't present.
*/
// cast of model after call to copy which returns model.
@SuppressWarnings("unchecked")
private static <T extends Output<T>> ModelTuple<T> convertModel(Model<T> oldModel, String provenanceHash, StripProvenanceOptions opt) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
if (oldModel instanceof EnsembleModel) {
EnsembleModelProvenance oldProvenance = ((EnsembleModel<T>) oldModel).getProvenance();
List<ModelProvenance> newProvenances = new ArrayList<>();
List<Model<T>> newModels = new ArrayList<>();
for (Model<T> e : ((EnsembleModel<T>) oldModel).getModels()) {
ModelTuple<T> tuple = convertModel(e, provenanceHash, opt);
newProvenances.add(tuple.provenance);
newModels.add(tuple.model);
}
ListProvenance<ModelProvenance> listProv = new ListProvenance<>(newProvenances);
EnsembleModelProvenance cleanedProvenance = cleanEnsembleProvenance(oldProvenance, listProv, provenanceHash, opt);
Class<? extends Model> clazz = oldModel.getClass();
Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class, List.class);
boolean accessible = copyMethod.isAccessible();
copyMethod.setAccessible(true);
String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced";
EnsembleModel<T> output = (EnsembleModel<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance, newModels);
copyMethod.setAccessible(accessible);
return new ModelTuple<>(output, cleanedProvenance);
} else {
ModelProvenance oldProvenance = oldModel.getProvenance();
ModelProvenance cleanedProvenance = cleanProvenance(oldProvenance, provenanceHash, opt);
Class<? extends Model> clazz = oldModel.getClass();
Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class);
boolean accessible = copyMethod.isAccessible();
copyMethod.setAccessible(true);
String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced";
Model<T> output = (Model<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance);
copyMethod.setAccessible(accessible);
return new ModelTuple<>(output, cleanedProvenance);
}
}
use of org.tribuo.provenance.EnsembleModelProvenance in project tribuo by oracle.
the class BaggingTrainer method train.
@Override
public EnsembleModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
// Creates a new RNG, adds one to the invocation count.
SplittableRandom localRNG;
TrainerProvenance trainerProvenance;
synchronized (this) {
if (invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
localRNG = rng.split();
trainerProvenance = getProvenance();
trainInvocationCounter++;
}
ImmutableFeatureMap featureIDs = examples.getFeatureIDMap();
ImmutableOutputInfo<T> labelIDs = examples.getOutputIDInfo();
ArrayList<Model<T>> models = new ArrayList<>();
int initialInovcation = innerTrainer.getInvocationCount();
for (int i = 0; i < numMembers; i++) {
logger.info("Building model " + i);
models.add(trainSingleModel(examples, featureIDs, labelIDs, localRNG.nextInt(), runProvenance, initialInovcation + i));
}
EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
return new WeightedEnsembleModel<>(ensembleName(), provenance, featureIDs, labelIDs, models, combiner);
}
use of org.tribuo.provenance.EnsembleModelProvenance in project tribuo by oracle.
the class WeightedEnsembleModel method createEnsembleFromExistingModels.
/**
* Creates an ensemble from existing models.
* <p>
* Uses the feature and output domain from the first model as the ensemble model's domains.
* The individual ensemble members use the domains that they contain.
* <p>
* If the output domains don't cover the same dimensions then it throws {@link IllegalArgumentException}.
* If the weights aren't the same length as the models it throws {@link IllegalArgumentException}.
* @param name The ensemble name.
* @param models The ensemble members.
* @param combiner The combination function.
* @param weights The model combination weights.
* @param <T> The output type.
* @return A weighted ensemble model.
*/
public static <T extends Output<T>> WeightedEnsembleModel<T> createEnsembleFromExistingModels(String name, List<Model<T>> models, EnsembleCombiner<T> combiner, float[] weights) {
// Basic parameter validation
if (models.size() < 2) {
throw new IllegalArgumentException("Must supply at least 2 models, found " + models.size());
}
if (weights.length != models.size()) {
throw new IllegalArgumentException("Must supply one weight per model, models.size() = " + models.size() + ", weights.length = " + weights.length);
}
// Validate output & feature domains
ImmutableOutputInfo<T> outputInfo = models.get(0).getOutputIDInfo();
ImmutableFeatureMap featureMap = models.get(0).getFeatureIDMap();
Set<T> firstOutputDomain = outputInfo.getDomain();
for (int i = 1; i < models.size(); i++) {
if (!models.get(i).getOutputIDInfo().getDomain().equals(firstOutputDomain)) {
throw new IllegalArgumentException("Model output domains are not equal.");
}
if (!models.get(i).getFeatureIDMap().domainEquals(featureMap)) {
throw new IllegalArgumentException("Model feature domains are not equal.");
}
}
// Defensive copy the model list (the weights are copied in the constructor)
List<Model<T>> modelList = new ArrayList<>(models);
// Build EnsembleModelProvenance
TimestampedTrainerProvenance trainerProvenance = new TimestampedTrainerProvenance();
EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), models.get(0).getProvenance().getDatasetProvenance(), trainerProvenance, ListProvenance.createListProvenance(models));
return new WeightedEnsembleModel<>(name, provenance, featureMap, outputInfo, modelList, combiner, weights);
}
use of org.tribuo.provenance.EnsembleModelProvenance in project tribuo by oracle.
the class AdaBoostTrainer method train.
@Override
public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
if (examples.getOutputInfo().getUnknownCount() > 0) {
throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
}
// Creates a new RNG, adds one to the invocation count.
SplittableRandom localRNG;
TrainerProvenance trainerProvenance;
synchronized (this) {
if (invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
localRNG = rng.split();
trainerProvenance = getProvenance();
trainInvocationCounter++;
}
boolean weighted = innerTrainer instanceof WeightedExamples;
ImmutableFeatureMap featureIDs = examples.getFeatureIDMap();
ImmutableOutputInfo<Label> labelIDs = examples.getOutputIDInfo();
int numClasses = labelIDs.size();
logger.log(Level.INFO, "NumClasses = " + numClasses);
ArrayList<Model<Label>> models = new ArrayList<>();
float[] modelWeights = new float[numMembers];
float[] exampleWeights = Util.generateUniformFloatVector(examples.size(), 1.0f / examples.size());
if (weighted) {
logger.info("Using weighted Adaboost.");
examples = ImmutableDataset.copyDataset(examples);
for (int i = 0; i < examples.size(); i++) {
Example<Label> e = examples.getExample(i);
e.setWeight(exampleWeights[i]);
}
} else {
logger.info("Using sampling Adaboost.");
}
for (int i = 0; i < numMembers; i++) {
logger.info("Building model " + i);
Model<Label> newModel;
if (weighted) {
newModel = innerTrainer.train(examples);
} else {
DatasetView<Label> bag = DatasetView.createWeightedBootstrapView(examples, examples.size(), localRNG.nextLong(), exampleWeights, featureIDs, labelIDs);
newModel = innerTrainer.train(bag);
}
//
// Score this model
List<Prediction<Label>> predictions = newModel.predict(examples);
float accuracy = accuracy(predictions, examples, exampleWeights);
float error = 1.0f - accuracy;
float alpha = (float) (Math.log(accuracy / error) + Math.log(numClasses - 1));
models.add(newModel);
modelWeights[i] = alpha;
if ((accuracy + 1e-10) > 1.0) {
//
// Perfect accuracy, can no longer boost.
float[] newModelWeights = Arrays.copyOf(modelWeights, models.size());
// Set the last weight to 1, as it's infinity.
newModelWeights[models.size() - 1] = 1.0f;
logger.log(Level.FINE, "Perfect accuracy reached on iteration " + i + ", returning current model.");
logger.log(Level.FINE, "Model weights:");
Util.logVector(logger, Level.FINE, newModelWeights);
EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
return new WeightedEnsembleModel<>("boosted-ensemble", provenance, featureIDs, labelIDs, models, new VotingCombiner(), newModelWeights);
}
// Update the weights
for (int j = 0; j < predictions.size(); j++) {
if (!predictions.get(j).getOutput().equals(examples.getExample(j).getOutput())) {
exampleWeights[j] *= Math.exp(alpha);
}
}
Util.inplaceNormalizeToDistribution(exampleWeights);
if (weighted) {
for (int j = 0; j < examples.size(); j++) {
examples.getExample(j).setWeight(exampleWeights[j]);
}
}
}
logger.log(Level.FINE, "Model weights:");
Util.logVector(logger, Level.FINE, modelWeights);
EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
return new WeightedEnsembleModel<>("boosted-ensemble", provenance, featureIDs, labelIDs, models, new VotingCombiner(), modelWeights);
}
Aggregations