Search in sources :

Example 1 with EnsembleModelProvenance

use of org.tribuo.provenance.EnsembleModelProvenance in project tribuo by oracle.

the class StripProvenance method cleanEnsembleProvenance.

/**
 * Creates a new ensemble provenance with the requested information removed.
 * @param old The old ensemble provenance.
 * @param memberProvenance The new member provenances.
 * @param provenanceHash The old ensemble provenance hash.
 * @param opt The program options.
 * @return The new ensemble provenance with the requested fields removed.
 */
private static EnsembleModelProvenance cleanEnsembleProvenance(EnsembleModelProvenance old, ListProvenance<ModelProvenance> memberProvenance, String provenanceHash, StripProvenanceOptions opt) {
    // Dataset provenance
    DatasetProvenance datasetProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(DATASET)) {
        datasetProvenance = new EmptyDatasetProvenance();
    } else {
        datasetProvenance = old.getDatasetProvenance();
    }
    // Trainer provenance
    TrainerProvenance trainerProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(TRAINER)) {
        trainerProvenance = new EmptyTrainerProvenance();
    } else {
        trainerProvenance = old.getTrainerProvenance();
    }
    // Instance provenance
    OffsetDateTime time;
    Map<String, Provenance> instanceProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(INSTANCE)) {
        instanceProvenance = new HashMap<>();
        time = OffsetDateTime.MIN;
    } else {
        instanceProvenance = new HashMap<>(old.getInstanceProvenance().getMap());
        time = old.getTrainingTime();
    }
    if (opt.storeHash) {
        logger.info("Writing provenance hash into instance map.");
        instanceProvenance.put("original-provenance-hash", new HashProvenance(opt.hashType, "original-provenance-hash", provenanceHash));
    }
    return new EnsembleModelProvenance(old.getClassName(), time, datasetProvenance, trainerProvenance, instanceProvenance, memberProvenance);
}
Also used : HashProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance) ObjectProvenance(com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) HashProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance) Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) ListProvenance(com.oracle.labs.mlrg.olcut.provenance.ListProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) OffsetDateTime(java.time.OffsetDateTime) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance)

Example 2 with EnsembleModelProvenance

use of org.tribuo.provenance.EnsembleModelProvenance in project tribuo by oracle.

the class StripProvenance method convertModel.

/**
 * Creates a copy of the old model with the requested provenance removed.
 * @param oldModel The model to remove provenance from.
 * @param provenanceHash A hash of the old provenance.
 * @param opt The program options.
 * @param <T> The output type.
 * @return A copy of the model with redacted provenance.
 * @throws InvocationTargetException If the model doesn't expose a copy method (all models should do).
 * @throws IllegalAccessException If the model's copy method is not accessible.
 * @throws NoSuchMethodException If the model's copy method isn't present.
 */
// cast of model after call to copy which returns model.
@SuppressWarnings("unchecked")
private static <T extends Output<T>> ModelTuple<T> convertModel(Model<T> oldModel, String provenanceHash, StripProvenanceOptions opt) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
    if (oldModel instanceof EnsembleModel) {
        EnsembleModelProvenance oldProvenance = ((EnsembleModel<T>) oldModel).getProvenance();
        List<ModelProvenance> newProvenances = new ArrayList<>();
        List<Model<T>> newModels = new ArrayList<>();
        for (Model<T> e : ((EnsembleModel<T>) oldModel).getModels()) {
            ModelTuple<T> tuple = convertModel(e, provenanceHash, opt);
            newProvenances.add(tuple.provenance);
            newModels.add(tuple.model);
        }
        ListProvenance<ModelProvenance> listProv = new ListProvenance<>(newProvenances);
        EnsembleModelProvenance cleanedProvenance = cleanEnsembleProvenance(oldProvenance, listProv, provenanceHash, opt);
        Class<? extends Model> clazz = oldModel.getClass();
        Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class, List.class);
        boolean accessible = copyMethod.isAccessible();
        copyMethod.setAccessible(true);
        String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced";
        EnsembleModel<T> output = (EnsembleModel<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance, newModels);
        copyMethod.setAccessible(accessible);
        return new ModelTuple<>(output, cleanedProvenance);
    } else {
        ModelProvenance oldProvenance = oldModel.getProvenance();
        ModelProvenance cleanedProvenance = cleanProvenance(oldProvenance, provenanceHash, opt);
        Class<? extends Model> clazz = oldModel.getClass();
        Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class);
        boolean accessible = copyMethod.isAccessible();
        copyMethod.setAccessible(true);
        String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced";
        Model<T> output = (Model<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance);
        copyMethod.setAccessible(accessible);
        return new ModelTuple<>(output, cleanedProvenance);
    }
}
Also used : ModelProvenance(org.tribuo.provenance.ModelProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) ArrayList(java.util.ArrayList) Method(java.lang.reflect.Method) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) DATASET(org.tribuo.json.StripProvenance.ProvenanceTypes.DATASET) Model(org.tribuo.Model) EnsembleModel(org.tribuo.ensemble.EnsembleModel) ListProvenance(com.oracle.labs.mlrg.olcut.provenance.ListProvenance) EnsembleModel(org.tribuo.ensemble.EnsembleModel)

