use of org.tribuo.WeightedExamples in project tribuo by oracle.
the class AdaBoostTrainer method train.
@Override
public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
if (examples.getOutputInfo().getUnknownCount() > 0) {
throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
}
// Creates a new RNG, adds one to the invocation count.
SplittableRandom localRNG;
TrainerProvenance trainerProvenance;
synchronized (this) {
if (invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
localRNG = rng.split();
trainerProvenance = getProvenance();
trainInvocationCounter++;
}
boolean weighted = innerTrainer instanceof WeightedExamples;
ImmutableFeatureMap featureIDs = examples.getFeatureIDMap();
ImmutableOutputInfo<Label> labelIDs = examples.getOutputIDInfo();
int numClasses = labelIDs.size();
logger.log(Level.INFO, "NumClasses = " + numClasses);
ArrayList<Model<Label>> models = new ArrayList<>();
float[] modelWeights = new float[numMembers];
float[] exampleWeights = Util.generateUniformFloatVector(examples.size(), 1.0f / examples.size());
if (weighted) {
logger.info("Using weighted Adaboost.");
examples = ImmutableDataset.copyDataset(examples);
for (int i = 0; i < examples.size(); i++) {
Example<Label> e = examples.getExample(i);
e.setWeight(exampleWeights[i]);
}
} else {
logger.info("Using sampling Adaboost.");
}
for (int i = 0; i < numMembers; i++) {
logger.info("Building model " + i);
Model<Label> newModel;
if (weighted) {
newModel = innerTrainer.train(examples);
} else {
DatasetView<Label> bag = DatasetView.createWeightedBootstrapView(examples, examples.size(), localRNG.nextLong(), exampleWeights, featureIDs, labelIDs);
newModel = innerTrainer.train(bag);
}
//
// Score this model
List<Prediction<Label>> predictions = newModel.predict(examples);
float accuracy = accuracy(predictions, examples, exampleWeights);
float error = 1.0f - accuracy;
float alpha = (float) (Math.log(accuracy / error) + Math.log(numClasses - 1));
models.add(newModel);
modelWeights[i] = alpha;
if ((accuracy + 1e-10) > 1.0) {
//
// Perfect accuracy, can no longer boost.
float[] newModelWeights = Arrays.copyOf(modelWeights, models.size());
// Set the last weight to 1, as it's infinity.
newModelWeights[models.size() - 1] = 1.0f;
logger.log(Level.FINE, "Perfect accuracy reached on iteration " + i + ", returning current model.");
logger.log(Level.FINE, "Model weights:");
Util.logVector(logger, Level.FINE, newModelWeights);
EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
return new WeightedEnsembleModel<>("boosted-ensemble", provenance, featureIDs, labelIDs, models, new VotingCombiner(), newModelWeights);
}
// Update the weights
for (int j = 0; j < predictions.size(); j++) {
if (!predictions.get(j).getOutput().equals(examples.getExample(j).getOutput())) {
exampleWeights[j] *= Math.exp(alpha);
}
}
Util.inplaceNormalizeToDistribution(exampleWeights);
if (weighted) {
for (int j = 0; j < examples.size(); j++) {
examples.getExample(j).setWeight(exampleWeights[j]);
}
}
}
logger.log(Level.FINE, "Model weights:");
Util.logVector(logger, Level.FINE, modelWeights);
EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
return new WeightedEnsembleModel<>("boosted-ensemble", provenance, featureIDs, labelIDs, models, new VotingCombiner(), modelWeights);
}
use of org.tribuo.WeightedExamples in project tribuo by oracle.
the class ConfigurableTrainTest method main.
/**
* @param args the command line arguments
*/
public static void main(String[] args) {
//
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();
ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args, o);
} catch (UsageException e) {
logger.info(e.getMessage());
return;
}
if (o.general.trainingPath == null || o.general.testingPath == null) {
logger.info(cm.usage());
System.exit(1);
}
Pair<Dataset<Label>, Dataset<Label>> data = null;
try {
data = o.general.load(new LabelFactory());
} catch (IOException e) {
logger.log(Level.SEVERE, "Failed to load data", e);
System.exit(1);
}
Dataset<Label> train = data.getA();
Dataset<Label> test = data.getB();
if (o.trainer == null) {
logger.warning("No trainer supplied");
logger.info(cm.usage());
System.exit(1);
}
logger.info("Trainer is " + o.trainer.toString());
if (o.weights != null) {
Map<Label, Float> weightsMap = processWeights(o.weights);
if (o.trainer instanceof WeightedLabels) {
((WeightedLabels) o.trainer).setLabelWeights(weightsMap);
logger.info("Setting label weights using " + weightsMap.toString());
} else if (o.trainer instanceof WeightedExamples) {
((MutableDataset<Label>) train).setWeights(weightsMap);
logger.info("Setting example weights using " + weightsMap.toString());
} else {
logger.warning("The selected trainer does not support weighted training. The chosen trainer is " + o.trainer.toString());
logger.info(cm.usage());
System.exit(1);
}
}
logger.info("Labels are " + train.getOutputInfo().toReadableString());
final long trainStart = System.currentTimeMillis();
Model<Label> model = o.trainer.train(train);
final long trainStop = System.currentTimeMillis();
logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
LabelEvaluator labelEvaluator = new LabelEvaluator();
final long testStart = System.currentTimeMillis();
List<Prediction<Label>> predictions = model.predict(test);
LabelEvaluation labelEvaluation = labelEvaluator.evaluate(model, predictions, test.getProvenance());
final long testStop = System.currentTimeMillis();
logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
System.out.println(labelEvaluation.toString());
ConfusionMatrix<Label> matrix = labelEvaluation.getConfusionMatrix();
System.out.println(matrix.toString());
if (model.generatesProbabilities()) {
System.out.println("Average AUC = " + labelEvaluation.averageAUCROC(false));
System.out.println("Average weighted AUC = " + labelEvaluation.averageAUCROC(true));
}
if (o.predictionPath != null) {
try (BufferedWriter wrt = Files.newBufferedWriter(o.predictionPath)) {
List<String> labels = model.getOutputIDInfo().getDomain().stream().map(Label::getLabel).sorted().collect(Collectors.toList());
wrt.write("Label,");
wrt.write(String.join(",", labels));
wrt.newLine();
for (Prediction<Label> pred : predictions) {
Example<Label> ex = pred.getExample();
wrt.write(ex.getOutput().getLabel() + ",");
wrt.write(labels.stream().map(l -> Double.toString(pred.getOutputScores().get(l).getScore())).collect(Collectors.joining(",")));
wrt.newLine();
}
wrt.flush();
} catch (IOException e) {
logger.log(Level.SEVERE, "Error writing predictions", e);
}
}
if (o.general.outputPath != null) {
try {
o.general.saveModel(model);
} catch (IOException e) {
logger.log(Level.SEVERE, "Error writing model", e);
}
}
}
Aggregations