Search in sources :

Example 1 with LibSVMDataSource

use of org.tribuo.datasource.LibSVMDataSource in project tribuo by oracle.

the class TestXGBoostExternalModel method testMNIST.

@Test
public void testMNIST() throws IOException, URISyntaxException {
    LabelFactory labelFactory = new LabelFactory();
    // Loads regular MNIST
    URL data = TestXGBoostExternalModel.class.getResource("/org/tribuo/classification/xgboost/mnist_test_head.libsvm");
    DataSource<Label> transposedMNIST = new LibSVMDataSource<>(data, labelFactory, false, 784);
    Dataset<Label> dataset = new MutableDataset<>(transposedMNIST);
    Map<String, Integer> featureMapping = new HashMap<>();
    for (int i = 0; i < 784; i++) {
        // This MNIST model has the feature indices transposed to test a non-trivial mapping.
        int id = (783 - i);
        featureMapping.put(String.format("%03d", i), id);
    }
    Map<Label, Integer> outputMapping = new HashMap<>();
    for (Label l : dataset.getOutputInfo().getDomain()) {
        outputMapping.put(l, Integer.parseInt(l.getLabel()));
    }
    XGBoostClassificationConverter labelConverter = new XGBoostClassificationConverter();
    Path testResource = Paths.get(TestXGBoostExternalModel.class.getResource("/org/tribuo/classification/xgboost/xgb_mnist.xgb").toURI());
    XGBoostExternalModel<Label> transposedMNISTXGB = XGBoostExternalModel.createXGBoostModel(labelFactory, featureMapping, outputMapping, labelConverter, testResource);
    LabelEvaluation evaluation = labelFactory.getEvaluator().evaluate(transposedMNISTXGB, transposedMNIST);
    assertEquals(1.0, evaluation.accuracy(), 1e-6);
    assertEquals(0.0, evaluation.balancedErrorRate(), 1e-6);
    Helpers.testModelSerialization(transposedMNISTXGB, Label.class);
}
Also used : Path(java.nio.file.Path) HashMap(java.util.HashMap) Label(org.tribuo.classification.Label) URL(java.net.URL) LabelFactory(org.tribuo.classification.LabelFactory) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) LibSVMDataSource(org.tribuo.datasource.LibSVMDataSource) MutableDataset(org.tribuo.MutableDataset) Test(org.junit.jupiter.api.Test)

Example 2 with LibSVMDataSource

use of org.tribuo.datasource.LibSVMDataSource in project tribuo by oracle.

the class TestOnnxRuntime method testTransposedMNIST.

/**
 * This test checks that the model works with the identity feature mapping when the data is transposed.
 * @throws IOException If it failed to read the file.
 * @throws OrtException If onnx-runtime failed.
 * @throws URISyntaxException If the URL failed to parse.
 */
@Test
public void testTransposedMNIST() throws IOException, OrtException, URISyntaxException {
    LabelFactory labelFactory = new LabelFactory();
    try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        // Loads transposed MNIST
        URL data = TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/transposed_mnist_test_head.libsvm");
        DataSource<Label> transposedMNIST = new LibSVMDataSource<>(data, labelFactory, true, 784);
        Dataset<Label> dataset = new MutableDataset<>(transposedMNIST);
        Map<String, Integer> featureMapping = new HashMap<>();
        for (int i = 0; i < 784; i++) {
            featureMapping.put(String.format("%03d", i), i);
        }
        Map<Label, Integer> outputMapping = new HashMap<>();
        for (Label l : dataset.getOutputInfo().getDomain()) {
            outputMapping.put(l, Integer.parseInt(l.getLabel()));
        }
        DenseTransformer denseTransformer = new DenseTransformer();
        LabelTransformer labelTransformer = new LabelTransformer();
        Path testResource = Paths.get(TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/lr_mnist.onnx").toURI());
        ONNXExternalModel<Label> transposedMNISTLR = ONNXExternalModel.createOnnxModel(labelFactory, featureMapping, outputMapping, denseTransformer, labelTransformer, sessionOptions, testResource, "float_input");
        // This model doesn't have a free batch size parameter on the output
        transposedMNISTLR.setBatchSize(1);
        LabelEvaluation evaluation = labelFactory.getEvaluator().evaluate(transposedMNISTLR, transposedMNIST);
        assertEquals(0.967741, evaluation.accuracy(), 1e-6);
        assertEquals(0.024285, evaluation.balancedErrorRate(), 1e-6);
    }
}
Also used : Path(java.nio.file.Path) HashMap(java.util.HashMap) Label(org.tribuo.classification.Label) URL(java.net.URL) OrtEnvironment(ai.onnxruntime.OrtEnvironment) OrtSession(ai.onnxruntime.OrtSession) LabelFactory(org.tribuo.classification.LabelFactory) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) LibSVMDataSource(org.tribuo.datasource.LibSVMDataSource) MutableDataset(org.tribuo.MutableDataset) Test(org.junit.jupiter.api.Test)

