Search in sources :

Example 1 with ModelProvenance

use of org.tribuo.provenance.ModelProvenance in project tribuo by oracle.

the class CRFTrainer method train.

@Override
public CRFModel train(SequenceDataset<Label> sequenceExamples, Map<String, Provenance> runProvenance) {
    if (sequenceExamples.getOutputInfo().getUnknownCount() > 0) {
        throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
    }
    // Creates a new RNG, adds one to the invocation count, generates a local optimiser.
    SplittableRandom localRNG;
    TrainerProvenance trainerProvenance;
    StochasticGradientOptimiser localOptimiser;
    synchronized (this) {
        localRNG = rng.split();
        localOptimiser = optimiser.copy();
        trainerProvenance = getProvenance();
        trainInvocationCounter++;
    }
    ImmutableOutputInfo<Label> labelIDMap = sequenceExamples.getOutputIDInfo();
    ImmutableFeatureMap featureIDMap = sequenceExamples.getFeatureIDMap();
    SGDVector[][] sgdFeatures = new SGDVector[sequenceExamples.size()][];
    int[][] sgdLabels = new int[sequenceExamples.size()][];
    double[] weights = new double[sequenceExamples.size()];
    int n = 0;
    for (SequenceExample<Label> example : sequenceExamples) {
        weights[n] = example.getWeight();
        Pair<int[], SGDVector[]> pair = CRFModel.convertToVector(example, featureIDMap, labelIDMap);
        sgdFeatures[n] = pair.getB();
        sgdLabels[n] = pair.getA();
        n++;
    }
    logger.info(String.format("Training SGD CRF with %d examples", n));
    CRFParameters crfParameters = new CRFParameters(featureIDMap.size(), labelIDMap.size());
    localOptimiser.initialise(crfParameters);
    double loss = 0.0;
    int iteration = 0;
    for (int i = 0; i < epochs; i++) {
        if (shuffle) {
            Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, localRNG);
        }
        if (minibatchSize == 1) {
            /*
                 * Special case a minibatch of size 1. Directly updates the parameters after each
                 * example rather than aggregating.
                 */
            for (int j = 0; j < sgdFeatures.length; j++) {
                Pair<Double, Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j], sgdLabels[j]);
                loss += output.getA() * weights[j];
                // Update the gradient with the current learning rates
                Tensor[] updates = localOptimiser.step(output.getB(), weights[j]);
                // Apply the update to the current parameters.
                crfParameters.update(updates);
                iteration++;
                if ((iteration % loggingInterval == 0) && (loggingInterval != -1)) {
                    logger.info("At iteration " + iteration + ", average loss = " + loss / loggingInterval);
                    loss = 0.0;
                }
            }
        } else {
            Tensor[][] gradients = new Tensor[minibatchSize][];
            for (int j = 0; j < sgdFeatures.length; j += minibatchSize) {
                double tempWeight = 0.0;
                int curSize = 0;
                // Aggregate the gradient updates for each example in the minibatch
                for (int k = j; k < j + minibatchSize && k < sgdFeatures.length; k++) {
                    Pair<Double, Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j], sgdLabels[j]);
                    loss += output.getA() * weights[k];
                    tempWeight += weights[k];
                    gradients[k - j] = output.getB();
                    curSize++;
                }
                // Merge the values into a single gradient update
                Tensor[] updates = crfParameters.merge(gradients, curSize);
                for (Tensor update : updates) {
                    update.scaleInPlace(minibatchSize);
                }
                tempWeight /= minibatchSize;
                // Update the gradient with the current learning rates
                updates = localOptimiser.step(updates, tempWeight);
                // Apply the gradient.
                crfParameters.update(updates);
                iteration++;
                if ((loggingInterval != -1) && (iteration % loggingInterval == 0)) {
                    logger.info("At iteration " + iteration + ", average loss = " + loss / loggingInterval);
                    loss = 0.0;
                }
            }
        }
    }
    localOptimiser.finalise();
    // public CRFModel(String name, String description, ImmutableInfoMap featureIDMap, ImmutableInfoMap outputIDInfo, CRFParameters parameters) {
    ModelProvenance provenance = new ModelProvenance(CRFModel.class.getName(), OffsetDateTime.now(), sequenceExamples.getProvenance(), trainerProvenance, runProvenance);
    CRFModel model = new CRFModel("crf-sgd-model", provenance, featureIDMap, labelIDMap, crfParameters);
    localOptimiser.reset();
    return model;
}
Also used : Tensor(org.tribuo.math.la.Tensor) ModelProvenance(org.tribuo.provenance.ModelProvenance) Label(org.tribuo.classification.Label) StochasticGradientOptimiser(org.tribuo.math.StochasticGradientOptimiser) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) SGDVector(org.tribuo.math.la.SGDVector) SplittableRandom(java.util.SplittableRandom) TrainerProvenance(org.tribuo.provenance.TrainerProvenance)

Example 2 with ModelProvenance

use of org.tribuo.provenance.ModelProvenance in project tribuo by oracle.

the class HdbscanTrainer method train.