Example 3 with EnsembleModelProvenance

use of org.tribuo.provenance.EnsembleModelProvenance in project tribuo by oracle.

the class BaggingTrainer method train.

@Override
public EnsembleModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
    // Creates a new RNG, adds one to the invocation count.
    SplittableRandom localRNG;
    TrainerProvenance trainerProvenance;
    synchronized (this) {
        if (invocationCount != INCREMENT_INVOCATION_COUNT) {
            setInvocationCount(invocationCount);
        }
        localRNG = rng.split();
        trainerProvenance = getProvenance();
        trainInvocationCounter++;
    }
    ImmutableFeatureMap featureIDs = examples.getFeatureIDMap();
    ImmutableOutputInfo<T> labelIDs = examples.getOutputIDInfo();
    ArrayList<Model<T>> models = new ArrayList<>();
    int initialInovcation = innerTrainer.getInvocationCount();
    for (int i = 0; i < numMembers; i++) {
        logger.info("Building model " + i);
        models.add(trainSingleModel(examples, featureIDs, labelIDs, localRNG.nextInt(), runProvenance, initialInovcation + i));
    }
    EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
    return new WeightedEnsembleModel<>(ensembleName(), provenance, featureIDs, labelIDs, models, combiner);
}
Also used : Model(org.tribuo.Model) ArrayList(java.util.ArrayList) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) SplittableRandom(java.util.SplittableRandom) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance)

Example 4 with EnsembleModelProvenance

use of org.tribuo.provenance.EnsembleModelProvenance in project tribuo by oracle.

the class WeightedEnsembleModel method createEnsembleFromExistingModels.

/**
 * Creates an ensemble from existing models.
 * <p>
 * Uses the feature and output domain from the first model as the ensemble model's domains.
 * The individual ensemble members use the domains that they contain.
 * <p>
 * If the output domains don't cover the same dimensions then it throws {@link IllegalArgumentException}.
 * If the weights aren't the same length as the models it throws {@link IllegalArgumentException}.
 * @param name The ensemble name.
 * @param models The ensemble members.
 * @param combiner The combination function.
 * @param weights The model combination weights.
 * @param <T> The output type.
 * @return A weighted ensemble model.
 */
public static <T extends Output<T>> WeightedEnsembleModel<T> createEnsembleFromExistingModels(String name, List<Model<T>> models, EnsembleCombiner<T> combiner, float[] weights) {
    // Basic parameter validation
    if (models.size() < 2) {
        throw new IllegalArgumentException("Must supply at least 2 models, found " + models.size());
    }
    if (weights.length != models.size()) {
        throw new IllegalArgumentException("Must supply one weight per model, models.size() = " + models.size() + ", weights.length = " + weights.length);
    }
    // Validate output & feature domains
    ImmutableOutputInfo<T> outputInfo = models.get(0).getOutputIDInfo();
    ImmutableFeatureMap featureMap = models.get(0).getFeatureIDMap();
    Set<T> firstOutputDomain = outputInfo.getDomain();
    for (int i = 1; i < models.size(); i++) {
        if (!models.get(i).getOutputIDInfo().getDomain().equals(firstOutputDomain)) {
            throw new IllegalArgumentException("Model output domains are not equal.");
        }
        if (!models.get(i).getFeatureIDMap().domainEquals(featureMap)) {
            throw new IllegalArgumentException("Model feature domains are not equal.");
        }
    }
    // Defensive copy the model list (the weights are copied in the constructor)
    List<Model<T>> modelList = new ArrayList<>(models);
    // Build EnsembleModelProvenance
    TimestampedTrainerProvenance trainerProvenance = new TimestampedTrainerProvenance();
    EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), models.get(0).getProvenance().getDatasetProvenance(), trainerProvenance, ListProvenance.createListProvenance(models));
    return new WeightedEnsembleModel<>(name, provenance, featureMap, outputInfo, modelList, combiner, weights);
}
Also used : TimestampedTrainerProvenance(org.tribuo.provenance.impl.TimestampedTrainerProvenance) Model(org.tribuo.Model) ArrayList(java.util.ArrayList) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance)

Example 5 with EnsembleModelProvenance

use of org.tribuo.provenance.EnsembleModelProvenance in project tribuo by oracle.

the class AdaBoostTrainer method train.

