use of org.tribuo.provenance.TrainerProvenance in project tribuo by oracle.
the class CRFTrainer method train.
@Override
public CRFModel train(SequenceDataset<Label> sequenceExamples, Map<String, Provenance> runProvenance) {
if (sequenceExamples.getOutputInfo().getUnknownCount() > 0) {
throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
}
// Creates a new RNG, adds one to the invocation count, generates a local optimiser.
SplittableRandom localRNG;
TrainerProvenance trainerProvenance;
StochasticGradientOptimiser localOptimiser;
synchronized (this) {
localRNG = rng.split();
localOptimiser = optimiser.copy();
trainerProvenance = getProvenance();
trainInvocationCounter++;
}
ImmutableOutputInfo<Label> labelIDMap = sequenceExamples.getOutputIDInfo();
ImmutableFeatureMap featureIDMap = sequenceExamples.getFeatureIDMap();
SGDVector[][] sgdFeatures = new SGDVector[sequenceExamples.size()][];
int[][] sgdLabels = new int[sequenceExamples.size()][];
double[] weights = new double[sequenceExamples.size()];
int n = 0;
for (SequenceExample<Label> example : sequenceExamples) {
weights[n] = example.getWeight();
Pair<int[], SGDVector[]> pair = CRFModel.convertToVector(example, featureIDMap, labelIDMap);
sgdFeatures[n] = pair.getB();
sgdLabels[n] = pair.getA();
n++;
}
logger.info(String.format("Training SGD CRF with %d examples", n));
CRFParameters crfParameters = new CRFParameters(featureIDMap.size(), labelIDMap.size());
localOptimiser.initialise(crfParameters);
double loss = 0.0;
int iteration = 0;
for (int i = 0; i < epochs; i++) {
if (shuffle) {
Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, localRNG);
}
if (minibatchSize == 1) {
/*
* Special case a minibatch of size 1. Directly updates the parameters after each
* example rather than aggregating.
*/
for (int j = 0; j < sgdFeatures.length; j++) {
Pair<Double, Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j], sgdLabels[j]);
loss += output.getA() * weights[j];
// Update the gradient with the current learning rates
Tensor[] updates = localOptimiser.step(output.getB(), weights[j]);
// Apply the update to the current parameters.
crfParameters.update(updates);
iteration++;
if ((iteration % loggingInterval == 0) && (loggingInterval != -1)) {
logger.info("At iteration " + iteration + ", average loss = " + loss / loggingInterval);
loss = 0.0;
}
}
} else {
Tensor[][] gradients = new Tensor[minibatchSize][];
for (int j = 0; j < sgdFeatures.length; j += minibatchSize) {
double tempWeight = 0.0;
int curSize = 0;
// Aggregate the gradient updates for each example in the minibatch
for (int k = j; k < j + minibatchSize && k < sgdFeatures.length; k++) {
Pair<Double, Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j], sgdLabels[j]);
loss += output.getA() * weights[k];
tempWeight += weights[k];
gradients[k - j] = output.getB();
curSize++;
}
// Merge the values into a single gradient update
Tensor[] updates = crfParameters.merge(gradients, curSize);
for (Tensor update : updates) {
update.scaleInPlace(minibatchSize);
}
tempWeight /= minibatchSize;
// Update the gradient with the current learning rates
updates = localOptimiser.step(updates, tempWeight);
// Apply the gradient.
crfParameters.update(updates);
iteration++;
if ((loggingInterval != -1) && (iteration % loggingInterval == 0)) {
logger.info("At iteration " + iteration + ", average loss = " + loss / loggingInterval);
loss = 0.0;
}
}
}
}
localOptimiser.finalise();
// public CRFModel(String name, String description, ImmutableInfoMap featureIDMap, ImmutableInfoMap outputIDInfo, CRFParameters parameters) {
ModelProvenance provenance = new ModelProvenance(CRFModel.class.getName(), OffsetDateTime.now(), sequenceExamples.getProvenance(), trainerProvenance, runProvenance);
CRFModel model = new CRFModel("crf-sgd-model", provenance, featureIDMap, labelIDMap, crfParameters);
localOptimiser.reset();
return model;
}
use of org.tribuo.provenance.TrainerProvenance in project tribuo by oracle.
the class HdbscanTrainer method train.
@Override
public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
// increment the invocation count.
TrainerProvenance trainerProvenance;
synchronized (this) {
trainerProvenance = getProvenance();
trainInvocationCounter++;
}
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
SGDVector[] data = new SGDVector[examples.size()];
int n = 0;
for (Example<ClusterID> example : examples) {
if (example.size() == featureMap.size()) {
data[n] = DenseVector.createDenseVector(example, featureMap, false);
} else {
data[n] = SparseVector.createSparseVector(example, featureMap, false);
}
n++;
}
DenseVector coreDistances = calculateCoreDistances(data, k, neighboursQueryFactory);
ExtendedMinimumSpanningTree emst = constructEMST(data, coreDistances, distType);
// The levels at which each point becomes noise
double[] pointNoiseLevels = new double[data.length];
// The last label of each point before becoming noise
int[] pointLastClusters = new int[data.length];
// The HDBSCAN* hierarchy
Map<Integer, int[]> hierarchy = new HashMap<>();
List<HdbscanCluster> clusters = computeHierarchyAndClusterTree(emst, minClusterSize, pointNoiseLevels, pointLastClusters, hierarchy);
propagateTree(clusters);
List<Integer> clusterLabels = findProminentClusters(hierarchy, clusters, data.length);
DenseVector outlierScoresVector = calculateOutlierScores(pointNoiseLevels, pointLastClusters, clusters);
Map<Integer, List<Pair<Double, Integer>>> clusterAssignments = generateClusterAssignments(clusterLabels, outlierScoresVector);
// Use the cluster assignments to establish the clustering info
Map<Integer, MutableLong> counts = new HashMap<>();
for (Entry<Integer, List<Pair<Double, Integer>>> e : clusterAssignments.entrySet()) {
counts.put(e.getKey(), new MutableLong(e.getValue().size()));
}
ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
// Compute the cluster exemplars.
List<ClusterExemplar> clusterExemplars = computeExemplars(data, clusterAssignments, distType);
// Get the outlier score value for points that are predicted as noise points.
double noisePointsOutlierScore = getNoisePointsOutlierScore(clusterAssignments);
logger.log(Level.INFO, "Hdbscan is done.");
ModelProvenance provenance = new ModelProvenance(HdbscanModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
return new HdbscanModel("hdbscan-model", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector, clusterExemplars, distType, noisePointsOutlierScore);
}
use of org.tribuo.provenance.TrainerProvenance in project tribuo by oracle.
the class XGBoostClassificationTrainer method train.
@Override
public synchronized XGBoostModel<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
if (examples.getOutputInfo().getUnknownCount() > 0) {
throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
}
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
ImmutableOutputInfo<Label> outputInfo = examples.getOutputIDInfo();
if (invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
TrainerProvenance trainerProvenance = getProvenance();
trainInvocationCounter++;
parameters.put("num_class", outputInfo.size());
Booster model;
Function<Label, Float> responseExtractor = (Label l) -> (float) outputInfo.getID(l);
try {
DMatrixTuple<Label> trainingData = convertExamples(examples, featureMap, responseExtractor);
model = XGBoost.train(trainingData.data, parameters, numTrees, Collections.emptyMap(), null, null);
} catch (XGBoostError e) {
logger.log(Level.SEVERE, "XGBoost threw an error", e);
throw new IllegalStateException(e);
}
ModelProvenance provenance = new ModelProvenance(XGBoostModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
XGBoostModel<Label> xgModel = createModel("xgboost-classification-model", provenance, featureMap, outputInfo, Collections.singletonList(model), new XGBoostClassificationConverter());
return xgModel;
}
use of org.tribuo.provenance.TrainerProvenance in project tribuo by oracle.
the class XGBoostRegressionTrainer method train.
@Override
public synchronized XGBoostModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance, int invocationCount) {
if (examples.getOutputInfo().getUnknownCount() > 0) {
throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
}
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo();
int numOutputs = outputInfo.size();
if (invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
TrainerProvenance trainerProvenance = getProvenance();
trainInvocationCounter++;
List<Booster> models = new ArrayList<>();
try {
// Use a null response extractor as we'll do the per dimension regression extraction later.
DMatrixTuple<Regressor> trainingData = convertExamples(examples, featureMap, null);
// Map the natural order into ids
int[] dimensionIds = ((ImmutableRegressionInfo) outputInfo).getNaturalOrderToIDMapping();
// Extract the weights and the regression targets.
float[][] outputs = new float[numOutputs][examples.size()];
float[] weights = new float[examples.size()];
int i = 0;
for (Example<Regressor> e : examples) {
weights[i] = e.getWeight();
double[] curOutputs = e.getOutput().getValues();
// Transpose them for easy training.
for (int j = 0; j < numOutputs; j++) {
outputs[dimensionIds[j]][i] = (float) curOutputs[j];
}
i++;
}
trainingData.data.setWeight(weights);
// Finished setup, now train one model per dimension.
for (i = 0; i < numOutputs; i++) {
trainingData.data.setLabel(outputs[i]);
models.add(XGBoost.train(trainingData.data, parameters, numTrees, Collections.emptyMap(), null, null));
}
} catch (XGBoostError e) {
logger.log(Level.SEVERE, "XGBoost threw an error", e);
throw new IllegalStateException(e);
}
ModelProvenance provenance = new ModelProvenance(XGBoostModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
XGBoostModel<Regressor> xgModel = createModel("xgboost-regression-model", provenance, featureMap, outputInfo, models, new XGBoostRegressionConverter());
return xgModel;
}
use of org.tribuo.provenance.TrainerProvenance in project tribuo by oracle.
the class StripProvenance method cleanEnsembleProvenance.
/**
* Creates a new ensemble provenance with the requested information removed.
* @param old The old ensemble provenance.
* @param memberProvenance The new member provenances.
* @param provenanceHash The old ensemble provenance hash.
* @param opt The program options.
* @return The new ensemble provenance with the requested fields removed.
*/
private static EnsembleModelProvenance cleanEnsembleProvenance(EnsembleModelProvenance old, ListProvenance<ModelProvenance> memberProvenance, String provenanceHash, StripProvenanceOptions opt) {
// Dataset provenance
DatasetProvenance datasetProvenance;
if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(DATASET)) {
datasetProvenance = new EmptyDatasetProvenance();
} else {
datasetProvenance = old.getDatasetProvenance();
}
// Trainer provenance
TrainerProvenance trainerProvenance;
if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(TRAINER)) {
trainerProvenance = new EmptyTrainerProvenance();
} else {
trainerProvenance = old.getTrainerProvenance();
}
// Instance provenance
OffsetDateTime time;
Map<String, Provenance> instanceProvenance;
if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(INSTANCE)) {
instanceProvenance = new HashMap<>();
time = OffsetDateTime.MIN;
} else {
instanceProvenance = new HashMap<>(old.getInstanceProvenance().getMap());
time = old.getTrainingTime();
}
if (opt.storeHash) {
logger.info("Writing provenance hash into instance map.");
instanceProvenance.put("original-provenance-hash", new HashProvenance(opt.hashType, "original-provenance-hash", provenanceHash));
}
return new EnsembleModelProvenance(old.getClassName(), time, datasetProvenance, trainerProvenance, instanceProvenance, memberProvenance);
}
Aggregations