Search in sources :

Example 1 with TrainerProvenance

use of org.tribuo.provenance.TrainerProvenance in project tribuo by oracle.

the class CRFTrainer method train.

@Override
public CRFModel train(SequenceDataset<Label> sequenceExamples, Map<String, Provenance> runProvenance) {
    if (sequenceExamples.getOutputInfo().getUnknownCount() > 0) {
        throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
    }
    // Creates a new RNG, adds one to the invocation count, generates a local optimiser.
    SplittableRandom localRNG;
    TrainerProvenance trainerProvenance;
    StochasticGradientOptimiser localOptimiser;
    synchronized (this) {
        localRNG = rng.split();
        localOptimiser = optimiser.copy();
        trainerProvenance = getProvenance();
        trainInvocationCounter++;
    }
    ImmutableOutputInfo<Label> labelIDMap = sequenceExamples.getOutputIDInfo();
    ImmutableFeatureMap featureIDMap = sequenceExamples.getFeatureIDMap();
    SGDVector[][] sgdFeatures = new SGDVector[sequenceExamples.size()][];
    int[][] sgdLabels = new int[sequenceExamples.size()][];
    double[] weights = new double[sequenceExamples.size()];
    int n = 0;
    for (SequenceExample<Label> example : sequenceExamples) {
        weights[n] = example.getWeight();
        Pair<int[], SGDVector[]> pair = CRFModel.convertToVector(example, featureIDMap, labelIDMap);
        sgdFeatures[n] = pair.getB();
        sgdLabels[n] = pair.getA();
        n++;
    }
    logger.info(String.format("Training SGD CRF with %d examples", n));
    CRFParameters crfParameters = new CRFParameters(featureIDMap.size(), labelIDMap.size());
    localOptimiser.initialise(crfParameters);
    double loss = 0.0;
    int iteration = 0;
    for (int i = 0; i < epochs; i++) {
        if (shuffle) {
            Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, localRNG);
        }
        if (minibatchSize == 1) {
            /*
                 * Special case a minibatch of size 1. Directly updates the parameters after each
                 * example rather than aggregating.
                 */
            for (int j = 0; j < sgdFeatures.length; j++) {
                Pair<Double, Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j], sgdLabels[j]);
                loss += output.getA() * weights[j];
                // Update the gradient with the current learning rates
                Tensor[] updates = localOptimiser.step(output.getB(), weights[j]);
                // Apply the update to the current parameters.
                crfParameters.update(updates);
                iteration++;
                if ((iteration % loggingInterval == 0) && (loggingInterval != -1)) {
                    logger.info("At iteration " + iteration + ", average loss = " + loss / loggingInterval);
                    loss = 0.0;
                }
            }
        } else {
            Tensor[][] gradients = new Tensor[minibatchSize][];
            for (int j = 0; j < sgdFeatures.length; j += minibatchSize) {
                double tempWeight = 0.0;
                int curSize = 0;
                // Aggregate the gradient updates for each example in the minibatch
                for (int k = j; k < j + minibatchSize && k < sgdFeatures.length; k++) {
                    Pair<Double, Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j], sgdLabels[j]);
                    loss += output.getA() * weights[k];
                    tempWeight += weights[k];
                    gradients[k - j] = output.getB();
                    curSize++;
                }
                // Merge the values into a single gradient update
                Tensor[] updates = crfParameters.merge(gradients, curSize);
                for (Tensor update : updates) {
                    update.scaleInPlace(minibatchSize);
                }
                tempWeight /= minibatchSize;
                // Update the gradient with the current learning rates
                updates = localOptimiser.step(updates, tempWeight);
                // Apply the gradient.
                crfParameters.update(updates);
                iteration++;
                if ((loggingInterval != -1) && (iteration % loggingInterval == 0)) {
                    logger.info("At iteration " + iteration + ", average loss = " + loss / loggingInterval);
                    loss = 0.0;
                }
            }
        }
    }
    localOptimiser.finalise();
    // public CRFModel(String name, String description, ImmutableInfoMap featureIDMap, ImmutableInfoMap outputIDInfo, CRFParameters parameters) {
    ModelProvenance provenance = new ModelProvenance(CRFModel.class.getName(), OffsetDateTime.now(), sequenceExamples.getProvenance(), trainerProvenance, runProvenance);
    CRFModel model = new CRFModel("crf-sgd-model", provenance, featureIDMap, labelIDMap, crfParameters);
    localOptimiser.reset();
    return model;
}
Also used : Tensor(org.tribuo.math.la.Tensor) ModelProvenance(org.tribuo.provenance.ModelProvenance) Label(org.tribuo.classification.Label) StochasticGradientOptimiser(org.tribuo.math.StochasticGradientOptimiser) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) SGDVector(org.tribuo.math.la.SGDVector) SplittableRandom(java.util.SplittableRandom) TrainerProvenance(org.tribuo.provenance.TrainerProvenance)

