Search in sources :

Example 1 with DoubleFieldProcessor

use of org.tribuo.data.columnar.processors.field.DoubleFieldProcessor in project tribuo by oracle.

the class CSVLoader method loadDataSource.

/**
 * Loads a DataSource from the specified csv path.
 * <p>
 * The {@code responseNames} set is traversed in iteration order to emit outputs,
 * and should be an ordered set to ensure reproducibility.
 *
 * @param csvPath       The csv to load from.
 * @param responseNames The names of the response variables.
 * @param header        The header of the CSV if it's not present in the file.
 * @return A datasource containing the csv data.
 * @throws IOException If the disk read failed.
 */
public DataSource<T> loadDataSource(URL csvPath, Set<String> responseNames, String[] header) throws IOException {
    List<String> headers = header == null || header.length == 0 ? Collections.emptyList() : Arrays.asList(header);
    URI csvURI;
    // Extract headers and convert to URI
    try (CSVIterator itr = new CSVIterator(csvPath.toURI(), separator, quote)) {
        List<String> extractedHeaders = itr.getFields();
        if (extractedHeaders.isEmpty() && headers.isEmpty()) {
            throw new IllegalArgumentException("Failed to read headers from CSV, and none were supplied.");
        }
        if (headers.size() != 0) {
            if (extractedHeaders.size() != headers.size()) {
                throw new IllegalArgumentException("The csv contains " + extractedHeaders.size() + " fields, but only " + headers.size() + " headers were supplied.");
            }
        } else {
            headers = extractedHeaders;
        }
        csvURI = csvPath.toURI();
    } catch (URISyntaxException e) {
        throw new FileNotFoundException("Failed to read from URL '" + csvPath + "' as it could not be converted to a URI");
    }
    // Validate the responseNames
    if (responseNames.isEmpty()) {
        throw new IllegalArgumentException("At least one response name must be specified, but responseNames is empty.");
    }
    if (!headers.containsAll(responseNames)) {
        for (String s : responseNames) {
            if (!headers.contains(s)) {
                throw new IllegalArgumentException(String.format("Response %s not found in file %s", s, csvPath));
            }
        }
    }
    // Construct the row processor
    Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
    for (String field : headers) {
        if (!responseNames.contains(field)) {
            fieldProcessors.put(field, new DoubleFieldProcessor(field, true, true));
        }
    }
    boolean includeFieldName = responseNames.size() > 1;
    ResponseProcessor<T> responseProcessor = new FieldResponseProcessor<>(new ArrayList<>(responseNames), Collections.nCopies(responseNames.size(), ""), outputFactory, includeFieldName, false);
    RowProcessor<T> rowProcessor = new RowProcessor<>(responseProcessor, fieldProcessors);
    if (header != null) {
        // if headers are supplied then we assume they aren't present in the csv file.
        return new CSVDataSource<>(csvURI, rowProcessor, true, separator, quote, headers);
    } else {
        return new CSVDataSource<>(csvURI, rowProcessor, true, separator, quote);
    }
}
Also used : HashMap(java.util.HashMap) FileNotFoundException(java.io.FileNotFoundException) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) FieldResponseProcessor(org.tribuo.data.columnar.processors.response.FieldResponseProcessor) URISyntaxException(java.net.URISyntaxException) URI(java.net.URI) FieldProcessor(org.tribuo.data.columnar.FieldProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) RowProcessor(org.tribuo.data.columnar.RowProcessor)

Example 2 with DoubleFieldProcessor

use of org.tribuo.data.columnar.processors.field.DoubleFieldProcessor in project tribuo by oracle.

the class TestHdbscan method testInvocationCounter.

