Search in sources :

Example 1 with BasicPipeline

use of org.tribuo.data.text.impl.BasicPipeline in project tribuo by oracle.

the class TestXGBoost method loadDataset.

private Dataset<Label> loadDataset(XGBoostModel<Label> model, Path path) throws IOException {
    TextFeatureExtractor<Label> extractor = new TextFeatureExtractorImpl<>(new BasicPipeline(new BreakIteratorTokenizer(Locale.US), 2));
    TextDataSource<Label> src = new SimpleTextDataSource<>(path, new LabelFactory(), extractor);
    return new ImmutableDataset<>(src, model.getFeatureIDMap(), model.getOutputIDInfo(), false);
}
Also used : LabelFactory(org.tribuo.classification.LabelFactory) TextFeatureExtractorImpl(org.tribuo.data.text.impl.TextFeatureExtractorImpl) BasicPipeline(org.tribuo.data.text.impl.BasicPipeline) Label(org.tribuo.classification.Label) ImmutableDataset(org.tribuo.ImmutableDataset) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) SimpleTextDataSource(org.tribuo.data.text.impl.SimpleTextDataSource)

Example 2 with BasicPipeline

use of org.tribuo.data.text.impl.BasicPipeline in project tribuo by oracle.

the class JsonDataSourceTest method buildRowProcessor.

private static RowProcessor<MockOutput> buildRowProcessor() {
    Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
    fieldProcessors.put("height", new DoubleFieldProcessor("height"));
    fieldProcessors.put("description", new TextFieldProcessor("description", new BasicPipeline(new BreakIteratorTokenizer(Locale.US), 2)));
    fieldProcessors.put("transport", new IdentityProcessor("transport"));
    Map<String, FieldProcessor> regexMappingProcessors = new HashMap<>();
    regexMappingProcessors.put("extra.*", new DoubleFieldProcessor("regex"));
    ResponseProcessor<MockOutput> responseProcessor = new FieldResponseProcessor<>("disposition", "UNK", new MockOutputFactory());
    List<FieldExtractor<?>> metadataExtractors = new ArrayList<>();
    metadataExtractors.add(new IntExtractor("id"));
    metadataExtractors.add(new DateExtractor("timestamp", "timestamp", "dd/MM/yyyy HH:mm"));
    return new RowProcessor<>(metadataExtractors, null, responseProcessor, fieldProcessors, regexMappingProcessors, Collections.emptySet());
}
Also used : TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) IntExtractor(org.tribuo.data.columnar.extractors.IntExtractor) DateExtractor(org.tribuo.data.columnar.extractors.DateExtractor) MockOutput(org.tribuo.test.MockOutput) HashMap(java.util.HashMap) MockOutputFactory(org.tribuo.test.MockOutputFactory) BasicPipeline(org.tribuo.data.text.impl.BasicPipeline) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) FieldResponseProcessor(org.tribuo.data.columnar.processors.response.FieldResponseProcessor) ArrayList(java.util.ArrayList) FieldProcessor(org.tribuo.data.columnar.FieldProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) FieldExtractor(org.tribuo.data.columnar.FieldExtractor) IdentityProcessor(org.tribuo.data.columnar.processors.field.IdentityProcessor) RowProcessor(org.tribuo.data.columnar.RowProcessor)

Example 3 with BasicPipeline

use of org.tribuo.data.text.impl.BasicPipeline in project tribuo by oracle.

the class LIMEColumnarTest method generateBinarisedDataset.

private Pair<RowProcessor<Label>, Dataset<Label>> generateBinarisedDataset() throws URISyntaxException {
    LabelFactory labelFactory = new LabelFactory();
    ResponseProcessor<Label> responseProcessor = new FieldResponseProcessor<>("Response", "N", labelFactory);
    Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
    fieldProcessors.put("A", new IdentityProcessor("A"));
    fieldProcessors.put("B", new DoubleFieldProcessor("B"));
    fieldProcessors.put("C", new DoubleFieldProcessor("C"));
    fieldProcessors.put("D", new IdentityProcessor("D"));
    fieldProcessors.put("TextField", new TextFieldProcessor("TextField", new BasicPipeline(tokenizer, 2)));
    RowProcessor<Label> rp = new RowProcessor<>(responseProcessor, fieldProcessors);
    CSVDataSource<Label> source = new CSVDataSource<>(LIMEColumnarTest.class.getResource("/org/tribuo/classification/explanations/lime/test-columnar.csv").toURI(), rp, true);
    Dataset<Label> dataset = new MutableDataset<>(source);
    return new Pair<>(rp, dataset);
}
Also used : TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) HashMap(java.util.HashMap) BasicPipeline(org.tribuo.data.text.impl.BasicPipeline) Label(org.tribuo.classification.Label) FieldResponseProcessor(org.tribuo.data.columnar.processors.response.FieldResponseProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) CSVDataSource(org.tribuo.data.csv.CSVDataSource) FieldProcessor(org.tribuo.data.columnar.FieldProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) LabelFactory(org.tribuo.classification.LabelFactory) IdentityProcessor(org.tribuo.data.columnar.processors.field.IdentityProcessor) RowProcessor(org.tribuo.data.columnar.RowProcessor) MutableDataset(org.tribuo.MutableDataset) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 4 with BasicPipeline