Example 3 with LibSVMDataSource

use of org.tribuo.datasource.LibSVMDataSource in project tribuo by oracle.

the class TestOnnxRuntime method testCNNMNIST.

/**
 * This test checks that the ImageTransformer works and we can process float matrices through the LabelTransformer.
 * @throws IOException If it failed to read the file.
 * @throws OrtException If onnx-runtime failed.
 * @throws URISyntaxException If the URL failed to parse.
 */
@Test
public void testCNNMNIST() throws IOException, OrtException, URISyntaxException {
    LabelFactory labelFactory = new LabelFactory();
    try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        // Loads regular MNIST
        URL data = TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/mnist_test_head.libsvm");
        DataSource<Label> mnistTest = new LibSVMDataSource<>(data, labelFactory, false, 784);
        Dataset<Label> dataset = new MutableDataset<>(mnistTest);
        Map<String, Integer> featureMapping = new HashMap<>();
        for (int i = 0; i < 784; i++) {
            featureMapping.put(String.format("%03d", i), i);
        }
        Map<Label, Integer> outputMapping = new HashMap<>();
        for (Label l : dataset.getOutputInfo().getDomain()) {
            outputMapping.put(l, Integer.parseInt(l.getLabel()));
        }
        ImageTransformer imageTransformer = new ImageTransformer(1, 28, 28);
        LabelTransformer labelTransformer = new LabelTransformer();
        Path testResource = Paths.get(TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/cnn_mnist.onnx").toURI());
        ONNXExternalModel<Label> cnnModel = ONNXExternalModel.createOnnxModel(labelFactory, featureMapping, outputMapping, imageTransformer, labelTransformer, sessionOptions, testResource, "input_image");
        LabelEvaluation evaluation = labelFactory.getEvaluator().evaluate(cnnModel, mnistTest);
        // CNNs are good at MNIST
        assertEquals(1.0, evaluation.accuracy(), 1e-6);
        assertEquals(0.0, evaluation.balancedErrorRate(), 1e-6);
    }
}
Also used : Path(java.nio.file.Path) HashMap(java.util.HashMap) Label(org.tribuo.classification.Label) URL(java.net.URL) OrtEnvironment(ai.onnxruntime.OrtEnvironment) OrtSession(ai.onnxruntime.OrtSession) LabelFactory(org.tribuo.classification.LabelFactory) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) LibSVMDataSource(org.tribuo.datasource.LibSVMDataSource) MutableDataset(org.tribuo.MutableDataset) Test(org.junit.jupiter.api.Test)

Example 4 with LibSVMDataSource

use of org.tribuo.datasource.LibSVMDataSource in project tribuo by oracle.

the class TestOnnxRuntime method testMNIST.

/**
 * This test checks that the model works when using the feature mapping logic as the model was trained with
 * a transposed feature mapping, but the data is loaded in using the standard mapping.
 * @throws IOException If it failed to read the file.
 * @throws OrtException If onnx-runtime failed.
 * @throws URISyntaxException If the URL failed to parse.
 */
@Test
public void testMNIST() throws IOException, OrtException, URISyntaxException {
    LabelFactory labelFactory = new LabelFactory();
    try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        // Loads regular MNIST
        URL data = TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/mnist_test_head.libsvm");
        DataSource<Label> transposedMNIST = new LibSVMDataSource<>(data, labelFactory, false, 784);
        Dataset<Label> dataset = new MutableDataset<>(transposedMNIST);
        Map<String, Integer> featureMapping = new HashMap<>();
        for (int i = 0; i < 784; i++) {
            // This MNIST model has the feature indices transposed to test a non-trivial mapping.
            int id = (783 - i);
            featureMapping.put(String.format("%03d", i), id);
        }
        Map<Label, Integer> outputMapping = new HashMap<>();
        for (Label l : dataset.getOutputInfo().getDomain()) {
            outputMapping.put(l, Integer.parseInt(l.getLabel()));
        }
        DenseTransformer denseTransformer = new DenseTransformer();
        LabelTransformer labelTransformer = new LabelTransformer();
        Path testResource = Paths.get(TestOnnxRuntime.class.getResource("/org/tribuo/interop/onnx/lr_mnist.onnx").toURI());
        ONNXExternalModel<Label> transposedMNISTLR = ONNXExternalModel.createOnnxModel(labelFactory, featureMapping, outputMapping, denseTransformer, labelTransformer, sessionOptions, testResource, "float_input");
        // This model doesn't have a free batch size parameter on the output
        transposedMNISTLR.setBatchSize(1);
        LabelEvaluation evaluation = labelFactory.getEvaluator().evaluate(transposedMNISTLR, transposedMNIST);
        assertEquals(0.967741, evaluation.accuracy(), 1e-6);
        assertEquals(0.024285, evaluation.balancedErrorRate(), 1e-6);
        Helpers.testModelSerialization(transposedMNISTLR, Label.class);
    }
}
Also used : Path(java.nio.file.Path) HashMap(java.util.HashMap) Label(org.tribuo.classification.Label) URL(java.net.URL) OrtEnvironment(ai.onnxruntime.OrtEnvironment) OrtSession(ai.onnxruntime.OrtSession) LabelFactory(org.tribuo.classification.LabelFactory) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) LibSVMDataSource(org.tribuo.datasource.LibSVMDataSource) MutableDataset(org.tribuo.MutableDataset) Test(org.junit.jupiter.api.Test)

