use of org.tribuo.data.text.impl.TokenPipeline in project tribuo by oracle.
the class RowProcessorTest method replaceNewlinesWithSpacesTest.
@Test
public void replaceNewlinesWithSpacesTest() {
final Pattern BLANK_LINES = Pattern.compile("(\n[\\s-]*\n)+");
final Function<CharSequence, CharSequence> newLiner = (CharSequence charSequence) -> {
if (charSequence == null || charSequence.length() == 0) {
return charSequence;
}
return BLANK_LINES.splitAsStream(charSequence).collect(Collectors.joining(" *\n\n"));
};
Tokenizer tokenizer = new MungingTokenizer(new BreakIteratorTokenizer(Locale.US), newLiner);
TokenPipeline textPipeline = new TokenPipeline(tokenizer, 2, false);
final Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
fieldProcessors.put("order_text", new TextFieldProcessor("order_text", textPipeline));
MockResponseProcessor response = new MockResponseProcessor("Label");
Map<String, String> row = new HashMap<>();
row.put("order_text", "Jimmy\n\n\n\nHoffa");
row.put("Label", "Sheep");
RowProcessor<MockOutput> processor = new RowProcessor<>(Collections.emptyList(), null, response, fieldProcessors, Collections.emptyMap(), Collections.emptySet(), false);
Example<MockOutput> example = processor.generateExample(row, true).get();
// Check example is extracted correctly
assertEquals(5, example.size());
assertEquals("Sheep", example.getOutput().label);
Iterator<Feature> featureIterator = example.iterator();
Feature a = featureIterator.next();
assertEquals("order_text@1-N=*", a.getName());
assertEquals(1.0, a.getValue());
a = featureIterator.next();
assertEquals("order_text@1-N=Hoffa", a.getName());
a = featureIterator.next();
assertEquals("order_text@1-N=Jimmy", a.getName());
a = featureIterator.next();
assertEquals("order_text@2-N=*/Hoffa", a.getName());
a = featureIterator.next();
assertEquals("order_text@2-N=Jimmy/*", a.getName());
assertFalse(featureIterator.hasNext());
// same input with replaceNewlinesWithSpacesTest=true (the default) produces different features
processor = new RowProcessor<>(Collections.emptyList(), null, response, fieldProcessors, Collections.emptyMap(), Collections.emptySet(), true);
example = processor.generateExample(row, true).get();
// Check example is extracted correctly
assertEquals(3, example.size());
assertEquals("Sheep", example.getOutput().label);
featureIterator = example.iterator();
a = featureIterator.next();
assertEquals("order_text@1-N=Hoffa", a.getName());
assertEquals(1.0, a.getValue());
a = featureIterator.next();
assertEquals("order_text@1-N=Jimmy", a.getName());
a = featureIterator.next();
assertEquals("order_text@2-N=Jimmy/Hoffa", a.getName());
assertFalse(featureIterator.hasNext());
}
use of org.tribuo.data.text.impl.TokenPipeline in project tribuo by oracle.
the class TextPipelineTest method testTokenPipeline.
@Test
public void testTokenPipeline() {
String input = "This is some input text.";
TokenPipeline pipeline = new TokenPipeline(new BreakIteratorTokenizer(Locale.US), 2, true);
List<Feature> featureList = pipeline.process("", input);
// logger.log(Level.INFO,featureList.toString());
assertTrue(featureList.contains(new Feature("1-N=This", 1.0)));
assertTrue(featureList.contains(new Feature("1-N=is", 1.0)));
assertTrue(featureList.contains(new Feature("1-N=some", 1.0)));
assertTrue(featureList.contains(new Feature("1-N=input", 1.0)));
assertTrue(featureList.contains(new Feature("1-N=text", 1.0)));
assertTrue(featureList.contains(new Feature("2-N=This/is", 1.0)));
assertTrue(featureList.contains(new Feature("2-N=is/some", 1.0)));
assertTrue(featureList.contains(new Feature("2-N=some/input", 1.0)));
assertTrue(featureList.contains(new Feature("2-N=input/text", 1.0)));
}
use of org.tribuo.data.text.impl.TokenPipeline in project tribuo by oracle.
the class TextPipelineTest method testTokenPipelineTagging.
@Test
public void testTokenPipelineTagging() {
String input = "This is some input text.";
TokenPipeline pipeline = new TokenPipeline(new BreakIteratorTokenizer(Locale.US), 2, true);
List<Feature> featureList = pipeline.process("Monkeys", input);
// logger.log(Level.INFO,featureList.toString());
assertTrue(featureList.contains(new Feature("Monkeys-1-N=This", 1.0)));
assertTrue(featureList.contains(new Feature("Monkeys-1-N=is", 1.0)));
assertTrue(featureList.contains(new Feature("Monkeys-1-N=some", 1.0)));
assertTrue(featureList.contains(new Feature("Monkeys-1-N=input", 1.0)));
assertTrue(featureList.contains(new Feature("Monkeys-1-N=text", 1.0)));
assertTrue(featureList.contains(new Feature("Monkeys-2-N=This/is", 1.0)));
assertTrue(featureList.contains(new Feature("Monkeys-2-N=is/some", 1.0)));
assertTrue(featureList.contains(new Feature("Monkeys-2-N=some/input", 1.0)));
assertTrue(featureList.contains(new Feature("Monkeys-2-N=input/text", 1.0)));
}
use of org.tribuo.data.text.impl.TokenPipeline in project tribuo by oracle.
the class DataOptions method load.
/**
* Loads the training and testing data from {@link #trainingPath} and {@link #testingPath}
* according to the other parameters specified in this class.
* @param outputFactory The output factory to use to process the inputs.
* @param <T> The dataset output type.
* @return A pair containing the training and testing datasets. The training dataset is element 'A' and the
* testing dataset is element 'B'.
* @throws IOException If the paths could not be loaded.
*/
public <T extends Output<T>> Pair<Dataset<T>, Dataset<T>> load(OutputFactory<T> outputFactory) throws IOException {
logger.info(String.format("Loading data from %s", trainingPath));
Dataset<T> train;
Dataset<T> test;
char separator;
switch(inputFormat) {
case SERIALIZED:
//
// Load Tribuo serialised datasets.
logger.info("Deserialising dataset from " + trainingPath);
try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(trainingPath.toFile())));
ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(testingPath.toFile())))) {
@SuppressWarnings("unchecked") Dataset<T> tmp = (Dataset<T>) ois.readObject();
train = tmp;
if (minCount > 0) {
logger.info("Found " + train.getFeatureIDMap().size() + " features");
logger.info("Removing features that occur fewer than " + minCount + " times.");
train = new MinimumCardinalityDataset<>(train, minCount);
}
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
@SuppressWarnings("unchecked") Dataset<T> deserTest = (Dataset<T>) oits.readObject();
test = new ImmutableDataset<>(deserTest, deserTest.getSourceProvenance(), deserTest.getOutputFactory(), train.getFeatureIDMap(), train.getOutputIDInfo(), true);
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Unknown class in serialised files", e);
}
break;
case LIBSVM:
//
// Load the libsvm text-based data format.
LibSVMDataSource<T> trainSVMSource = new LibSVMDataSource<>(trainingPath, outputFactory);
train = new MutableDataset<>(trainSVMSource);
boolean zeroIndexed = trainSVMSource.isZeroIndexed();
int maxFeatureID = trainSVMSource.getMaxFeatureID();
if (minCount > 0) {
logger.info("Removing features that occur fewer than " + minCount + " times.");
train = new MinimumCardinalityDataset<>(train, minCount);
}
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
test = new ImmutableDataset<>(new LibSVMDataSource<>(testingPath, outputFactory, zeroIndexed, maxFeatureID), train.getFeatureIDMap(), train.getOutputIDInfo(), false);
break;
case TEXT:
//
// Using a simple Java break iterator to generate ngram features.
TextFeatureExtractor<T> extractor;
if (hashDim > 0) {
extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting, hashDim));
} else {
extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting));
}
TextDataSource<T> trainSource = new SimpleTextDataSource<>(trainingPath, outputFactory, extractor);
train = new MutableDataset<>(trainSource);
if (minCount > 0) {
logger.info("Removing features that occur fewer than " + minCount + " times.");
train = new MinimumCardinalityDataset<>(train, minCount);
}
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
TextDataSource<T> testSource = new SimpleTextDataSource<>(testingPath, outputFactory, extractor);
test = new ImmutableDataset<>(testSource, train.getFeatureIDMap(), train.getOutputIDInfo(), false);
break;
case CSV:
// Load the data using the simple CSV loader
if (csvResponseName == null) {
throw new IllegalArgumentException("Please supply a response column name");
}
separator = delimiter.value;
CSVLoader<T> loader = new CSVLoader<>(separator, outputFactory);
train = new MutableDataset<>(loader.loadDataSource(trainingPath, csvResponseName));
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
test = new MutableDataset<>(loader.loadDataSource(testingPath, csvResponseName));
break;
case COLUMNAR:
if (rowProcessor == null) {
throw new IllegalArgumentException("Please supply a RowProcessor");
}
OutputFactory<?> rowOutputFactory = rowProcessor.getResponseProcessor().getOutputFactory();
if (!rowOutputFactory.equals(outputFactory)) {
throw new IllegalArgumentException("The RowProcessor doesn't use the same kind of OutputFactory as the one supplied. RowProcessor has " + rowOutputFactory.getClass().getSimpleName() + ", supplied " + outputFactory.getClass().getName());
}
// checked by the if statement above
@SuppressWarnings("unchecked") RowProcessor<T> typedRowProcessor = (RowProcessor<T>) rowProcessor;
separator = delimiter.value;
train = new MutableDataset<>(new CSVDataSource<>(trainingPath, typedRowProcessor, true, separator, csvQuoteChar));
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
test = new MutableDataset<>(new CSVDataSource<>(testingPath, typedRowProcessor, true, separator, csvQuoteChar));
break;
default:
throw new IllegalArgumentException("Unsupported input format " + inputFormat);
}
logger.info(String.format("Loaded %d testing examples", test.size()));
if (scaleFeatures) {
logger.info("Fitting feature scaling");
TransformationMap map = new TransformationMap(Collections.singletonList(new LinearScalingTransformation()));
TransformerMap transformers = train.createTransformers(map, scaleIncZeros);
logger.info("Applying scaling to training dataset");
train = transformers.transformDataset(train);
logger.info("Applying scaling to testing dataset");
test = transformers.transformDataset(test);
}
return new Pair<>(train, test);
}
use of org.tribuo.data.text.impl.TokenPipeline in project tribuo by oracle.
the class Test method load.
/**
* Loads in the model and the dataset from the options.
* @param o The options.
* @return The model and the dataset.
* @throws IOException If either the model or dataset could not be read.
*/
// deserialising generically typed datasets.
@SuppressWarnings("unchecked")
public static Pair<Model<Label>, Dataset<Label>> load(ConfigurableTestOptions o) throws IOException {
Path modelPath = o.modelPath;
Path datasetPath = o.testingPath;
logger.info(String.format("Loading model from %s", modelPath));
Model<Label> model;
try (ObjectInputStream mois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(modelPath.toFile())))) {
model = (Model<Label>) mois.readObject();
boolean valid = model.validate(Label.class);
if (!valid) {
throw new ClassCastException("Failed to cast deserialised Model to Model<Label>");
}
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Unknown class in serialised model", e);
}
logger.info(String.format("Loading data from %s", datasetPath));
Dataset<Label> test;
switch(o.inputFormat) {
case SERIALIZED:
//
// Load Tribuo serialised datasets.
logger.info("Deserialising dataset from " + datasetPath);
try (ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(datasetPath.toFile())))) {
Dataset<Label> deserTest = (Dataset<Label>) oits.readObject();
test = ImmutableDataset.copyDataset(deserTest, model.getFeatureIDMap(), model.getOutputIDInfo());
logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString()));
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Unknown class in serialised dataset", e);
}
break;
case LIBSVM:
//
// Load the libsvm text-based data format.
boolean zeroIndexed = o.zeroIndexed;
int maxFeatureID = model.getFeatureIDMap().size() - 1;
LibSVMDataSource<Label> testSVMSource = new LibSVMDataSource<>(datasetPath, new LabelFactory(), zeroIndexed, maxFeatureID);
test = new ImmutableDataset<>(testSVMSource, model, true);
logger.info(String.format("Loaded %d training examples for %s", test.size(), test.getOutputs().toString()));
break;
case TEXT:
//
// Using a simple Java break iterator to generate ngram features.
TextFeatureExtractor<Label> extractor;
if (o.hashDim > 0) {
extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), o.ngram, o.termCounting, o.hashDim));
} else {
extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), o.ngram, o.termCounting));
}
TextDataSource<Label> testSource = new SimpleTextDataSource<>(datasetPath, new LabelFactory(), extractor);
test = new ImmutableDataset<>(testSource, model.getFeatureIDMap(), model.getOutputIDInfo(), true);
logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString()));
break;
case CSV:
// Load the data using the simple CSV loader
if (o.csvResponseName == null) {
throw new IllegalArgumentException("Please supply a response column name");
}
CSVLoader<Label> loader = new CSVLoader<>(new LabelFactory());
test = new ImmutableDataset<>(loader.loadDataSource(datasetPath, o.csvResponseName), model.getFeatureIDMap(), model.getOutputIDInfo(), true);
logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString()));
break;
default:
throw new IllegalArgumentException("Unsupported input format " + o.inputFormat);
}
return new Pair<>(model, test);
}
Aggregations