use of org.tribuo.data.text.impl.BasicPipeline in project tribuo by oracle.

the class LIMEColumnarTest method generateCategoricalDataset.

private Pair<RowProcessor<Label>, Dataset<Label>> generateCategoricalDataset() throws URISyntaxException {
    LabelFactory labelFactory = new LabelFactory();
    ResponseProcessor<Label> responseProcessor = new FieldResponseProcessor<>("Response", "N", labelFactory);
    Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
    fieldProcessors.put("A", new IdentityProcessor("A") {

        @Override
        public GeneratedFeatureType getFeatureType() {
            return GeneratedFeatureType.CATEGORICAL;
        }
    });
    fieldProcessors.put("B", new DoubleFieldProcessor("B"));
    fieldProcessors.put("C", new DoubleFieldProcessor("C"));
    fieldProcessors.put("D", new IdentityProcessor("D") {

        @Override
        public GeneratedFeatureType getFeatureType() {
            return GeneratedFeatureType.CATEGORICAL;
        }
    });
    fieldProcessors.put("TextField", new TextFieldProcessor("TextField", new BasicPipeline(tokenizer, 2)));
    RowProcessor<Label> rp = new RowProcessor<>(responseProcessor, fieldProcessors);
    CSVDataSource<Label> source = new CSVDataSource<>(LIMEColumnarTest.class.getResource("/org/tribuo/classification/explanations/lime/test-columnar.csv").toURI(), rp, true);
    Dataset<Label> dataset = new MutableDataset<>(source);
    return new Pair<>(rp, dataset);
}
Also used : TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) HashMap(java.util.HashMap) BasicPipeline(org.tribuo.data.text.impl.BasicPipeline) Label(org.tribuo.classification.Label) FieldResponseProcessor(org.tribuo.data.columnar.processors.response.FieldResponseProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) CSVDataSource(org.tribuo.data.csv.CSVDataSource) FieldProcessor(org.tribuo.data.columnar.FieldProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) LabelFactory(org.tribuo.classification.LabelFactory) IdentityProcessor(org.tribuo.data.columnar.processors.field.IdentityProcessor) RowProcessor(org.tribuo.data.columnar.RowProcessor) MutableDataset(org.tribuo.MutableDataset) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 5 with BasicPipeline

use of org.tribuo.data.text.impl.BasicPipeline in project tribuo by oracle.

the class TextPipelineTest method testBasicPipelineTagging.

@Test
public void testBasicPipelineTagging() {
    String input = "This is some input text.";
    BasicPipeline pipeline = new BasicPipeline(new BreakIteratorTokenizer(Locale.US), 2);
    List<Feature> featureList = pipeline.process("Monkeys", input);
    // logger.log(Level.INFO,featureList.toString());
    assertTrue(featureList.contains(new Feature("Monkeys-1-N=This", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-1-N=is", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-1-N=some", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-1-N=input", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-1-N=text", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-2-N=This/is", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-2-N=is/some", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-2-N=some/input", 1.0)));
    assertTrue(featureList.contains(new Feature("Monkeys-2-N=input/text", 1.0)));
}
Also used : BasicPipeline(org.tribuo.data.text.impl.BasicPipeline) Feature(org.tribuo.Feature) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) Test(org.junit.jupiter.api.Test)

Aggregations

BasicPipeline (org.tribuo.data.text.impl.BasicPipeline)8 BreakIteratorTokenizer (org.tribuo.util.tokens.impl.BreakIteratorTokenizer)6 Label (org.tribuo.classification.Label)5 LabelFactory (org.tribuo.classification.LabelFactory)5 HashMap (java.util.HashMap)3 ImmutableDataset (org.tribuo.ImmutableDataset)3 FieldProcessor (org.tribuo.data.columnar.FieldProcessor)3 RowProcessor (org.tribuo.data.columnar.RowProcessor)3 DoubleFieldProcessor (org.tribuo.data.columnar.processors.field.DoubleFieldProcessor)3 IdentityProcessor (org.tribuo.data.columnar.processors.field.IdentityProcessor)3 TextFieldProcessor (org.tribuo.data.columnar.processors.field.TextFieldProcessor)3 FieldResponseProcessor (org.tribuo.data.columnar.processors.response.FieldResponseProcessor)3 SimpleTextDataSource (org.tribuo.data.text.impl.SimpleTextDataSource)3 TextFeatureExtractorImpl (org.tribuo.data.text.impl.TextFeatureExtractorImpl)3 Pair (com.oracle.labs.mlrg.olcut.util.Pair)2 Test (org.junit.jupiter.api.Test)2 Feature (org.tribuo.Feature)2 MutableDataset (org.tribuo.MutableDataset)2 CSVDataSource (org.tribuo.data.csv.CSVDataSource)2 ArrayList (java.util.ArrayList)1