use of org.tribuo.transform.TransformationMap in project tribuo by oracle.
the class DataOptions method load.
/**
* Loads the training and testing data from {@link #trainingPath} and {@link #testingPath}
* according to the other parameters specified in this class.
* @param outputFactory The output factory to use to process the inputs.
* @param <T> The dataset output type.
* @return A pair containing the training and testing datasets. The training dataset is element 'A' and the
* testing dataset is element 'B'.
* @throws IOException If the paths could not be loaded.
*/
public <T extends Output<T>> Pair<Dataset<T>, Dataset<T>> load(OutputFactory<T> outputFactory) throws IOException {
logger.info(String.format("Loading data from %s", trainingPath));
Dataset<T> train;
Dataset<T> test;
char separator;
switch(inputFormat) {
case SERIALIZED:
//
// Load Tribuo serialised datasets.
logger.info("Deserialising dataset from " + trainingPath);
try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(trainingPath.toFile())));
ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(testingPath.toFile())))) {
@SuppressWarnings("unchecked") Dataset<T> tmp = (Dataset<T>) ois.readObject();
train = tmp;
if (minCount > 0) {
logger.info("Found " + train.getFeatureIDMap().size() + " features");
logger.info("Removing features that occur fewer than " + minCount + " times.");
train = new MinimumCardinalityDataset<>(train, minCount);
}
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
@SuppressWarnings("unchecked") Dataset<T> deserTest = (Dataset<T>) oits.readObject();
test = new ImmutableDataset<>(deserTest, deserTest.getSourceProvenance(), deserTest.getOutputFactory(), train.getFeatureIDMap(), train.getOutputIDInfo(), true);
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Unknown class in serialised files", e);
}
break;
case LIBSVM:
//
// Load the libsvm text-based data format.
LibSVMDataSource<T> trainSVMSource = new LibSVMDataSource<>(trainingPath, outputFactory);
train = new MutableDataset<>(trainSVMSource);
boolean zeroIndexed = trainSVMSource.isZeroIndexed();
int maxFeatureID = trainSVMSource.getMaxFeatureID();
if (minCount > 0) {
logger.info("Removing features that occur fewer than " + minCount + " times.");
train = new MinimumCardinalityDataset<>(train, minCount);
}
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
test = new ImmutableDataset<>(new LibSVMDataSource<>(testingPath, outputFactory, zeroIndexed, maxFeatureID), train.getFeatureIDMap(), train.getOutputIDInfo(), false);
break;
case TEXT:
//
// Using a simple Java break iterator to generate ngram features.
TextFeatureExtractor<T> extractor;
if (hashDim > 0) {
extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting, hashDim));
} else {
extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting));
}
TextDataSource<T> trainSource = new SimpleTextDataSource<>(trainingPath, outputFactory, extractor);
train = new MutableDataset<>(trainSource);
if (minCount > 0) {
logger.info("Removing features that occur fewer than " + minCount + " times.");
train = new MinimumCardinalityDataset<>(train, minCount);
}
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
TextDataSource<T> testSource = new SimpleTextDataSource<>(testingPath, outputFactory, extractor);
test = new ImmutableDataset<>(testSource, train.getFeatureIDMap(), train.getOutputIDInfo(), false);
break;
case CSV:
// Load the data using the simple CSV loader
if (csvResponseName == null) {
throw new IllegalArgumentException("Please supply a response column name");
}
separator = delimiter.value;
CSVLoader<T> loader = new CSVLoader<>(separator, outputFactory);
train = new MutableDataset<>(loader.loadDataSource(trainingPath, csvResponseName));
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
test = new MutableDataset<>(loader.loadDataSource(testingPath, csvResponseName));
break;
case COLUMNAR:
if (rowProcessor == null) {
throw new IllegalArgumentException("Please supply a RowProcessor");
}
OutputFactory<?> rowOutputFactory = rowProcessor.getResponseProcessor().getOutputFactory();
if (!rowOutputFactory.equals(outputFactory)) {
throw new IllegalArgumentException("The RowProcessor doesn't use the same kind of OutputFactory as the one supplied. RowProcessor has " + rowOutputFactory.getClass().getSimpleName() + ", supplied " + outputFactory.getClass().getName());
}
// checked by the if statement above
@SuppressWarnings("unchecked") RowProcessor<T> typedRowProcessor = (RowProcessor<T>) rowProcessor;
separator = delimiter.value;
train = new MutableDataset<>(new CSVDataSource<>(trainingPath, typedRowProcessor, true, separator, csvQuoteChar));
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
test = new MutableDataset<>(new CSVDataSource<>(testingPath, typedRowProcessor, true, separator, csvQuoteChar));
break;
default:
throw new IllegalArgumentException("Unsupported input format " + inputFormat);
}
logger.info(String.format("Loaded %d testing examples", test.size()));
if (scaleFeatures) {
logger.info("Fitting feature scaling");
TransformationMap map = new TransformationMap(Collections.singletonList(new LinearScalingTransformation()));
TransformerMap transformers = train.createTransformers(map, scaleIncZeros);
logger.info("Applying scaling to training dataset");
train = transformers.transformDataset(train);
logger.info("Applying scaling to testing dataset");
test = transformers.transformDataset(test);
}
return new Pair<>(train, test);
}
use of org.tribuo.transform.TransformationMap in project tribuo by oracle.
the class Dataset method createTransformers.
/**
* Takes a {@link TransformationMap} and converts it into a {@link TransformerMap} by
* observing all the values in this dataset.
* <p>
* Does not mutate the dataset, if you wish to apply the TransformerMap, use
* {@link MutableDataset#transform} or {@link TransformerMap#transformDataset}.
* <p>
* TransformerMaps operate on feature values which are present, sparse values
* are ignored and not transformed. If the zeros should be transformed, call
* {@link MutableDataset#densify} on the datasets before applying a transformer.
* See {@link org.tribuo.transform} for a more detailed discussion of densify and includeImplicitZeroFeatures.
* <p>
* Throws {@link IllegalArgumentException} if the TransformationMap object has
* regexes which apply to multiple features.
* @param transformations The transformations to fit.
* @param includeImplicitZeroFeatures Use the implicit zero feature values to construct the transformations.
* @return A TransformerMap which can apply the transformations to a dataset.
*/
public TransformerMap createTransformers(TransformationMap transformations, boolean includeImplicitZeroFeatures) {
ArrayList<String> featureNames = new ArrayList<>(getFeatureMap().keySet());
// Validate map by checking no regex applies to multiple features.
logger.fine(String.format("Processing %d feature specific transforms", transformations.getFeatureTransformations().size()));
Map<String, List<Transformation>> featureTransformations = new HashMap<>();
for (Map.Entry<String, List<Transformation>> entry : transformations.getFeatureTransformations().entrySet()) {
// Compile the regex.
Pattern pattern = Pattern.compile(entry.getKey());
// Check all the feature names
for (String name : featureNames) {
// If the regex matches
if (pattern.matcher(name).matches()) {
List<Transformation> oldTransformations = featureTransformations.put(name, entry.getValue());
// See if there is already a transformation list for that name.
if (oldTransformations != null) {
throw new IllegalArgumentException("Feature name '" + name + "' matches multiple regexes, at least one of which was '" + entry.getKey() + "'.");
}
}
}
}
// Populate the feature transforms map.
Map<String, Queue<TransformStatistics>> featureStats = new HashMap<>();
// sparseCount tracks how many times a feature was not observed
Map<String, MutableLong> sparseCount = new HashMap<>();
for (Map.Entry<String, List<Transformation>> entry : featureTransformations.entrySet()) {
// Create the queue of feature transformations for this feature
Queue<TransformStatistics> l = new LinkedList<>();
for (Transformation t : entry.getValue()) {
l.add(t.createStats());
}
// Add the queue to the map for that feature
featureStats.put(entry.getKey(), l);
sparseCount.put(entry.getKey(), new MutableLong(data.size()));
}
if (!transformations.getGlobalTransformations().isEmpty()) {
// Append all the global transformations
int ntransform = featureNames.size();
logger.fine(String.format("Starting %,d global transformations", ntransform));
int ndone = 0;
for (String v : featureNames) {
// Create the queue of feature transformations for this feature
Queue<TransformStatistics> l = featureStats.computeIfAbsent(v, (k) -> new LinkedList<>());
for (Transformation t : transformations.getGlobalTransformations()) {
l.add(t.createStats());
}
// Add the queue to the map for that feature
featureStats.put(v, l);
// Generate the sparse count initialised to the number of features.
sparseCount.putIfAbsent(v, new MutableLong(data.size()));
ndone++;
if (logger.isLoggable(Level.FINE) && ndone % 10000 == 0) {
logger.fine(String.format("Completed %,d of %,d global transformations", ndone, ntransform));
}
}
}
Map<String, List<Transformer>> output = new LinkedHashMap<>();
Set<String> removeSet = new LinkedHashSet<>();
boolean initialisedSparseCounts = false;
// Iterate through the dataset max(transformations.length) times.
while (!featureStats.isEmpty()) {
for (Example<T> example : data) {
for (Feature f : example) {
if (featureStats.containsKey(f.getName())) {
if (!initialisedSparseCounts) {
sparseCount.get(f.getName()).decrement();
}
List<Transformer> curTransformers = output.get(f.getName());
// Apply all current transformations
double fValue = TransformerMap.applyTransformerList(f.getValue(), curTransformers);
// Observe the transformed value
featureStats.get(f.getName()).peek().observeValue(fValue);
}
}
}
// Sparse counts are updated (this could be protected by an if statement)
initialisedSparseCounts = true;
removeSet.clear();
// Emit the new transformers onto the end of the list in the output map.
for (Map.Entry<String, Queue<TransformStatistics>> entry : featureStats.entrySet()) {
TransformStatistics currentStats = entry.getValue().poll();
if (includeImplicitZeroFeatures) {
// Observe all the sparse feature values
int unobservedFeatures = sparseCount.get(entry.getKey()).intValue();
currentStats.observeSparse(unobservedFeatures);
}
// Get the transformer list for that feature (if absent)
List<Transformer> l = output.computeIfAbsent(entry.getKey(), (k) -> new ArrayList<>());
// Generate the transformer and add it to the appropriate list.
l.add(currentStats.generateTransformer());
// If the queue is empty, remove that feature, ensuring that featureStats is eventually empty.
if (entry.getValue().isEmpty()) {
removeSet.add(entry.getKey());
}
}
// Remove the features with empty queues.
for (String s : removeSet) {
featureStats.remove(s);
}
}
return new TransformerMap(output, getProvenance(), transformations.getProvenance());
}
Aggregations