Example 2 with TrainerProvenance

use of org.tribuo.provenance.TrainerProvenance in project tribuo by oracle.

the class HdbscanTrainer method train.

@Override
public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
    // increment the invocation count.
    TrainerProvenance trainerProvenance;
    synchronized (this) {
        trainerProvenance = getProvenance();
        trainInvocationCounter++;
    }
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    SGDVector[] data = new SGDVector[examples.size()];
    int n = 0;
    for (Example<ClusterID> example : examples) {
        if (example.size() == featureMap.size()) {
            data[n] = DenseVector.createDenseVector(example, featureMap, false);
        } else {
            data[n] = SparseVector.createSparseVector(example, featureMap, false);
        }
        n++;
    }
    DenseVector coreDistances = calculateCoreDistances(data, k, neighboursQueryFactory);
    ExtendedMinimumSpanningTree emst = constructEMST(data, coreDistances, distType);
    // The levels at which each point becomes noise
    double[] pointNoiseLevels = new double[data.length];
    // The last label of each point before becoming noise
    int[] pointLastClusters = new int[data.length];
    // The HDBSCAN* hierarchy
    Map<Integer, int[]> hierarchy = new HashMap<>();
    List<HdbscanCluster> clusters = computeHierarchyAndClusterTree(emst, minClusterSize, pointNoiseLevels, pointLastClusters, hierarchy);
    propagateTree(clusters);
    List<Integer> clusterLabels = findProminentClusters(hierarchy, clusters, data.length);
    DenseVector outlierScoresVector = calculateOutlierScores(pointNoiseLevels, pointLastClusters, clusters);
    Map<Integer, List<Pair<Double, Integer>>> clusterAssignments = generateClusterAssignments(clusterLabels, outlierScoresVector);
    // Use the cluster assignments to establish the clustering info
    Map<Integer, MutableLong> counts = new HashMap<>();
    for (Entry<Integer, List<Pair<Double, Integer>>> e : clusterAssignments.entrySet()) {
        counts.put(e.getKey(), new MutableLong(e.getValue().size()));
    }
    ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
    // Compute the cluster exemplars.
    List<ClusterExemplar> clusterExemplars = computeExemplars(data, clusterAssignments, distType);
    // Get the outlier score value for points that are predicted as noise points.
    double noisePointsOutlierScore = getNoisePointsOutlierScore(clusterAssignments);
    logger.log(Level.INFO, "Hdbscan is done.");
    ModelProvenance provenance = new ModelProvenance(HdbscanModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
    return new HdbscanModel("hdbscan-model", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector, clusterExemplars, distType, noisePointsOutlierScore);
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) HashMap(java.util.HashMap) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) SGDVector(org.tribuo.math.la.SGDVector) ArrayList(java.util.ArrayList) List(java.util.List) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong) DenseVector(org.tribuo.math.la.DenseVector) ImmutableClusteringInfo(org.tribuo.clustering.ImmutableClusteringInfo)

Example 3 with TrainerProvenance

use of org.tribuo.provenance.TrainerProvenance in project tribuo by oracle.

the class XGBoostClassificationTrainer method train.

