Search in sources :

Example 1 with TokenPipeline

use of org.tribuo.data.text.impl.TokenPipeline in project tribuo by oracle.

the class RowProcessorTest method replaceNewlinesWithSpacesTest.

@Test
public void replaceNewlinesWithSpacesTest() {
    final Pattern BLANK_LINES = Pattern.compile("(\n[\\s-]*\n)+");
    final Function<CharSequence, CharSequence> newLiner = (CharSequence charSequence) -> {
        if (charSequence == null || charSequence.length() == 0) {
            return charSequence;
        }
        return BLANK_LINES.splitAsStream(charSequence).collect(Collectors.joining(" *\n\n"));
    };
    Tokenizer tokenizer = new MungingTokenizer(new BreakIteratorTokenizer(Locale.US), newLiner);
    TokenPipeline textPipeline = new TokenPipeline(tokenizer, 2, false);
    final Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
    fieldProcessors.put("order_text", new TextFieldProcessor("order_text", textPipeline));
    MockResponseProcessor response = new MockResponseProcessor("Label");
    Map<String, String> row = new HashMap<>();
    row.put("order_text", "Jimmy\n\n\n\nHoffa");
    row.put("Label", "Sheep");
    RowProcessor<MockOutput> processor = new RowProcessor<>(Collections.emptyList(), null, response, fieldProcessors, Collections.emptyMap(), Collections.emptySet(), false);
    Example<MockOutput> example = processor.generateExample(row, true).get();
    // Check example is extracted correctly
    assertEquals(5, example.size());
    assertEquals("Sheep", example.getOutput().label);
    Iterator<Feature> featureIterator = example.iterator();
    Feature a = featureIterator.next();
    assertEquals("order_text@1-N=*", a.getName());
    assertEquals(1.0, a.getValue());
    a = featureIterator.next();
    assertEquals("order_text@1-N=Hoffa", a.getName());
    a = featureIterator.next();
    assertEquals("order_text@1-N=Jimmy", a.getName());
    a = featureIterator.next();
    assertEquals("order_text@2-N=*/Hoffa", a.getName());
    a = featureIterator.next();
    assertEquals("order_text@2-N=Jimmy/*", a.getName());
    assertFalse(featureIterator.hasNext());
    // same input with replaceNewlinesWithSpacesTest=true (the default) produces different features
    processor = new RowProcessor<>(Collections.emptyList(), null, response, fieldProcessors, Collections.emptyMap(), Collections.emptySet(), true);
    example = processor.generateExample(row, true).get();
    // Check example is extracted correctly
    assertEquals(3, example.size());
    assertEquals("Sheep", example.getOutput().label);
    featureIterator = example.iterator();
    a = featureIterator.next();
    assertEquals("order_text@1-N=Hoffa", a.getName());
    assertEquals(1.0, a.getValue());
    a = featureIterator.next();
    assertEquals("order_text@1-N=Jimmy", a.getName());
    a = featureIterator.next();
    assertEquals("order_text@2-N=Jimmy/Hoffa", a.getName());
    assertFalse(featureIterator.hasNext());
}
Also used : Pattern(java.util.regex.Pattern) TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) MockOutput(org.tribuo.test.MockOutput) HashMap(java.util.HashMap) Feature(org.tribuo.Feature) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) TokenPipeline(org.tribuo.data.text.impl.TokenPipeline) Tokenizer(org.tribuo.util.tokens.Tokenizer) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) Test(org.junit.jupiter.api.Test)

Example 2 with TokenPipeline

use of org.tribuo.data.text.impl.TokenPipeline in project tribuo by oracle.

the class TextPipelineTest method testTokenPipeline.

@Test
public void testTokenPipeline() {
    String input = "This is some input text.";
    TokenPipeline pipeline = new TokenPipeline(new BreakIteratorTokenizer(Locale.US), 2, true);
    List<Feature> featureList = pipeline.process("", input);
    // logger.log(Level.INFO,featureList.toString());
    assertTrue(featureList.contains(new Feature("1-N=This", 1.0)));
    assertTrue(featureList.contains(new Feature("1-N=is", 1.0)));
    assertTrue(featureList.contains(new Feature("1-N=some", 1.0)));
    assertTrue(featureList.contains(new Feature("1-N=input", 1.0)));
    assertTrue(featureList.contains(new Feature("1-N=text", 1.0)));
    assertTrue(featureList.contains(new Feature("2-N=This/is", 1.0)));
    assertTrue(featureList.contains(new Feature("2-N=is/some", 1.0)));
    assertTrue(featureList.contains(new Feature("2-N=some/input", 1.0)));
    assertTrue(featureList.contains(new Feature("2-N=input/text", 1.0)));
}
Also used : TokenPipeline(org.tribuo.data.text.impl.TokenPipeline) Feature(org.tribuo.Feature) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) Test(org.junit.jupiter.api.Test)