@Override
public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
    // increment the invocation count.
    TrainerProvenance trainerProvenance;
    synchronized (this) {
        trainerProvenance = getProvenance();
        trainInvocationCounter++;
    }
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    SGDVector[] data = new SGDVector[examples.size()];
    int n = 0;
    for (Example<ClusterID> example : examples) {
        if (example.size() == featureMap.size()) {
            data[n] = DenseVector.createDenseVector(example, featureMap, false);
        } else {
            data[n] = SparseVector.createSparseVector(example, featureMap, false);
        }
        n++;
    }
    DenseVector coreDistances = calculateCoreDistances(data, k, neighboursQueryFactory);
    ExtendedMinimumSpanningTree emst = constructEMST(data, coreDistances, distType);
    // The levels at which each point becomes noise
    double[] pointNoiseLevels = new double[data.length];
    // The last label of each point before becoming noise
    int[] pointLastClusters = new int[data.length];
    // The HDBSCAN* hierarchy
    Map<Integer, int[]> hierarchy = new HashMap<>();
    List<HdbscanCluster> clusters = computeHierarchyAndClusterTree(emst, minClusterSize, pointNoiseLevels, pointLastClusters, hierarchy);
    propagateTree(clusters);
    List<Integer> clusterLabels = findProminentClusters(hierarchy, clusters, data.length);
    DenseVector outlierScoresVector = calculateOutlierScores(pointNoiseLevels, pointLastClusters, clusters);
    Map<Integer, List<Pair<Double, Integer>>> clusterAssignments = generateClusterAssignments(clusterLabels, outlierScoresVector);
    // Use the cluster assignments to establish the clustering info
    Map<Integer, MutableLong> counts = new HashMap<>();
    for (Entry<Integer, List<Pair<Double, Integer>>> e : clusterAssignments.entrySet()) {
        counts.put(e.getKey(), new MutableLong(e.getValue().size()));
    }
    ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
    // Compute the cluster exemplars.
    List<ClusterExemplar> clusterExemplars = computeExemplars(data, clusterAssignments, distType);
    // Get the outlier score value for points that are predicted as noise points.
    double noisePointsOutlierScore = getNoisePointsOutlierScore(clusterAssignments);
    logger.log(Level.INFO, "Hdbscan is done.");
    ModelProvenance provenance = new ModelProvenance(HdbscanModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
    return new HdbscanModel("hdbscan-model", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector, clusterExemplars, distType, noisePointsOutlierScore);
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) HashMap(java.util.HashMap) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) SGDVector(org.tribuo.math.la.SGDVector) ArrayList(java.util.ArrayList) List(java.util.List) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong) DenseVector(org.tribuo.math.la.DenseVector) ImmutableClusteringInfo(org.tribuo.clustering.ImmutableClusteringInfo)

Example 3 with ModelProvenance

use of org.tribuo.provenance.ModelProvenance in project tribuo by oracle.

the class XGBoostClassificationTrainer method train.

@Override
public synchronized XGBoostModel<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
    if (examples.getOutputInfo().getUnknownCount() > 0) {
        throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
    }
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    ImmutableOutputInfo<Label> outputInfo = examples.getOutputIDInfo();
    if (invocationCount != INCREMENT_INVOCATION_COUNT) {
        setInvocationCount(invocationCount);
    }
    TrainerProvenance trainerProvenance = getProvenance();
    trainInvocationCounter++;
    parameters.put("num_class", outputInfo.size());
    Booster model;
    Function<Label, Float> responseExtractor = (Label l) -> (float) outputInfo.getID(l);
    try {
        DMatrixTuple<Label> trainingData = convertExamples(examples, featureMap, responseExtractor);
        model = XGBoost.train(trainingData.data, parameters, numTrees, Collections.emptyMap(), null, null);
    } catch (XGBoostError e) {
        logger.log(Level.SEVERE, "XGBoost threw an error", e);
        throw new IllegalStateException(e);
    }
    ModelProvenance provenance = new ModelProvenance(XGBoostModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
    XGBoostModel<Label> xgModel = createModel("xgboost-classification-model", provenance, featureMap, outputInfo, Collections.singletonList(model), new XGBoostClassificationConverter());
    return xgModel;
}
Also used : Booster(ml.dmlc.xgboost4j.java.Booster) ModelProvenance(org.tribuo.provenance.ModelProvenance) Label(org.tribuo.classification.Label) XGBoostError(ml.dmlc.xgboost4j.java.XGBoostError) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) XGBoostModel(org.tribuo.common.xgboost.XGBoostModel)

Example 4 with ModelProvenance

use of org.tribuo.provenance.ModelProvenance in project tribuo by oracle.

the class XGBoostRegressionTrainer method train.

