use of org.tribuo.data.columnar.processors.field.DoubleFieldProcessor in project tribuo by oracle.
the class CSVLoader method loadDataSource.
/**
* Loads a DataSource from the specified csv path.
* <p>
* The {@code responseNames} set is traversed in iteration order to emit outputs,
* and should be an ordered set to ensure reproducibility.
*
* @param csvPath The csv to load from.
* @param responseNames The names of the response variables.
* @param header The header of the CSV if it's not present in the file.
* @return A datasource containing the csv data.
* @throws IOException If the disk read failed.
*/
public DataSource<T> loadDataSource(URL csvPath, Set<String> responseNames, String[] header) throws IOException {
List<String> headers = header == null || header.length == 0 ? Collections.emptyList() : Arrays.asList(header);
URI csvURI;
// Extract headers and convert to URI
try (CSVIterator itr = new CSVIterator(csvPath.toURI(), separator, quote)) {
List<String> extractedHeaders = itr.getFields();
if (extractedHeaders.isEmpty() && headers.isEmpty()) {
throw new IllegalArgumentException("Failed to read headers from CSV, and none were supplied.");
}
if (headers.size() != 0) {
if (extractedHeaders.size() != headers.size()) {
throw new IllegalArgumentException("The csv contains " + extractedHeaders.size() + " fields, but only " + headers.size() + " headers were supplied.");
}
} else {
headers = extractedHeaders;
}
csvURI = csvPath.toURI();
} catch (URISyntaxException e) {
throw new FileNotFoundException("Failed to read from URL '" + csvPath + "' as it could not be converted to a URI");
}
// Validate the responseNames
if (responseNames.isEmpty()) {
throw new IllegalArgumentException("At least one response name must be specified, but responseNames is empty.");
}
if (!headers.containsAll(responseNames)) {
for (String s : responseNames) {
if (!headers.contains(s)) {
throw new IllegalArgumentException(String.format("Response %s not found in file %s", s, csvPath));
}
}
}
// Construct the row processor
Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
for (String field : headers) {
if (!responseNames.contains(field)) {
fieldProcessors.put(field, new DoubleFieldProcessor(field, true, true));
}
}
boolean includeFieldName = responseNames.size() > 1;
ResponseProcessor<T> responseProcessor = new FieldResponseProcessor<>(new ArrayList<>(responseNames), Collections.nCopies(responseNames.size(), ""), outputFactory, includeFieldName, false);
RowProcessor<T> rowProcessor = new RowProcessor<>(responseProcessor, fieldProcessors);
if (header != null) {
// if headers are supplied then we assume they aren't present in the csv file.
return new CSVDataSource<>(csvURI, rowProcessor, true, separator, quote, headers);
} else {
return new CSVDataSource<>(csvURI, rowProcessor, true, separator, quote);
}
}
use of org.tribuo.data.columnar.processors.field.DoubleFieldProcessor in project tribuo by oracle.
the class TestHdbscan method testInvocationCounter.
@Test
public void testInvocationCounter() {
ClusteringFactory clusteringFactory = new ClusteringFactory();
ResponseProcessor<ClusterID> emptyResponseProcessor = new EmptyResponseProcessor<>(clusteringFactory);
Map<String, FieldProcessor> regexMappingProcessors = new HashMap<>();
regexMappingProcessors.put("Feature1", new DoubleFieldProcessor("Feature1"));
regexMappingProcessors.put("Feature2", new DoubleFieldProcessor("Feature2"));
regexMappingProcessors.put("Feature3", new DoubleFieldProcessor("Feature3"));
RowProcessor<ClusterID> rowProcessor = new RowProcessor<>(emptyResponseProcessor, regexMappingProcessors);
CSVDataSource<ClusterID> csvSource = new CSVDataSource<>(Paths.get("src/test/resources/basic-gaussians.csv"), rowProcessor, false);
Dataset<ClusterID> dataset = new MutableDataset<>(csvSource);
HdbscanTrainer trainer = new HdbscanTrainer(7, DistanceType.L2, 7, 4, NeighboursQueryFactoryType.BRUTE_FORCE);
for (int i = 0; i < 5; i++) {
HdbscanModel model = trainer.train(dataset);
}
assertEquals(5, trainer.getInvocationCount());
trainer.setInvocationCount(0);
assertEquals(0, trainer.getInvocationCount());
Model<ClusterID> model = trainer.train(dataset, Collections.emptyMap(), 3);
assertEquals(4, trainer.getInvocationCount());
}
use of org.tribuo.data.columnar.processors.field.DoubleFieldProcessor in project tribuo by oracle.
the class JsonDataSourceTest method buildRowProcessor.
private static RowProcessor<MockOutput> buildRowProcessor() {
Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
fieldProcessors.put("height", new DoubleFieldProcessor("height"));
fieldProcessors.put("description", new TextFieldProcessor("description", new BasicPipeline(new BreakIteratorTokenizer(Locale.US), 2)));
fieldProcessors.put("transport", new IdentityProcessor("transport"));
Map<String, FieldProcessor> regexMappingProcessors = new HashMap<>();
regexMappingProcessors.put("extra.*", new DoubleFieldProcessor("regex"));
ResponseProcessor<MockOutput> responseProcessor = new FieldResponseProcessor<>("disposition", "UNK", new MockOutputFactory());
List<FieldExtractor<?>> metadataExtractors = new ArrayList<>();
metadataExtractors.add(new IntExtractor("id"));
metadataExtractors.add(new DateExtractor("timestamp", "timestamp", "dd/MM/yyyy HH:mm"));
return new RowProcessor<>(metadataExtractors, null, responseProcessor, fieldProcessors, regexMappingProcessors, Collections.emptySet());
}
use of org.tribuo.data.columnar.processors.field.DoubleFieldProcessor in project tribuo by oracle.
the class LIMEColumnarTest method generateBinarisedDataset.
private Pair<RowProcessor<Label>, Dataset<Label>> generateBinarisedDataset() throws URISyntaxException {
LabelFactory labelFactory = new LabelFactory();
ResponseProcessor<Label> responseProcessor = new FieldResponseProcessor<>("Response", "N", labelFactory);
Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
fieldProcessors.put("A", new IdentityProcessor("A"));
fieldProcessors.put("B", new DoubleFieldProcessor("B"));
fieldProcessors.put("C", new DoubleFieldProcessor("C"));
fieldProcessors.put("D", new IdentityProcessor("D"));
fieldProcessors.put("TextField", new TextFieldProcessor("TextField", new BasicPipeline(tokenizer, 2)));
RowProcessor<Label> rp = new RowProcessor<>(responseProcessor, fieldProcessors);
CSVDataSource<Label> source = new CSVDataSource<>(LIMEColumnarTest.class.getResource("/org/tribuo/classification/explanations/lime/test-columnar.csv").toURI(), rp, true);
Dataset<Label> dataset = new MutableDataset<>(source);
return new Pair<>(rp, dataset);
}
use of org.tribuo.data.columnar.processors.field.DoubleFieldProcessor in project tribuo by oracle.
the class LIMEColumnarTest method generateCategoricalDataset.
private Pair<RowProcessor<Label>, Dataset<Label>> generateCategoricalDataset() throws URISyntaxException {
LabelFactory labelFactory = new LabelFactory();
ResponseProcessor<Label> responseProcessor = new FieldResponseProcessor<>("Response", "N", labelFactory);
Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
fieldProcessors.put("A", new IdentityProcessor("A") {
@Override
public GeneratedFeatureType getFeatureType() {
return GeneratedFeatureType.CATEGORICAL;
}
});
fieldProcessors.put("B", new DoubleFieldProcessor("B"));
fieldProcessors.put("C", new DoubleFieldProcessor("C"));
fieldProcessors.put("D", new IdentityProcessor("D") {
@Override
public GeneratedFeatureType getFeatureType() {
return GeneratedFeatureType.CATEGORICAL;
}
});
fieldProcessors.put("TextField", new TextFieldProcessor("TextField", new BasicPipeline(tokenizer, 2)));
RowProcessor<Label> rp = new RowProcessor<>(responseProcessor, fieldProcessors);
CSVDataSource<Label> source = new CSVDataSource<>(LIMEColumnarTest.class.getResource("/org/tribuo/classification/explanations/lime/test-columnar.csv").toURI(), rp, true);
Dataset<Label> dataset = new MutableDataset<>(source);
return new Pair<>(rp, dataset);
}
Aggregations