Example 3 with TokenPipeline

use of org.tribuo.data.text.impl.TokenPipeline in project tribuo by oracle.

the class TextPipelineTest method testTokenPipelineTagging.

@Test
public void testTokenPipelineTagging() {
    String input = "This is some input text.";
    TokenPipeline pipeline = new TokenPipeline(new BreakIteratorTokenizer(Locale.US), 2, true);
    List<Feature> featureList = pipeline.process("Monkeys", input);
    // logger.log(Level.INFO,featureList.toString());
    assertTrue(featureList.contains(new Feature("Monkeys-1-N=This", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-1-N=is", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-1-N=some", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-1-N=input", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-1-N=text", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-2-N=This/is", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-2-N=is/some", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-2-N=some/input", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-2-N=input/text", 1.0)));
}
Also used : TokenPipeline(org.tribuo.data.text.impl.TokenPipeline) Feature(org.tribuo.Feature) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) Test(org.junit.jupiter.api.Test)

Example 4 with TokenPipeline

use of org.tribuo.data.text.impl.TokenPipeline in project tribuo by oracle.

the class DataOptions method load.

/**
 * Loads the training and testing data from {@link #trainingPath} and {@link #testingPath}
 * according to the other parameters specified in this class.
 * @param outputFactory The output factory to use to process the inputs.
 * @param <T> The dataset output type.
 * @return A pair containing the training and testing datasets. The training dataset is element 'A' and the
 * testing dataset is element 'B'.
 * @throws IOException If the paths could not be loaded.
 */