@Override
public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
    if (examples.getOutputInfo().getUnknownCount() > 0) {
        throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
    }
    // Creates a new RNG, adds one to the invocation count.
    SplittableRandom localRNG;
    TrainerProvenance trainerProvenance;
    synchronized (this) {
        if (invocationCount != INCREMENT_INVOCATION_COUNT) {
            setInvocationCount(invocationCount);
        }
        localRNG = rng.split();
        trainerProvenance = getProvenance();
        trainInvocationCounter++;
    }
    boolean weighted = innerTrainer instanceof WeightedExamples;
    ImmutableFeatureMap featureIDs = examples.getFeatureIDMap();
    ImmutableOutputInfo<Label> labelIDs = examples.getOutputIDInfo();
    int numClasses = labelIDs.size();
    logger.log(Level.INFO, "NumClasses = " + numClasses);
    ArrayList<Model<Label>> models = new ArrayList<>();
    float[] modelWeights = new float[numMembers];
    float[] exampleWeights = Util.generateUniformFloatVector(examples.size(), 1.0f / examples.size());
    if (weighted) {
        logger.info("Using weighted Adaboost.");
        examples = ImmutableDataset.copyDataset(examples);
        for (int i = 0; i < examples.size(); i++) {
            Example<Label> e = examples.getExample(i);
            e.setWeight(exampleWeights[i]);
        }
    } else {
        logger.info("Using sampling Adaboost.");
    }
    for (int i = 0; i < numMembers; i++) {
        logger.info("Building model " + i);
        Model<Label> newModel;
        if (weighted) {
            newModel = innerTrainer.train(examples);
        } else {
            DatasetView<Label> bag = DatasetView.createWeightedBootstrapView(examples, examples.size(), localRNG.nextLong(), exampleWeights, featureIDs, labelIDs);
            newModel = innerTrainer.train(bag);
        }
        // 
        // Score this model
        List<Prediction<Label>> predictions = newModel.predict(examples);
        float accuracy = accuracy(predictions, examples, exampleWeights);
        float error = 1.0f - accuracy;
        float alpha = (float) (Math.log(accuracy / error) + Math.log(numClasses - 1));
        models.add(newModel);
        modelWeights[i] = alpha;
        if ((accuracy + 1e-10) > 1.0) {
            // 
            // Perfect accuracy, can no longer boost.
            float[] newModelWeights = Arrays.copyOf(modelWeights, models.size());
            // Set the last weight to 1, as it's infinity.
            newModelWeights[models.size() - 1] = 1.0f;
            logger.log(Level.FINE, "Perfect accuracy reached on iteration " + i + ", returning current model.");
            logger.log(Level.FINE, "Model weights:");
            Util.logVector(logger, Level.FINE, newModelWeights);
            EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
            return new WeightedEnsembleModel<>("boosted-ensemble", provenance, featureIDs, labelIDs, models, new VotingCombiner(), newModelWeights);
        }
        // Update the weights
        for (int j = 0; j < predictions.size(); j++) {
            if (!predictions.get(j).getOutput().equals(examples.getExample(j).getOutput())) {
                exampleWeights[j] *= Math.exp(alpha);
            }
        }
        Util.inplaceNormalizeToDistribution(exampleWeights);
        if (weighted) {
            for (int j = 0; j < examples.size(); j++) {
                examples.getExample(j).setWeight(exampleWeights[j]);
            }
        }
    }
    logger.log(Level.FINE, "Model weights:");
    Util.logVector(logger, Level.FINE, modelWeights);
    EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
    return new WeightedEnsembleModel<>("boosted-ensemble", provenance, featureIDs, labelIDs, models, new VotingCombiner(), modelWeights);
}
Also used : Prediction(org.tribuo.Prediction) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) WeightedExamples(org.tribuo.WeightedExamples) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) Model(org.tribuo.Model) WeightedEnsembleModel(org.tribuo.ensemble.WeightedEnsembleModel) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) WeightedEnsembleModel(org.tribuo.ensemble.WeightedEnsembleModel) SplittableRandom(java.util.SplittableRandom) TrainerProvenance(org.tribuo.provenance.TrainerProvenance)

Aggregations

EnsembleModelProvenance (org.tribuo.provenance.EnsembleModelProvenance)6 ArrayList (java.util.ArrayList)5 Model (org.tribuo.Model)5 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)4 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)4 SplittableRandom (java.util.SplittableRandom)3 ListProvenance (com.oracle.labs.mlrg.olcut.provenance.ListProvenance)2 WeightedEnsembleModel (org.tribuo.ensemble.WeightedEnsembleModel)2 ModelProvenance (org.tribuo.provenance.ModelProvenance)2 ObjectProvenance (com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance)1 Provenance (com.oracle.labs.mlrg.olcut.provenance.Provenance)1 HashProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance)1 Method (java.lang.reflect.Method)1 OffsetDateTime (java.time.OffsetDateTime)1 Prediction (org.tribuo.Prediction)1 WeightedExamples (org.tribuo.WeightedExamples)1 Label (org.tribuo.classification.Label)1 EnsembleModel (org.tribuo.ensemble.EnsembleModel)1 DATASET (org.tribuo.json.StripProvenance.ProvenanceTypes.DATASET)1 MultiLabel (org.tribuo.multilabel.MultiLabel)1