use of org.tribuo.sequence.ImmutableSequenceDataset in project tribuo by oracle.
the class ViterbiTrainer method train.
/**
* The viterbi train method is unique because it delegates to a regular
* {@link Model} train method, but before it does, it adds features derived
* from preceding labels. The pipeline upstream of this call should not care
* that these features are being added - that is, we would not want to make
* the upstream logic worry about what kind of trainer will be used and have
* conditional logic that says to add special label-derived features if
* using the ViterbiTrainer. So, these one-of-a-kind unique-in-the-world
* label-derived features are generated here and added to the sequence
* examples of the passed in dataset. If you pass in a
* MutableSequenceDataset, then please be aware that your dataset will be
* modified after calling this method and therefore subsequent calls to
* other SequenceModel.train methods with your dataset should be avoided. If
* you pass in an ImmutableSequenceDataset, then please be aware that your
* entire dataset is going to be copied as a MutableSequenceDataset - so
* there is a memory penalty.
* @param dataset The input dataset.
* @param runProvenance Any additional information to record in the provenance.
* @return A {@link SequenceModel} using Viterbi wrapped around an inner {@link Model}.
*/
@Override
public SequenceModel<Label> train(SequenceDataset<Label> dataset, Map<String, Provenance> runProvenance) {
if (dataset.getOutputInfo().getUnknownCount() > 0) {
throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
}
// number of unique output values
if (stackSize == -1) {
stackSize = dataset.getOutputIDInfo().size();
}
// create a copy of the dataset to a mutable one. See note above.
if (dataset instanceof ImmutableSequenceDataset) {
dataset = new MutableSequenceDataset<>((ImmutableSequenceDataset<Label>) dataset);
}
if (!(dataset instanceof MutableSequenceDataset)) {
throw new IllegalArgumentException("unable to handle sub-type of dataset: " + dataset.getClass().getName());
}
for (SequenceExample<Label> sequenceExample : dataset) {
List<Label> labels = new ArrayList<>();
for (Example<Label> example : sequenceExample) {
List<Feature> labelFeatures = extractFeatures(labels, (MutableSequenceDataset<Label>) dataset, 1.0);
example.addAll(labelFeatures);
labels.add(example.getOutput());
}
}
TrainerProvenance trainerProvenance = getProvenance();
ModelProvenance provenance = new ModelProvenance(ViterbiModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), trainerProvenance, runProvenance);
trainInvocationCounter++;
Dataset<Label> flatData = dataset.getFlatDataset();
Model<Label> model = trainer.train(flatData);
return new ViterbiModel("viterbi+" + model.getName(), provenance, model, labelFeatureExtractor, stackSize, scoreAggregation);
}
Aggregations