public <T extends Output<T>> Pair<Dataset<T>, Dataset<T>> load(OutputFactory<T> outputFactory) throws IOException {
    logger.info(String.format("Loading data from %s", trainingPath));
    Dataset<T> train;
    Dataset<T> test;
    char separator;
    switch(inputFormat) {
        case SERIALIZED:
            // 
            // Load Tribuo serialised datasets.
            logger.info("Deserialising dataset from " + trainingPath);
            try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(trainingPath.toFile())));
                ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(testingPath.toFile())))) {
                @SuppressWarnings("unchecked") Dataset<T> tmp = (Dataset<T>) ois.readObject();
                train = tmp;
                if (minCount > 0) {
                    logger.info("Found " + train.getFeatureIDMap().size() + " features");
                    logger.info("Removing features that occur fewer than " + minCount + " times.");
                    train = new MinimumCardinalityDataset<>(train, minCount);
                }
                logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
                logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
                @SuppressWarnings("unchecked") Dataset<T> deserTest = (Dataset<T>) oits.readObject();
                test = new ImmutableDataset<>(deserTest, deserTest.getSourceProvenance(), deserTest.getOutputFactory(), train.getFeatureIDMap(), train.getOutputIDInfo(), true);
            } catch (ClassNotFoundException e) {
                throw new IllegalArgumentException("Unknown class in serialised files", e);
            }
            break;
        case LIBSVM:
            // 
            // Load the libsvm text-based data format.
            LibSVMDataSource<T> trainSVMSource = new LibSVMDataSource<>(trainingPath, outputFactory);
            train = new MutableDataset<>(trainSVMSource);
            boolean zeroIndexed = trainSVMSource.isZeroIndexed();
            int maxFeatureID = trainSVMSource.getMaxFeatureID();
            if (minCount > 0) {
                logger.info("Removing features that occur fewer than " + minCount + " times.");
                train = new MinimumCardinalityDataset<>(train, minCount);
            }
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new ImmutableDataset<>(new LibSVMDataSource<>(testingPath, outputFactory, zeroIndexed, maxFeatureID), train.getFeatureIDMap(), train.getOutputIDInfo(), false);
            break;
        case TEXT:
            // 
            // Using a simple Java break iterator to generate ngram features.
            TextFeatureExtractor<T> extractor;
            if (hashDim > 0) {
                extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting, hashDim));
            } else {
                extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting));
            }
            TextDataSource<T> trainSource = new SimpleTextDataSource<>(trainingPath, outputFactory, extractor);
            train = new MutableDataset<>(trainSource);
            if (minCount > 0) {
                logger.info("Removing features that occur fewer than " + minCount + " times.");
                train = new MinimumCardinalityDataset<>(train, minCount);
            }
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            TextDataSource<T> testSource = new SimpleTextDataSource<>(testingPath, outputFactory, extractor);
            test = new ImmutableDataset<>(testSource, train.getFeatureIDMap(), train.getOutputIDInfo(), false);
            break;
        case CSV:
            // Load the data using the simple CSV loader
            if (csvResponseName == null) {
                throw new IllegalArgumentException("Please supply a response column name");
            }
            separator = delimiter.value;
            CSVLoader<T> loader = new CSVLoader<>(separator, outputFactory);
            train = new MutableDataset<>(loader.loadDataSource(trainingPath, csvResponseName));
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new MutableDataset<>(loader.loadDataSource(testingPath, csvResponseName));
            break;
        case COLUMNAR:
            if (rowProcessor == null) {
                throw new IllegalArgumentException("Please supply a RowProcessor");
            }
            OutputFactory<?> rowOutputFactory = rowProcessor.getResponseProcessor().getOutputFactory();
            if (!rowOutputFactory.equals(outputFactory)) {
                throw new IllegalArgumentException("The RowProcessor doesn't use the same kind of OutputFactory as the one supplied. RowProcessor has " + rowOutputFactory.getClass().getSimpleName() + ", supplied " + outputFactory.getClass().getName());
            }
            // checked by the if statement above
            @SuppressWarnings("unchecked") RowProcessor<T> typedRowProcessor = (RowProcessor<T>) rowProcessor;
            separator = delimiter.value;
            train = new MutableDataset<>(new CSVDataSource<>(trainingPath, typedRowProcessor, true, separator, csvQuoteChar));
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new MutableDataset<>(new CSVDataSource<>(testingPath, typedRowProcessor, true, separator, csvQuoteChar));
            break;
        default:
            throw new IllegalArgumentException("Unsupported input format " + inputFormat);
    }
    logger.info(String.format("Loaded %d testing examples", test.size()));
    if (scaleFeatures) {
        logger.info("Fitting feature scaling");
        TransformationMap map = new TransformationMap(Collections.singletonList(new LinearScalingTransformation()));
        TransformerMap transformers = train.createTransformers(map, scaleIncZeros);
        logger.info("Applying scaling to training dataset");
        train = transformers.transformDataset(train);
        logger.info("Applying scaling to testing dataset");
        test = transformers.transformDataset(test);
    }
    return new Pair<>(train, test);
}
Also used : TransformerMap(org.tribuo.transform.TransformerMap) CSVDataSource(org.tribuo.data.csv.CSVDataSource) SimpleTextDataSource(org.tribuo.data.text.impl.SimpleTextDataSource) TransformationMap(org.tribuo.transform.TransformationMap) LinearScalingTransformation(org.tribuo.transform.transformations.LinearScalingTransformation) BufferedInputStream(java.io.BufferedInputStream) LibSVMDataSource(org.tribuo.datasource.LibSVMDataSource) RowProcessor(org.tribuo.data.columnar.RowProcessor) Pair(com.oracle.labs.mlrg.olcut.util.Pair) CSVLoader(org.tribuo.data.csv.CSVLoader) ImmutableDataset(org.tribuo.ImmutableDataset) Dataset(org.tribuo.Dataset) MinimumCardinalityDataset(org.tribuo.dataset.MinimumCardinalityDataset) MutableDataset(org.tribuo.MutableDataset) FileInputStream(java.io.FileInputStream) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) TokenPipeline(org.tribuo.data.text.impl.TokenPipeline) ObjectInputStream(java.io.ObjectInputStream)

Example 5 with TokenPipeline

use of org.tribuo.data.text.impl.TokenPipeline in project tribuo by oracle.

the class Test method load.

/**
 * Loads in the model and the dataset from the options.
 * @param o The options.
 * @return The model and the dataset.
 * @throws IOException If either the model or dataset could not be read.
 */