Example 5 with LibSVMDataSource

use of org.tribuo.datasource.LibSVMDataSource in project tribuo by oracle.

the class DataOptions method load.

/**
 * Loads the training and testing data from {@link #trainingPath} and {@link #testingPath}
 * according to the other parameters specified in this class.
 * @param outputFactory The output factory to use to process the inputs.
 * @param <T> The dataset output type.
 * @return A pair containing the training and testing datasets. The training dataset is element 'A' and the
 * testing dataset is element 'B'.
 * @throws IOException If the paths could not be loaded.
 */
public <T extends Output<T>> Pair<Dataset<T>, Dataset<T>> load(OutputFactory<T> outputFactory) throws IOException {
    logger.info(String.format("Loading data from %s", trainingPath));
    Dataset<T> train;
    Dataset<T> test;
    char separator;
    switch(inputFormat) {
        case SERIALIZED:
            // 
            // Load Tribuo serialised datasets.
            logger.info("Deserialising dataset from " + trainingPath);
            try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(trainingPath.toFile())));
                ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(testingPath.toFile())))) {
                @SuppressWarnings("unchecked") Dataset<T> tmp = (Dataset<T>) ois.readObject();
                train = tmp;
                if (minCount > 0) {
                    logger.info("Found " + train.getFeatureIDMap().size() + " features");
                    logger.info("Removing features that occur fewer than " + minCount + " times.");
                    train = new MinimumCardinalityDataset<>(train, minCount);
                }
                logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
                logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
                @SuppressWarnings("unchecked") Dataset<T> deserTest = (Dataset<T>) oits.readObject();
                test = new ImmutableDataset<>(deserTest, deserTest.getSourceProvenance(), deserTest.getOutputFactory(), train.getFeatureIDMap(), train.getOutputIDInfo(), true);
            } catch (ClassNotFoundException e) {
                throw new IllegalArgumentException("Unknown class in serialised files", e);
            }
            break;
        case LIBSVM:
            // 
            // Load the libsvm text-based data format.
            LibSVMDataSource<T> trainSVMSource = new LibSVMDataSource<>(trainingPath, outputFactory);
            train = new MutableDataset<>(trainSVMSource);
            boolean zeroIndexed = trainSVMSource.isZeroIndexed();
            int maxFeatureID = trainSVMSource.getMaxFeatureID();
            if (minCount > 0) {
                logger.info("Removing features that occur fewer than " + minCount + " times.");
                train = new MinimumCardinalityDataset<>(train, minCount);
            }
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new ImmutableDataset<>(new LibSVMDataSource<>(testingPath, outputFactory, zeroIndexed, maxFeatureID), train.getFeatureIDMap(), train.getOutputIDInfo(), false);
            break;
        case TEXT:
            // 
            // Using a simple Java break iterator to generate ngram features.
            TextFeatureExtractor<T> extractor;
            if (hashDim > 0) {
                extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting, hashDim));
            } else {
                extractor = new TextFeatureExtractorImpl<>(new TokenPipeline(new BreakIteratorTokenizer(Locale.US), ngram, termCounting));
            }
            TextDataSource<T> trainSource = new SimpleTextDataSource<>(trainingPath, outputFactory, extractor);
            train = new MutableDataset<>(trainSource);
            if (minCount > 0) {
                logger.info("Removing features that occur fewer than " + minCount + " times.");
                train = new MinimumCardinalityDataset<>(train, minCount);
            }
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            TextDataSource<T> testSource = new SimpleTextDataSource<>(testingPath, outputFactory, extractor);
            test = new ImmutableDataset<>(testSource, train.getFeatureIDMap(), train.getOutputIDInfo(), false);
            break;
        case CSV:
            // Load the data using the simple CSV loader
            if (csvResponseName == null) {
                throw new IllegalArgumentException("Please supply a response column name");
            }
            separator = delimiter.value;
            CSVLoader<T> loader = new CSVLoader<>(separator, outputFactory);
            train = new MutableDataset<>(loader.loadDataSource(trainingPath, csvResponseName));
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new MutableDataset<>(loader.loadDataSource(testingPath, csvResponseName));
            break;
        case COLUMNAR:
            if (rowProcessor == null) {
                throw new IllegalArgumentException("Please supply a RowProcessor");
            }
            OutputFactory<?> rowOutputFactory = rowProcessor.getResponseProcessor().getOutputFactory();
            if (!rowOutputFactory.equals(outputFactory)) {
                throw new IllegalArgumentException("The RowProcessor doesn't use the same kind of OutputFactory as the one supplied. RowProcessor has " + rowOutputFactory.getClass().getSimpleName() + ", supplied " + outputFactory.getClass().getName());
            }
            // checked by the if statement above
            @SuppressWarnings("unchecked") RowProcessor<T> typedRowProcessor = (RowProcessor<T>) rowProcessor;
            separator = delimiter.value;
            train = new MutableDataset<>(new CSVDataSource<>(trainingPath, typedRowProcessor, true, separator, csvQuoteChar));
            logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
            logger.info("Found " + train.getFeatureIDMap().size() + " features, and " + train.getOutputInfo().size() + " response dimensions");
            test = new MutableDataset<>(new CSVDataSource<>(testingPath, typedRowProcessor, true, separator, csvQuoteChar));
            break;
        default:
            throw new IllegalArgumentException("Unsupported input format " + inputFormat);
    }
    logger.info(String.format("Loaded %d testing examples", test.size()));
    if (scaleFeatures) {
        logger.info("Fitting feature scaling");
        TransformationMap map = new TransformationMap(Collections.singletonList(new LinearScalingTransformation()));
        TransformerMap transformers = train.createTransformers(map, scaleIncZeros);
        logger.info("Applying scaling to training dataset");
        train = transformers.transformDataset(train);
        logger.info("Applying scaling to testing dataset");
        test = transformers.transformDataset(test);
    }
    return new Pair<>(train, test);
}
Also used : TransformerMap(org.tribuo.transform.TransformerMap) CSVDataSource(org.tribuo.data.csv.CSVDataSource) SimpleTextDataSource(org.tribuo.data.text.impl.SimpleTextDataSource) TransformationMap(org.tribuo.transform.TransformationMap) LinearScalingTransformation(org.tribuo.transform.transformations.LinearScalingTransformation) BufferedInputStream(java.io.BufferedInputStream) LibSVMDataSource(org.tribuo.datasource.LibSVMDataSource) RowProcessor(org.tribuo.data.columnar.RowProcessor) Pair(com.oracle.labs.mlrg.olcut.util.Pair) CSVLoader(org.tribuo.data.csv.CSVLoader) ImmutableDataset(org.tribuo.ImmutableDataset) Dataset(org.tribuo.Dataset) MinimumCardinalityDataset(org.tribuo.dataset.MinimumCardinalityDataset) MutableDataset(org.tribuo.MutableDataset) FileInputStream(java.io.FileInputStream) BreakIteratorTokenizer(org.tribuo.util.tokens.impl.BreakIteratorTokenizer) TokenPipeline(org.tribuo.data.text.impl.TokenPipeline) ObjectInputStream(java.io.ObjectInputStream)

Aggregations

LibSVMDataSource (org.tribuo.datasource.LibSVMDataSource)8 MutableDataset (org.tribuo.MutableDataset)7 Label (org.tribuo.classification.Label)7 LabelFactory (org.tribuo.classification.LabelFactory)7 Path (java.nio.file.Path)6 LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)6 URL (java.net.URL)5 HashMap (java.util.HashMap)5 Test (org.junit.jupiter.api.Test)5 OrtEnvironment (ai.onnxruntime.OrtEnvironment)3 OrtSession (ai.onnxruntime.OrtSession)3 ImmutableDataset (org.tribuo.ImmutableDataset)3 Pair (com.oracle.labs.mlrg.olcut.util.Pair)2 BufferedInputStream (java.io.BufferedInputStream)2 FileInputStream (java.io.FileInputStream)2 ObjectInputStream (java.io.ObjectInputStream)2 Dataset (org.tribuo.Dataset)2 CSVLoader (org.tribuo.data.csv.CSVLoader)2 SimpleTextDataSource (org.tribuo.data.text.impl.SimpleTextDataSource)2 TokenPipeline (org.tribuo.data.text.impl.TokenPipeline)2