Search in sources :

Example 1 with WeightedExamples

use of org.tribuo.WeightedExamples in project tribuo by oracle.

the class AdaBoostTrainer method train.

@Override
public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
    if (examples.getOutputInfo().getUnknownCount() > 0) {
        throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
    }
    // Creates a new RNG, adds one to the invocation count.
    SplittableRandom localRNG;
    TrainerProvenance trainerProvenance;
    synchronized (this) {
        if (invocationCount != INCREMENT_INVOCATION_COUNT) {
            setInvocationCount(invocationCount);
        }
        localRNG = rng.split();
        trainerProvenance = getProvenance();
        trainInvocationCounter++;
    }
    boolean weighted = innerTrainer instanceof WeightedExamples;
    ImmutableFeatureMap featureIDs = examples.getFeatureIDMap();
    ImmutableOutputInfo<Label> labelIDs = examples.getOutputIDInfo();
    int numClasses = labelIDs.size();
    logger.log(Level.INFO, "NumClasses = " + numClasses);
    ArrayList<Model<Label>> models = new ArrayList<>();
    float[] modelWeights = new float[numMembers];
    float[] exampleWeights = Util.generateUniformFloatVector(examples.size(), 1.0f / examples.size());
    if (weighted) {
        logger.info("Using weighted Adaboost.");
        examples = ImmutableDataset.copyDataset(examples);
        for (int i = 0; i < examples.size(); i++) {
            Example<Label> e = examples.getExample(i);
            e.setWeight(exampleWeights[i]);
        }
    } else {
        logger.info("Using sampling Adaboost.");
    }
    for (int i = 0; i < numMembers; i++) {
        logger.info("Building model " + i);
        Model<Label> newModel;
        if (weighted) {
            newModel = innerTrainer.train(examples);
        } else {
            DatasetView<Label> bag = DatasetView.createWeightedBootstrapView(examples, examples.size(), localRNG.nextLong(), exampleWeights, featureIDs, labelIDs);
            newModel = innerTrainer.train(bag);
        }
        // 
        // Score this model
        List<Prediction<Label>> predictions = newModel.predict(examples);
        float accuracy = accuracy(predictions, examples, exampleWeights);
        float error = 1.0f - accuracy;
        float alpha = (float) (Math.log(accuracy / error) + Math.log(numClasses - 1));
        models.add(newModel);
        modelWeights[i] = alpha;
        if ((accuracy + 1e-10) > 1.0) {
            // 
            // Perfect accuracy, can no longer boost.
            float[] newModelWeights = Arrays.copyOf(modelWeights, models.size());
            // Set the last weight to 1, as it's infinity.
            newModelWeights[models.size() - 1] = 1.0f;
            logger.log(Level.FINE, "Perfect accuracy reached on iteration " + i + ", returning current model.");
            logger.log(Level.FINE, "Model weights:");
            Util.logVector(logger, Level.FINE, newModelWeights);
            EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
            return new WeightedEnsembleModel<>("boosted-ensemble", provenance, featureIDs, labelIDs, models, new VotingCombiner(), newModelWeights);
        }
        // Update the weights
        for (int j = 0; j < predictions.size(); j++) {
            if (!predictions.get(j).getOutput().equals(examples.getExample(j).getOutput())) {
                exampleWeights[j] *= Math.exp(alpha);
            }
        }
        Util.inplaceNormalizeToDistribution(exampleWeights);
        if (weighted) {
            for (int j = 0; j < examples.size(); j++) {
                examples.getExample(j).setWeight(exampleWeights[j]);
            }
        }
    }
    logger.log(Level.FINE, "Model weights:");
    Util.logVector(logger, Level.FINE, modelWeights);
    EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
    return new WeightedEnsembleModel<>("boosted-ensemble", provenance, featureIDs, labelIDs, models, new VotingCombiner(), modelWeights);
}
Also used : Prediction(org.tribuo.Prediction) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) WeightedExamples(org.tribuo.WeightedExamples) EnsembleModelProvenance(org.tribuo.provenance.EnsembleModelProvenance) Model(org.tribuo.Model) WeightedEnsembleModel(org.tribuo.ensemble.WeightedEnsembleModel) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) WeightedEnsembleModel(org.tribuo.ensemble.WeightedEnsembleModel) SplittableRandom(java.util.SplittableRandom) TrainerProvenance(org.tribuo.provenance.TrainerProvenance)

Example 2 with WeightedExamples

use of org.tribuo.WeightedExamples in project tribuo by oracle.

the class ConfigurableTrainTest method main.

/**
 * @param args the command line arguments
 */