@Override
public synchronized XGBoostModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance, int invocationCount) {
    if (examples.getOutputInfo().getUnknownCount() > 0) {
        throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
    }
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo();
    int numOutputs = outputInfo.size();
    if (invocationCount != INCREMENT_INVOCATION_COUNT) {
        setInvocationCount(invocationCount);
    }
    TrainerProvenance trainerProvenance = getProvenance();
    trainInvocationCounter++;
    List<Booster> models = new ArrayList<>();
    try {
        // Use a null response extractor as we'll do the per dimension regression extraction later.
        DMatrixTuple<Regressor> trainingData = convertExamples(examples, featureMap, null);
        // Map the natural order into ids
        int[] dimensionIds = ((ImmutableRegressionInfo) outputInfo).getNaturalOrderToIDMapping();
        // Extract the weights and the regression targets.
        float[][] outputs = new float[numOutputs][examples.size()];
        float[] weights = new float[examples.size()];
        int i = 0;
        for (Example<Regressor> e : examples) {
            weights[i] = e.getWeight();
            double[] curOutputs = e.getOutput().getValues();
            // Transpose them for easy training.
            for (int j = 0; j < numOutputs; j++) {
                outputs[dimensionIds[j]][i] = (float) curOutputs[j];
            }
            i++;
        }
        trainingData.data.setWeight(weights);
        // Finished setup, now train one model per dimension.
        for (i = 0; i < numOutputs; i++) {
            trainingData.data.setLabel(outputs[i]);
            models.add(XGBoost.train(trainingData.data, parameters, numTrees, Collections.emptyMap(), null, null));
        }
    } catch (XGBoostError e) {
        logger.log(Level.SEVERE, "XGBoost threw an error", e);
        throw new IllegalStateException(e);
    }
    ModelProvenance provenance = new ModelProvenance(XGBoostModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
    XGBoostModel<Regressor> xgModel = createModel("xgboost-regression-model", provenance, featureMap, outputInfo, models, new XGBoostRegressionConverter());
    return xgModel;
}
Also used : Booster(ml.dmlc.xgboost4j.java.Booster) ModelProvenance(org.tribuo.provenance.ModelProvenance) ArrayList(java.util.ArrayList) XGBoostError(ml.dmlc.xgboost4j.java.XGBoostError) ImmutableRegressionInfo(org.tribuo.regression.ImmutableRegressionInfo) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) Regressor(org.tribuo.regression.Regressor) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) XGBoostModel(org.tribuo.common.xgboost.XGBoostModel)

Example 5 with ModelProvenance

use of org.tribuo.provenance.ModelProvenance in project tribuo by oracle.

the class StripProvenance method cleanEnsembleProvenance.

/**
 * Creates a new ensemble provenance with the requested information removed.
 * @param old The old ensemble provenance.
 * @param memberProvenance The new member provenances.
 * @param provenanceHash The old ensemble provenance hash.
 * @param opt The program options.
 * @return The new ensemble provenance with the requested fields removed.
 */
private static EnsembleModelProvenance cleanEnsembleProvenance(EnsembleModelProvenance old, ListProvenance<ModelProvenance> memberProvenance, String provenanceHash, StripProvenanceOptions opt) {
    // Dataset provenance
    DatasetProvenance datasetProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(DATASET)) {
        datasetProvenance = new EmptyDatasetProvenance();
    } else {
        datasetProvenance = old.getDatasetProvenance();
    }
    // Trainer provenance
    TrainerProvenance trainerProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(TRAINER)) {
        trainerProvenance = new EmptyTrainerProvenance();
    } else {
        trainerProvenance = old.getTrainerProvenance();
    }
    // Instance provenance
    OffsetDateTime time;
    Map<String, Provenance> instanceProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(INSTANCE)) {
        instanceProvenance = new HashMap<>();
        time = OffsetDateTime.MIN;
    } else {
        instanceProvenance = new HashMap<>(old.getInstanceProvenance().getMap());
        time = old.getTrainingTime();
    }
    if (opt.storeHash) {
        logger.info("Writing provenance hash into instance map.");
        instanceProvenance.put("original-provenance-hash", new HashProvenance(opt.hashType, "original-provenance-hash", provenanceHash));
    }
    return new EnsembleModelProvenance(old.getClassName(), time, datasetProvenance, trainerProvenance, instanceProvenance, memberProvenance);
}
Also used : HashProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance) ObjectProvenance(com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) HashProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance) Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) ListProvenance(com.oracle.labs.mlrg.olcut.provenance.ListProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) OffsetDateTime(java.time.OffsetDateTime) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance)

Aggregations

ModelProvenance (org.tribuo.provenance.ModelProvenance)42 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)27 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)24 SplittableRandom (java.util.SplittableRandom)13 HashMap (java.util.HashMap)11 DatasetProvenance (org.tribuo.provenance.DatasetProvenance)10 Regressor (org.tribuo.regression.Regressor)10 ArrayList (java.util.ArrayList)9 Label (org.tribuo.classification.Label)9 OffsetDateTime (java.time.OffsetDateTime)7 OrtEnvironment (ai.onnxruntime.OrtEnvironment)5 OrtSession (ai.onnxruntime.OrtSession)5 Provenance (com.oracle.labs.mlrg.olcut.provenance.Provenance)5 Path (java.nio.file.Path)5 ExternalDatasetProvenance (org.tribuo.interop.ExternalDatasetProvenance)5 SGDVector (org.tribuo.math.la.SGDVector)5 SparseVector (org.tribuo.math.la.SparseVector)5 IOException (java.io.IOException)4 Map (java.util.Map)4 Prediction (org.tribuo.Prediction)4