use of org.tribuo.datasource.LibSVMDataSource in project tribuo by oracle.
the class TestXGBoostExternalModel method testMNIST.
@Test
public void testMNIST() throws IOException, URISyntaxException {
LabelFactory labelFactory = new LabelFactory();
// Loads regular MNIST
URL data = TestXGBoostExternalModel.class.getResource("/org/tribuo/classification/xgboost/mnist_test_head.libsvm");
DataSource<Label> transposedMNIST = new LibSVMDataSource<>(data, labelFactory, false, 784);
Dataset<Label> dataset = new MutableDataset<>(transposedMNIST);
Map<String, Integer> featureMapping = new HashMap<>();
for (int i = 0; i < 784; i++) {
// This MNIST model has the feature indices transposed to test a non-trivial mapping.
int id = (783 - i);
featureMapping.put(String.format("%03d", i), id);
}
Map<Label, Integer> outputMapping = new HashMap<>();
for (Label l : dataset.getOutputInfo().getDomain()) {
outputMapping.put(l, Integer.parseInt(l.getLabel()));
}
XGBoostClassificationConverter labelConverter = new XGBoostClassificationConverter();
Path testResource = Paths.get(TestXGBoostExternalModel.class.getResource("/org/tribuo/classification/xgboost/xgb_mnist.xgb").toURI());
XGBoostExternalModel<Label> transposedMNISTXGB = XGBoostExternalModel.createXGBoostModel(labelFactory, featureMapping, outputMapping, labelConverter, testResource);
LabelEvaluation evaluation = labelFactory.getEvaluator().evaluate(transposedMNISTXGB, transposedMNIST);
assertEquals(1.0, evaluation.accuracy(), 1e-6);
assertEquals(0.0, evaluation.balancedErrorRate(), 1e-6);
Helpers.testModelSerialization(transposedMNISTXGB, Label.class);
}
use of org.tribuo.datasource.LibSVMDataSource in project tribuo by oracle.
the class TestOnnxRuntime method testTransposedMNIST.
/**
* This test checks that the model works with the identity feature mapping when the data is transposed.
* @throws IOException If it failed to read the file.
* @throws OrtException If onnx-runtime failed.
* @throws URISyntaxException If the URL failed to parse.
*/
@Test
public void testTransposedMNIST() throws IOException, OrtException, URISyntaxException {
LabelFactory labelFactory = new LabelFactory();
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
// Loads transposed MNIST
URL data = TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/transposed_mnist_test_head.libsvm");
DataSource<Label> transposedMNIST = new LibSVMDataSource<>(data, labelFactory, true, 784);
Dataset<Label> dataset = new MutableDataset<>(transposedMNIST);
Map<String, Integer> featureMapping = new HashMap<>();
for (int i = 0; i < 784; i++) {
featureMapping.put(String.format("%03d", i), i);
}
Map<Label, Integer> outputMapping = new HashMap<>();
for (Label l : dataset.getOutputInfo().getDomain()) {
outputMapping.put(l, Integer.parseInt(l.getLabel()));
}
DenseTransformer denseTransformer = new DenseTransformer();
LabelTransformer labelTransformer = new LabelTransformer();
Path testResource = Paths.get(TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/lr_mnist.onnx").toURI());
ONNXExternalModel<Label> transposedMNISTLR = ONNXExternalModel.createOnnxModel(labelFactory, featureMapping, outputMapping, denseTransformer, labelTransformer, sessionOptions, testResource, "float_input");
// This model doesn't have a free batch size parameter on the output
transposedMNISTLR.setBatchSize(1);
LabelEvaluation evaluation = labelFactory.getEvaluator().evaluate(transposedMNISTLR, transposedMNIST);
assertEquals(0.967741, evaluation.accuracy(), 1e-6);
assertEquals(0.024285, evaluation.balancedErrorRate(), 1e-6);
}
}
use of org.tribuo.datasource.LibSVMDataSource in project tribuo by oracle.
the class TestOnnxRuntime method testCNNMNIST.
/**
* This test checks that the ImageTransformer works and we can process float matrices through the LabelTransformer.
* @throws IOException If it failed to read the file.
* @throws OrtException If onnx-runtime failed.
* @throws URISyntaxException If the URL failed to parse.
*/
@Test
public void testCNNMNIST() throws IOException, OrtException, URISyntaxException {
LabelFactory labelFactory = new LabelFactory();
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
// Loads regular MNIST
URL data = TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/mnist_test_head.libsvm");
DataSource<Label> mnistTest = new LibSVMDataSource<>(data, labelFactory, false, 784);
Dataset<Label> dataset = new MutableDataset<>(mnistTest);
Map<String, Integer> featureMapping = new HashMap<>();
for (int i = 0; i < 784; i++) {
featureMapping.put(String.format("%03d", i), i);
}
Map<Label, Integer> outputMapping = new HashMap<>();
for (Label l : dataset.getOutputInfo().getDomain()) {
outputMapping.put(l, Integer.parseInt(l.getLabel()));
}
ImageTransformer imageTransformer = new ImageTransformer(1, 28, 28);
LabelTransformer labelTransformer = new LabelTransformer();
Path testResource = Paths.get(TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/cnn_mnist.onnx").toURI());
ONNXExternalModel<Label> cnnModel = ONNXExternalModel.createOnnxModel(labelFactory, featureMapping, outputMapping, imageTransformer, labelTransformer, sessionOptions, testResource, "input_image");
LabelEvaluation evaluation = labelFactory.getEvaluator().evaluate(cnnModel, mnistTest);
// CNNs are good at MNIST
assertEquals(1.0, evaluation.accuracy(), 1e-6);
assertEquals(0.0, evaluation.balancedErrorRate(), 1e-6);
}
}
use of org.tribuo.datasource.LibSVMDataSource in project tribuo by oracle.
the class TestOnnxRuntime method testMNIST.
/**
* This test checks that the model works when using the feature mapping logic as the model was trained with
* a transposed feature mapping, but the data is loaded in using the standard mapping.
* @throws IOException If it failed to read the file.
* @throws OrtException If onnx-runtime failed.
* @throws URISyntaxException If the URL failed to parse.
*/
@Test
public void testMNIST() throws IOException, OrtException, URISyntaxException {
LabelFactory labelFactory = new LabelFactory();
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
// Loads regular MNIST
URL data = TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/mnist_test_head.libsvm");
DataSource<Label> transposedMNIST = new LibSVMDataSource<>(data, labelFactory, false, 784);
Dataset<Label> dataset = new MutableDataset<>(transposedMNIST);
Map<String, Integer> featureMapping = new HashMap<>();
for (int i = 0; i < 784; i++) {
// This MNIST model has the feature indices transposed to test a non-trivial mapping.
int id = (783 - i);
featureMapping.put(String.format("%03d", i), id);
}
Map<Label, Integer> outputMapping = new HashMap<>();
for (Label l : dataset.getOutputInfo().getDomain()) {
outputMapping.put(l, Integer.parseInt(l.getLabel()));
}
DenseTransformer denseTransformer = new DenseTransformer();
LabelTransformer labelTransformer = new LabelTransformer();
Path testResource = Paths.get(TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/lr_mnist.onnx").toURI());
ONNXExternalModel<Label> transposedMNISTLR = ONNXExternalModel.createOnnxModel(labelFactory, featureMapping, outputMapping, denseTransformer, labelTransformer, sessionOptions, testResource, "float_input");
// This model doesn't have a free batch size parameter on the output
transposedMNISTLR.setBatchSize(1);
LabelEvaluation evaluation = labelFactory.getEvaluator().evaluate(transposedMNISTLR, transposedMNIST);
assertEquals(0.967741, evaluation.accuracy(), 1e-6);
assertEquals(0.024285, evaluation.balancedErrorRate(), 1e-6);
Helpers.testModelSerialization(transposedMNISTLR, Label.class);
}
}
use of org.tribuo.datasource.LibSVMDataSource in project tribuo by oracle.
the class DataOptions method load.
/**
* Loads the training and testing data from {@link #trainingPath} and {@link #testingPath}
* according to the other parameters specified in this class.
* @param outputFactory The output factory to use to process the inputs.
* @param <T> The dataset output type.
* @return A pair containing the training and testing datasets. The training dataset is element 'A' and the
* testing dataset is element 'B'.
* @throws IOException If the paths could not be loaded.
*/
public <T extends Output<T>> Pair<Dataset<T>, Dataset<T>> load(OutputFactory<T> outputFactory) throws IOException {
logger.info(String.format("Loading data from %s", trainingPath));
Dataset<T> train;
Dataset<T> test;
char separator;
switch(inputFormat) {
case SERIALIZED:
//
// Load Tribuo serialised datasets.
logger.info("Deserialising dataset from " + trainingPath);
try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(trainingPath.toFile())));
ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(testingPath.toFile())))) {
@SuppressWarnings("unchecked") Dataset<T> tmp = (Dataset<T>) ois.readObject();
train = tmp;
if (minCount > 0) {
logger.info("Found " + train.getFeatureIDMap().size() + " features");
logger.info("Removing features that occur fewer than " + minCount + " times.");
train = new MinimumCardinalityDataset<>(train, minCount);
}
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
@SuppressWarnings("unchecked") Dataset<T> deserTest = (Dataset<T>) oits.readObject();
test = new ImmutableDataset<>(deserTest, deserTest.getSourceProvenance(), deserTest.getOutputFactory(), train.getFeatureIDMap(), train.getOutputIDInfo(), true);
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Unknown class in serialised files", e);
}
break;
case LIBSVM:
//
// Load the libsvm text-based data format.
LibSVMDataSource<T> trainSVMSource = new LibSVMDataSource<>(trainingPath, outputFactory);
train = new MutableDataset<>(trainSVMSource);
boolean zeroIndexed = trainSVMSource.isZeroIndexed();
int maxFeatureID = trainSVMSource.getMaxFeatureID();
if (minCount > 0) {
logger.info("Removing features that occur fewer than " + minCount + " times.");
train = new MinimumCardinalityDataset<>(train, minCount);
}
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
test = new ImmutableDataset<>(new LibSVMDataSource<>(testingPath, outputFactory, zeroIndexed, maxFeatureID), train.getFeatureIDMap(), train.getOutputIDInfo(), false);
break;
case TEXT:
//
// Using a simple Java break iterator to generate ngram features.
TextFeatureExtractor<T> extractor;
if (hashDim > 0) {
extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting, hashDim));
} else {
extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting));
}
TextDataSource<T> trainSource = new SimpleTextDataSource<>(trainingPath, outputFactory, extractor);
train = new MutableDataset<>(trainSource);
if (minCount > 0) {
logger.info("Removing features that occur fewer than " + minCount + " times.");
train = new MinimumCardinalityDataset<>(train, minCount);
}
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
TextDataSource<T> testSource = new SimpleTextDataSource<>(testingPath, outputFactory, extractor);
test = new ImmutableDataset<>(testSource, train.getFeatureIDMap(), train.getOutputIDInfo(), false);
break;
case CSV:
// Load the data using the simple CSV loader
if (csvResponseName == null) {
throw new IllegalArgumentException("Please supply a response column name");
}
separator = delimiter.value;
CSVLoader<T> loader = new CSVLoader<>(separator, outputFactory);
train = new MutableDataset<>(loader.loadDataSource(trainingPath, csvResponseName));
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
test = new MutableDataset<>(loader.loadDataSource(testingPath, csvResponseName));
break;
case COLUMNAR:
if (rowProcessor == null) {
throw new IllegalArgumentException("Please supply a RowProcessor");
}
OutputFactory<?> rowOutputFactory = rowProcessor.getResponseProcessor().getOutputFactory();
if (!rowOutputFactory.equals(outputFactory)) {
throw new IllegalArgumentException("The RowProcessor doesn't use the same kind of OutputFactory as the one supplied. RowProcessor has " + rowOutputFactory.getClass().getSimpleName() + ", supplied " + outputFactory.getClass().getName());
}
// checked by the if statement above
@SuppressWarnings("unchecked") RowProcessor<T> typedRowProcessor = (RowProcessor<T>) rowProcessor;
separator = delimiter.value;
train = new MutableDataset<>(new CSVDataSource<>(trainingPath, typedRowProcessor, true, separator, csvQuoteChar));
logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
test = new MutableDataset<>(new CSVDataSource<>(testingPath, typedRowProcessor, true, separator, csvQuoteChar));
break;
default:
throw new IllegalArgumentException("Unsupported input format " + inputFormat);
}
logger.info(String.format("Loaded %d testing examples", test.size()));
if (scaleFeatures) {
logger.info("Fitting feature scaling");
TransformationMap map = new TransformationMap(Collections.singletonList(new LinearScalingTransformation()));
TransformerMap transformers = train.createTransformers(map, scaleIncZeros);
logger.info("Applying scaling to training dataset");
train = transformers.transformDataset(train);
logger.info("Applying scaling to testing dataset");
test = transformers.transformDataset(test);
}
return new Pair<>(train, test);
}
Aggregations