Search in sources :

Example 1 with ImmutableSequenceDataset

use of org.tribuo.sequence.ImmutableSequenceDataset in project tribuo by oracle.

the class ViterbiTrainer method train.

/**
 * The viterbi train method is unique because it delegates to a regular
 * {@link Model} train method, but before it does, it adds features derived
 * from preceding labels. The pipeline upstream of this call should not care
 * that these features are being added - that is, we would not want to make
 * the upstream logic worry about what kind of trainer will be used and have
 * conditional logic that says to add special label-derived features if
 * using the ViterbiTrainer. So, these one-of-a-kind unique-in-the-world
 * label-derived features are generated here and added to the sequence
 * examples of the passed in dataset. If you pass in a
 * MutableSequenceDataset, then please be aware that your dataset will be
 * modified after calling this method and therefore subsequent calls to
 * other SequenceModel.train methods with your dataset should be avoided. If
 * you pass in an ImmutableSequenceDataset, then please be aware that your
 * entire dataset is going to be copied as a MutableSequenceDataset - so
 * there is a memory penalty.
 * @param dataset The input dataset.
 * @param runProvenance Any additional information to record in the provenance.
 * @return A {@link SequenceModel} using Viterbi wrapped around an inner {@link Model}.
 */
@Override
public SequenceModel<Label> train(SequenceDataset<Label> dataset, Map<String, Provenance> runProvenance) {
    if (dataset.getOutputInfo().getUnknownCount() > 0) {
        throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
    }
    // number of unique output values
    if (stackSize == -1) {
        stackSize = dataset.getOutputIDInfo().size();
    }
    // create a copy of the dataset to a mutable one. See note above.
    if (dataset instanceof ImmutableSequenceDataset) {
        dataset = new MutableSequenceDataset<>((ImmutableSequenceDataset<Label>) dataset);
    }
    if (!(dataset instanceof MutableSequenceDataset)) {
        throw new IllegalArgumentException("unable to handle sub-type of dataset: " + dataset.getClass().getName());
    }
    for (SequenceExample<Label> sequenceExample : dataset) {
        List<Label> labels = new ArrayList<>();
        for (Example<Label> example : sequenceExample) {
            List<Feature> labelFeatures = extractFeatures(labels, (MutableSequenceDataset<Label>) dataset, 1.0);
            example.addAll(labelFeatures);
            labels.add(example.getOutput());
        }
    }
    TrainerProvenance trainerProvenance = getProvenance();
    ModelProvenance provenance = new ModelProvenance(ViterbiModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), trainerProvenance, runProvenance);
    trainInvocationCounter++;
    Dataset<Label> flatData = dataset.getFlatDataset();
    Model<Label> model = trainer.train(flatData);
    return new ViterbiModel("viterbi+" + model.getName(), provenance, model, labelFeatureExtractor, stackSize, scoreAggregation);
}
Also used : MutableSequenceDataset(org.tribuo.sequence.MutableSequenceDataset) ModelProvenance(org.tribuo.provenance.ModelProvenance) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) Feature(org.tribuo.Feature) ImmutableSequenceDataset(org.tribuo.sequence.ImmutableSequenceDataset) TrainerProvenance(org.tribuo.provenance.TrainerProvenance)

Aggregations

ArrayList (java.util.ArrayList)1 Feature (org.tribuo.Feature)1 Label (org.tribuo.classification.Label)1 ModelProvenance (org.tribuo.provenance.ModelProvenance)1 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)1 ImmutableSequenceDataset (org.tribuo.sequence.ImmutableSequenceDataset)1 MutableSequenceDataset (org.tribuo.sequence.MutableSequenceDataset)1