@Override
public synchronized XGBoostModel<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
    if (examples.getOutputInfo().getUnknownCount() > 0) {
        throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
    }
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    ImmutableOutputInfo<Label> outputInfo = examples.getOutputIDInfo();
    if (invocationCount != INCREMENT_INVOCATION_COUNT) {
        setInvocationCount(invocationCount);
    }
    TrainerProvenance trainerProvenance = getProvenance();
    trainInvocationCounter++;
    parameters.put("num_class", outputInfo.size());
    Booster model;
    Function<Label, Float> responseExtractor = (Label l) -> (float) outputInfo.getID(l);
    try {
        DMatrixTuple<Label> trainingData = convertExamples(examples, featureMap, responseExtractor);
        model = XGBoost.train(trainingData.data, parameters, numTrees, Collections.emptyMap(), null, null);
    } catch (XGBoostError e) {
        logger.log(Level.SEVERE, "XGBoost threw an error", e);
        throw new IllegalStateException(e);
    }
    ModelProvenance provenance = new ModelProvenance(XGBoostModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
    XGBoostModel<Label> xgModel = createModel("xgboost-classification-model", provenance, featureMap, outputInfo, Collections.singletonList(model), new XGBoostClassificationConverter());
    return xgModel;
}
Also used : Booster(ml.dmlc.xgboost4j.java.Booster) ModelProvenance(org.tribuo.provenance.ModelProvenance) Label(org.tribuo.classification.Label) XGBoostError(ml.dmlc.xgboost4j.java.XGBoostError) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) XGBoostModel(org.tribuo.common.xgboost.XGBoostModel)

Example 4 with TrainerProvenance

use of org.tribuo.provenance.TrainerProvenance in project tribuo by oracle.

the class XGBoostRegressionTrainer method train.

@Override
public synchronized XGBoostModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance, int invocationCount) {
    if (examples.getOutputInfo().getUnknownCount() > 0) {
        throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
    }
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo();
    int numOutputs = outputInfo.size();
    if (invocationCount != INCREMENT_INVOCATION_COUNT) {
        setInvocationCount(invocationCount);
    }
    TrainerProvenance trainerProvenance = getProvenance();
    trainInvocationCounter++;
    List<Booster> models = new ArrayList<>();
    try {
        // Use a null response extractor as we'll do the per dimension regression extraction later.
        DMatrixTuple<Regressor> trainingData = convertExamples(examples, featureMap, null);
        // Map the natural order into ids
        int[] dimensionIds = ((ImmutableRegressionInfo) outputInfo).getNaturalOrderToIDMapping();
        // Extract the weights and the regression targets.
        float[][] outputs = new float[numOutputs][examples.size()];
        float[] weights = new float[examples.size()];
        int i = 0;
        for (Example<Regressor> e : examples) {
            weights[i] = e.getWeight();
            double[] curOutputs = e.getOutput().getValues();
            // Transpose them for easy training.
            for (int j = 0; j < numOutputs; j++) {
                outputs[dimensionIds[j]][i] = (float) curOutputs[j];
            }
            i++;
        }
        trainingData.data.setWeight(weights);
        // Finished setup, now train one model per dimension.
        for (i = 0; i < numOutputs; i++) {
            trainingData.data.setLabel(outputs[i]);
            models.add(XGBoost.train(trainingData.data, parameters, numTrees, Collections.emptyMap(), null, null));
        }
    } catch (XGBoostError e) {
        logger.log(Level.SEVERE, "XGBoost threw an error", e);
        throw new IllegalStateException(e);
    }
    ModelProvenance provenance = new ModelProvenance(XGBoostModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
    XGBoostModel<Regressor> xgModel = createModel("xgboost-regression-model", provenance, featureMap, outputInfo, models, new XGBoostRegressionConverter());
    return xgModel;
}
Also used : Booster(ml.dmlc.xgboost4j.java.Booster) ModelProvenance(org.tribuo.provenance.ModelProvenance) ArrayList(java.util.ArrayList) XGBoostError(ml.dmlc.xgboost4j.java.XGBoostError) ImmutableRegressionInfo(org.tribuo.regression.ImmutableRegressionInfo) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) Regressor(org.tribuo.regression.Regressor) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) XGBoostModel(org.tribuo.common.xgboost.XGBoostModel)