public static void main(String[] args) {
    // 
    // Use the labs format logging.
    LabsLogFormatter.setAllLogFormatters();
    ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
    ConfigurationManager cm;
    try {
        cm = new ConfigurationManager(args, o);
    } catch (UsageException e) {
        logger.info(e.getMessage());
        return;
    }
    if (o.general.trainingPath == null || o.general.testingPath == null) {
        logger.info(cm.usage());
        System.exit(1);
    }
    Pair<Dataset<Label>, Dataset<Label>> data = null;
    try {
        data = o.general.load(new LabelFactory());
    } catch (IOException e) {
        logger.log(Level.SEVERE, "Failed to load data", e);
        System.exit(1);
    }
    Dataset<Label> train = data.getA();
    Dataset<Label> test = data.getB();
    if (o.trainer == null) {
        logger.warning("No trainer supplied");
        logger.info(cm.usage());
        System.exit(1);
    }
    logger.info("Trainer is " + o.trainer.toString());
    if (o.weights != null) {
        Map<Label, Float> weightsMap = processWeights(o.weights);
        if (o.trainer instanceof WeightedLabels) {
            ((WeightedLabels) o.trainer).setLabelWeights(weightsMap);
            logger.info("Setting label weights using " + weightsMap.toString());
        } else if (o.trainer instanceof WeightedExamples) {
            ((MutableDataset<Label>) train).setWeights(weightsMap);
            logger.info("Setting example weights using " + weightsMap.toString());
        } else {
            logger.warning("The selected trainer does not support weighted training. The chosen trainer is " + o.trainer.toString());
            logger.info(cm.usage());
            System.exit(1);
        }
    }
    logger.info("Labels are " + train.getOutputInfo().toReadableString());
    final long trainStart = System.currentTimeMillis();
    Model<Label> model = o.trainer.train(train);
    final long trainStop = System.currentTimeMillis();
    logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
    LabelEvaluator labelEvaluator = new LabelEvaluator();
    final long testStart = System.currentTimeMillis();
    List<Prediction<Label>> predictions = model.predict(test);
    LabelEvaluation labelEvaluation = labelEvaluator.evaluate(model, predictions, test.getProvenance());
    final long testStop = System.currentTimeMillis();
    logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
    System.out.println(labelEvaluation.toString());
    ConfusionMatrix<Label> matrix = labelEvaluation.getConfusionMatrix();
    System.out.println(matrix.toString());
    if (model.generatesProbabilities()) {
        System.out.println("Average AUC = " + labelEvaluation.averageAUCROC(false));
        System.out.println("Average weighted AUC = " + labelEvaluation.averageAUCROC(true));
    }
    if (o.predictionPath != null) {
        try (BufferedWriter wrt = Files.newBufferedWriter(o.predictionPath)) {
            List<String> labels = model.getOutputIDInfo().getDomain().stream().map(Label::getLabel).sorted().collect(Collectors.toList());
            wrt.write("Label,");
            wrt.write(String.join(",", labels));
            wrt.newLine();
            for (Prediction<Label> pred : predictions) {
                Example<Label> ex = pred.getExample();
                wrt.write(ex.getOutput().getLabel() + ",");
                wrt.write(labels.stream().map(l -> Double.toString(pred.getOutputScores().get(l).getScore())).collect(Collectors.joining(",")));
                wrt.newLine();
            }
            wrt.flush();
        } catch (IOException e) {
            logger.log(Level.SEVERE, "Error writing predictions", e);
        }
    }
    if (o.general.outputPath != null) {
        try {
            o.general.saveModel(model);
        } catch (IOException e) {
            logger.log(Level.SEVERE, "Error writing model", e);
        }
    }
}
Also used : UsageException(com.oracle.labs.mlrg.olcut.config.UsageException) Label(org.tribuo.classification.Label) WeightedExamples(org.tribuo.WeightedExamples) BufferedWriter(java.io.BufferedWriter) WeightedLabels(org.tribuo.classification.WeightedLabels) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) ConfigurationManager(com.oracle.labs.mlrg.olcut.config.ConfigurationManager) Dataset(org.tribuo.Dataset) MutableDataset(org.tribuo.MutableDataset) Prediction(org.tribuo.Prediction) IOException(java.io.IOException) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator) LabelFactory(org.tribuo.classification.LabelFactory)

Aggregations

Prediction (org.tribuo.Prediction)2 WeightedExamples (org.tribuo.WeightedExamples)2 Label (org.tribuo.classification.Label)2 ConfigurationManager (com.oracle.labs.mlrg.olcut.config.ConfigurationManager)1 UsageException (com.oracle.labs.mlrg.olcut.config.UsageException)1 BufferedWriter (java.io.BufferedWriter)1 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1 SplittableRandom (java.util.SplittableRandom)1 Dataset (org.tribuo.Dataset)1 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)1 Model (org.tribuo.Model)1 MutableDataset (org.tribuo.MutableDataset)1 LabelFactory (org.tribuo.classification.LabelFactory)1 WeightedLabels (org.tribuo.classification.WeightedLabels)1 LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)1 LabelEvaluator (org.tribuo.classification.evaluation.LabelEvaluator)1 WeightedEnsembleModel (org.tribuo.ensemble.WeightedEnsembleModel)1 EnsembleModelProvenance (org.tribuo.provenance.EnsembleModelProvenance)1 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)1