@Test
public void testInvocationCounter() {
    ClusteringFactory clusteringFactory = new ClusteringFactory();
    ResponseProcessor<ClusterID> emptyResponseProcessor = new EmptyResponseProcessor<>(clusteringFactory);
    Map<String, FieldProcessor> regexMappingProcessors = new HashMap<>();
    regexMappingProcessors.put("Feature1", new DoubleFieldProcessor("Feature1"));
    regexMappingProcessors.put("Feature2", new DoubleFieldProcessor("Feature2"));
    regexMappingProcessors.put("Feature3", new DoubleFieldProcessor("Feature3"));
    RowProcessor<ClusterID> rowProcessor = new RowProcessor<>(emptyResponseProcessor, regexMappingProcessors);
    CSVDataSource<ClusterID> csvSource = new CSVDataSource<>(Paths.get("src/test/resources/basic-gaussians.csv"), rowProcessor, false);
    Dataset<ClusterID> dataset = new MutableDataset<>(csvSource);
    HdbscanTrainer trainer = new HdbscanTrainer(7, DistanceType.L2, 7, 4, NeighboursQueryFactoryType.BRUTE_FORCE);
    for (int i = 0; i < 5; i++) {
        HdbscanModel model = trainer.train(dataset);
    }
    assertEquals(5, trainer.getInvocationCount());
    trainer.setInvocationCount(0);
    assertEquals(0, trainer.getInvocationCount());
    Model<ClusterID> model = trainer.train(dataset, Collections.emptyMap(), 3);
    assertEquals(4, trainer.getInvocationCount());
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) HashMap(java.util.HashMap) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) CSVDataSource(org.tribuo.data.csv.CSVDataSource) FieldProcessor(org.tribuo.data.columnar.FieldProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) EmptyResponseProcessor(org.tribuo.data.columnar.processors.response.EmptyResponseProcessor) ClusteringFactory(org.tribuo.clustering.ClusteringFactory) RowProcessor(org.tribuo.data.columnar.RowProcessor) MutableDataset(org.tribuo.MutableDataset) Test(org.junit.jupiter.api.Test)

Example 3 with DoubleFieldProcessor

use of org.tribuo.data.columnar.processors.field.DoubleFieldProcessor in project tribuo by oracle.

the class JsonDataSourceTest method buildRowProcessor.

private static RowProcessor<MockOutput> buildRowProcessor() {
    Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
    fieldProcessors.put("height", new DoubleFieldProcessor("height"));
    fieldProcessors.put("description", new TextFieldProcessor("description", new BasicPipeline(new BreakIteratorTokenizer(Locale.US), 2)));
    fieldProcessors.put("transport", new IdentityProcessor("transport"));
    Map<String, FieldProcessor> regexMappingProcessors = new HashMap<>();
    regexMappingProcessors.put("extra.*", new DoubleFieldProcessor("regex"));
    ResponseProcessor<MockOutput> responseProcessor = new FieldResponseProcessor<>("disposition", "UNK", new MockOutputFactory());
    List<FieldExtractor<?>> metadataExtractors = new ArrayList<>();
    metadataExtractors.add(new IntExtractor("id"));
    metadataExtractors.add(new DateExtractor("timestamp", "timestamp", "dd/MM/yyyy HH:mm"));
    return new RowProcessor<>(metadataExtractors, null, responseProcessor, fieldProcessors, regexMappingProcessors, Collections.emptySet());
}
Also used : TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) IntExtractor(org.tribuo.data.columnar.extractors.IntExtractor) DateExtractor(org.tribuo.data.columnar.extractors.DateExtractor) MockOutput(org.tribuo.test.MockOutput) HashMap(java.util.HashMap) MockOutputFactory(org.tribuo.test.MockOutputFactory) BasicPipeline(org.tribuo.data.text.impl.BasicPipeline) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) FieldResponseProcessor(org.tribuo.data.columnar.processors.response.FieldResponseProcessor) ArrayList(java.util.ArrayList) FieldProcessor(org.tribuo.data.columnar.FieldProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) FieldExtractor(org.tribuo.data.columnar.FieldExtractor) IdentityProcessor(org.tribuo.data.columnar.processors.field.IdentityProcessor) RowProcessor(org.tribuo.data.columnar.RowProcessor)

Example 4 with DoubleFieldProcessor

use of org.tribuo.data.columnar.processors.field.DoubleFieldProcessor in project tribuo by oracle.

the class LIMEColumnarTest method generateBinarisedDataset.

private Pair<RowProcessor<Label>, Dataset<Label>> generateBinarisedDataset() throws URISyntaxException {
    LabelFactory labelFactory = new LabelFactory();
    ResponseProcessor<Label> responseProcessor = new FieldResponseProcessor<>("Response", "N", labelFactory);
    Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
    fieldProcessors.put("A", new IdentityProcessor("A"));
    fieldProcessors.put("B", new DoubleFieldProcessor("B"));
    fieldProcessors.put("C", new DoubleFieldProcessor("C"));
    fieldProcessors.put("D", new IdentityProcessor("D"));
    fieldProcessors.put("TextField", new TextFieldProcessor("TextField", new BasicPipeline(tokenizer, 2)));
    RowProcessor<Label> rp = new RowProcessor<>(responseProcessor, fieldProcessors);
    CSVDataSource<Label> source = new CSVDataSource<>(LIMEColumnarTest.class.getResource("/org/tribuo/classification/explanations/lime/test-columnar.csv").toURI(), rp, true);
    Dataset<Label> dataset = new MutableDataset<>(source);
    return new Pair<>(rp, dataset);
}
Also used : TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) HashMap(java.util.HashMap) BasicPipeline(org.tribuo.data.text.impl.BasicPipeline) Label(org.tribuo.classification.Label) FieldResponseProcessor(org.tribuo.data.columnar.processors.response.FieldResponseProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) CSVDataSource(org.tribuo.data.csv.CSVDataSource) FieldProcessor(org.tribuo.data.columnar.FieldProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) LabelFactory(org.tribuo.classification.LabelFactory) IdentityProcessor(org.tribuo.data.columnar.processors.field.IdentityProcessor) RowProcessor(org.tribuo.data.columnar.RowProcessor) MutableDataset(org.tribuo.MutableDataset) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 5 with DoubleFieldProcessor