Example 5 with TrainerProvenance

use of org.tribuo.provenance.TrainerProvenance in project tribuo by oracle.

the class StripProvenance method cleanEnsembleProvenance.

/**
 * Creates a new ensemble provenance with the requested information removed.
 * @param old The old ensemble provenance.
 * @param memberProvenance The new member provenances.
 * @param provenanceHash The old ensemble provenance hash.
 * @param opt The program options.
 * @return The new ensemble provenance with the requested fields removed.
 */
private static EnsembleModelProvenance cleanEnsembleProvenance(EnsembleModelProvenance old, ListProvenance<ModelProvenance> memberProvenance, String provenanceHash, StripProvenanceOptions opt) {
    // Dataset provenance
    DatasetProvenance datasetProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(DATASET)) {
        datasetProvenance = new EmptyDatasetProvenance();
    } else {
        datasetProvenance = old.getDatasetProvenance();
    }
    // Trainer provenance
    TrainerProvenance trainerProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(TRAINER)) {
        trainerProvenance = new EmptyTrainerProvenance();
    } else {
        trainerProvenance = old.getTrainerProvenance();
    }
    // Instance provenance
    OffsetDateTime time;
    Map<String, Provenance> instanceProvenance;
    if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(INSTANCE)) {
        instanceProvenance = new HashMap<>();
        time = OffsetDateTime.MIN;
    } else {
        instanceProvenance = new HashMap<>(old.getInstanceProvenance().getMap());
        time = old.getTrainingTime();
    }
    if (opt.storeHash) {
        logger.info("Writing provenance hash into instance map.");
        instanceProvenance.put("original-provenance-hash", new HashProvenance(opt.hashType, "original-provenance-hash", provenanceHash));
    }
    return new EnsembleModelProvenance(old.getClassName(), time, datasetProvenance, trainerProvenance, instanceProvenance, memberProvenance);
}
Also used : HashProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance) ObjectProvenance(com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) HashProvenance(com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance) Provenance(com.oracle.labs.mlrg.olcut.provenance.Provenance) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) ListProvenance(com.oracle.labs.mlrg.olcut.provenance.ListProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) OffsetDateTime(java.time.OffsetDateTime) DatasetProvenance(org.tribuo.provenance.DatasetProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance) EmptyTrainerProvenance(org.tribuo.provenance.impl.EmptyTrainerProvenance) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) EmptyDatasetProvenance(org.tribuo.provenance.impl.EmptyDatasetProvenance)

Aggregations

TrainerProvenance (org.tribuo.provenance.TrainerProvenance)27 ModelProvenance (org.tribuo.provenance.ModelProvenance)24 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)22 SplittableRandom (java.util.SplittableRandom)16 ArrayList (java.util.ArrayList)10 Label (org.tribuo.classification.Label)8 Regressor (org.tribuo.regression.Regressor)6 HashMap (java.util.HashMap)5 Model (org.tribuo.Model)5 SGDVector (org.tribuo.math.la.SGDVector)5 SparseVector (org.tribuo.math.la.SparseVector)5 EnsembleModelProvenance (org.tribuo.provenance.EnsembleModelProvenance)5 Feature (org.tribuo.Feature)3 DatasetProvenance (org.tribuo.provenance.DatasetProvenance)3 ListProvenance (com.oracle.labs.mlrg.olcut.provenance.ListProvenance)2 ObjectProvenance (com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance)2 Provenance (com.oracle.labs.mlrg.olcut.provenance.Provenance)2 HashProvenance (com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance)2 MutableLong (com.oracle.labs.mlrg.olcut.util.MutableLong)2 OffsetDateTime (java.time.OffsetDateTime)2