// deserialising generically typed datasets.
@SuppressWarnings("unchecked")
public static Pair<Model<Label>, Dataset<Label>> load(ConfigurableTestOptions o) throws IOException {
    Path modelPath = o.modelPath;
    Path datasetPath = o.testingPath;
    logger.info(String.format("Loading model from %s", modelPath));
    Model<Label> model;
    try (ObjectInputStream mois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(modelPath.toFile())))) {
        model = (Model<Label>) mois.readObject();
        boolean valid = model.validate(Label.class);
        if (!valid) {
            throw new ClassCastException("Failed to cast deserialised Model to Model<Label>");
        }
    } catch (ClassNotFoundException e) {
        throw new IllegalArgumentException("Unknown class in serialised model", e);
    }
    logger.info(String.format("Loading data from %s", datasetPath));
    Dataset<Label> test;
    switch(o.inputFormat) {
        case SERIALIZED:
            // 
            // Load Tribuo serialised datasets.
            logger.info("Deserialising dataset from " + datasetPath);
            try (ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(datasetPath.toFile())))) {
                Dataset<Label> deserTest = (Dataset<Label>) oits.readObject();
                test = ImmutableDataset.copyDataset(deserTest, model.getFeatureIDMap(), model.getOutputIDInfo());
                logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString()));
            } catch (ClassNotFoundException e) {
                throw new IllegalArgumentException("Unknown class in serialised dataset", e);
            }
            break;
        case LIBSVM:
            // 
            // Load the libsvm text-based data format.
            boolean zeroIndexed = o.zeroIndexed;
            int maxFeatureID = model.getFeatureIDMap().size() - 1;
            LibSVMDataSource<Label> testSVMSource = new LibSVMDataSource<>(datasetPath, new LabelFactory(), zeroIndexed, maxFeatureID);
            test = new ImmutableDataset<>(testSVMSource, model, true);
            logger.info(String.format("Loaded %d training examples for %s", test.size(), test.getOutputs().toString()));
            break;
        case TEXT:
            // 
            // Using a simple Java break iterator to generate ngram features.
            TextFeatureExtractor<Label> extractor;
            if (o.hashDim > 0) {
                extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), o.ngram, o.termCounting, o.hashDim));
            } else {
                extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), o.ngram, o.termCounting));
            }
            TextDataSource<Label> testSource = new SimpleTextDataSource<>(datasetPath, new LabelFactory(), extractor);
            test = new ImmutableDataset<>(testSource, model.getFeatureIDMap(), model.getOutputIDInfo(), true);
            logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString()));
            break;
        case CSV:
            // Load the data using the simple CSV loader
            if (o.csvResponseName == null) {
                throw new IllegalArgumentException("Please supply a response column name");
            }
            CSVLoader<Label> loader = new CSVLoader<>(new LabelFactory());
            test = new ImmutableDataset<>(loader.loadDataSource(datasetPath, o.csvResponseName), model.getFeatureIDMap(), model.getOutputIDInfo(), true);
            logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString()));
            break;
        default:
            throw new IllegalArgumentException("Unsupported input format " + o.inputFormat);
    }
    return new Pair<>(model, test);
}
Also used : Label(org.tribuo.classification.Label) SimpleTextDataSource(org.tribuo.data.text.impl.SimpleTextDataSource) BufferedInputStream(java.io.BufferedInputStream) LibSVMDataSource(org.tribuo.datasource.LibSVMDataSource) Pair(com.oracle.labs.mlrg.olcut.util.Pair) Path(java.nio.file.Path) CSVLoader(org.tribuo.data.csv.CSVLoader) ImmutableDataset(org.tribuo.ImmutableDataset) Dataset(org.tribuo.Dataset) FileInputStream(java.io.FileInputStream) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) LabelFactory(org.tribuo.classification.LabelFactory) TokenPipeline(org.tribuo.data.text.impl.TokenPipeline) ObjectInputStream(java.io.ObjectInputStream)

Aggregations

TokenPipeline (org.tribuo.data.text.impl.TokenPipeline)5 BreakIteratorTokenizer (org.tribuo.util.tokens.impl.BreakIteratorTokenizer)5 Test (org.junit.jupiter.api.Test)3 Feature (org.tribuo.Feature)3 Pair (com.oracle.labs.mlrg.olcut.util.Pair)2 BufferedInputStream (java.io.BufferedInputStream)2 FileInputStream (java.io.FileInputStream)2 ObjectInputStream (java.io.ObjectInputStream)2 Dataset (org.tribuo.Dataset)2 ImmutableDataset (org.tribuo.ImmutableDataset)2 CSVLoader (org.tribuo.data.csv.CSVLoader)2 SimpleTextDataSource (org.tribuo.data.text.impl.SimpleTextDataSource)2 LibSVMDataSource (org.tribuo.datasource.LibSVMDataSource)2 Path (java.nio.file.Path)1 HashMap (java.util.HashMap)1 Pattern (java.util.regex.Pattern)1 MutableDataset (org.tribuo.MutableDataset)1 Label (org.tribuo.classification.Label)1 LabelFactory (org.tribuo.classification.LabelFactory)1 RowProcessor (org.tribuo.data.columnar.RowProcessor)1