use of org.tribuo.data.columnar.processors.field.DoubleFieldProcessor in project tribuo by oracle.

the class LIMEColumnarTest method generateCategoricalDataset.

private Pair<RowProcessor<Label>, Dataset<Label>> generateCategoricalDataset() throws URISyntaxException {
    LabelFactory labelFactory = new LabelFactory();
    ResponseProcessor<Label> responseProcessor = new FieldResponseProcessor<>("Response", "N", labelFactory);
    Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
    fieldProcessors.put("A", new IdentityProcessor("A") {

        @Override
        public GeneratedFeatureType getFeatureType() {
            return GeneratedFeatureType.CATEGORICAL;
        }
    });
    fieldProcessors.put("B", new DoubleFieldProcessor("B"));
    fieldProcessors.put("C", new DoubleFieldProcessor("C"));
    fieldProcessors.put("D", new IdentityProcessor("D") {

        @Override
        public GeneratedFeatureType getFeatureType() {
            return GeneratedFeatureType.CATEGORICAL;
        }
    });
    fieldProcessors.put("TextField", new TextFieldProcessor("TextField", new BasicPipeline(tokenizer, 2)));
    RowProcessor<Label> rp = new RowProcessor<>(responseProcessor, fieldProcessors);
    CSVDataSource<Label> source = new CSVDataSource<>(LIMEColumnarTest.class.getResource("/org/tribuo/classification/explanations/lime/test-columnar.csv").toURI(), rp, true);
    Dataset<Label> dataset = new MutableDataset<>(source);
    return new Pair<>(rp, dataset);
}
Also used : TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) HashMap(java.util.HashMap) BasicPipeline(org.tribuo.data.text.impl.BasicPipeline) Label(org.tribuo.classification.Label) FieldResponseProcessor(org.tribuo.data.columnar.processors.response.FieldResponseProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) CSVDataSource(org.tribuo.data.csv.CSVDataSource) FieldProcessor(org.tribuo.data.columnar.FieldProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) LabelFactory(org.tribuo.classification.LabelFactory) IdentityProcessor(org.tribuo.data.columnar.processors.field.IdentityProcessor) RowProcessor(org.tribuo.data.columnar.RowProcessor) MutableDataset(org.tribuo.MutableDataset) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Aggregations

HashMap (java.util.HashMap)9 DoubleFieldProcessor (org.tribuo.data.columnar.processors.field.DoubleFieldProcessor)9 FieldProcessor (org.tribuo.data.columnar.FieldProcessor)8 RowProcessor (org.tribuo.data.columnar.RowProcessor)8 MutableDataset (org.tribuo.MutableDataset)6 CSVDataSource (org.tribuo.data.csv.CSVDataSource)6 Test (org.junit.jupiter.api.Test)5 ClusterID (org.tribuo.clustering.ClusterID)4 ClusteringFactory (org.tribuo.clustering.ClusteringFactory)4 IdentityProcessor (org.tribuo.data.columnar.processors.field.IdentityProcessor)4 TextFieldProcessor (org.tribuo.data.columnar.processors.field.TextFieldProcessor)4 EmptyResponseProcessor (org.tribuo.data.columnar.processors.response.EmptyResponseProcessor)4 FieldResponseProcessor (org.tribuo.data.columnar.processors.response.FieldResponseProcessor)4 BasicPipeline (org.tribuo.data.text.impl.BasicPipeline)3 Pair (com.oracle.labs.mlrg.olcut.util.Pair)2 ArrayList (java.util.ArrayList)2 Prediction (org.tribuo.Prediction)2 Label (org.tribuo.classification.Label)2 LabelFactory (org.tribuo.classification.LabelFactory)2 DateExtractor (org.tribuo.data.columnar.extractors.DateExtractor)2