Search in sources :

Example 1 with VariableInfo

use of org.tribuo.VariableInfo in project tribuo by oracle.

the class RowProcessor method expandRegexMapping.

/**
 * Uses similar logic to {@link org.tribuo.transform.TransformationMap#validateTransformations} to check the regexes
 * against the supplied feature map. Throws an IllegalArgumentException if any regexes overlap with
 * themselves, or with the currently defined set of fieldProcessorMap.
 * @param featureMap The feature map to use to expand the regexes.
 */
public void expandRegexMapping(ImmutableFeatureMap featureMap) {
    ArrayList<String> fieldNames = new ArrayList<>(featureMap.size());
    for (VariableInfo v : featureMap) {
        String[] split = FEATURE_NAME_PATTERN.split(v.getName(), 1);
        String fieldName = split[0];
        fieldNames.add(fieldName);
    }
    expandRegexMapping(fieldNames);
}
Also used : VariableInfo(org.tribuo.VariableInfo) ArrayList(java.util.ArrayList)

Example 2 with VariableInfo

use of org.tribuo.VariableInfo in project tribuo by oracle.

the class ClassificationTest method classificationCNNTest.

@Test
public void classificationCNNTest() throws IOException {
    // Create the train and test data
    Pair<Dataset<Label>, Dataset<Label>> data = generateImageData(512, 10, 128, 5, 42);
    Dataset<Label> trainData = data.getA();
    Dataset<Label> testData = data.getB();
    // Build the CNN graph
    GraphDefTuple graphDefTuple = CNNExamples.buildLeNetGraph(INPUT_NAME, 10, 255, trainData.getOutputs().size());
    // Configure the trainer
    Map<String, Float> gradientParams = new HashMap<>();
    gradientParams.put("learningRate", 0.01f);
    gradientParams.put("initialAccumulatorValue", 0.1f);
    FeatureConverter imageConverter = new ImageConverter(INPUT_NAME, 10, 10, 1);
    OutputConverter<Label> outputConverter = new LabelConverter();
    TensorFlowTrainer<Label> trainer = new TensorFlowTrainer<>(graphDefTuple.graphDef, graphDefTuple.outputName, GradientOptimiser.ADAGRAD, gradientParams, imageConverter, outputConverter, 16, 5, 16, -1);
    // Train the model
    TensorFlowModel<Label> model = trainer.train(trainData);
    // Make some predictions
    List<Prediction<Label>> predictions = model.predict(testData);
    // Run smoke test evaluation
    LabelEvaluation eval = new LabelEvaluator().evaluate(model, predictions, testData.getProvenance());
    Assertions.assertTrue(eval.accuracy() > 0.0);
    // Check Tribuo serialization
    Helpers.testModelSerialization(model, Label.class);
    // Check saved model bundle export
    Path outputPath = Files.createTempDirectory("tf-classification-cnn-test");
    model.exportModel(outputPath.toString());
    try (Stream<Path> f = Files.list(outputPath)) {
        List<Path> files = f.collect(Collectors.toList());
        Assertions.assertNotEquals(0, files.size());
    }
    // Create external model from bundle
    Map<Label, Integer> outputMapping = new HashMap<>();
    for (Pair<Integer, Label> p : model.getOutputIDInfo()) {
        outputMapping.put(p.getB(), p.getA());
    }
    Map<String, Integer> featureMapping = new HashMap<>();
    ImmutableFeatureMap featureIDMap = model.getFeatureIDMap();
    for (VariableInfo info : featureIDMap) {
        featureMapping.put(info.getName(), featureIDMap.getID(info.getName()));
    }
    TensorFlowSavedModelExternalModel<Label> externalModel = TensorFlowSavedModelExternalModel.createTensorflowModel(trainData.getOutputFactory(), featureMapping, outputMapping, model.getOutputName(), imageConverter, outputConverter, outputPath.toString());
    // Check predictions are equal
    List<Prediction<Label>> externalPredictions = externalModel.predict(testData);
    checkPredictionEquality(predictions, externalPredictions);
    // Cleanup saved model bundle
    externalModel.close();
    Files.walk(outputPath).sorted(Comparator.reverseOrder()).map(Path::toFile).forEach(File::delete);
    Assertions.assertFalse(Files.exists(outputPath));
    // Cleanup created model
    model.close();
}
Also used : GraphDefTuple(org.tribuo.interop.tensorflow.example.GraphDefTuple) HashMap(java.util.HashMap) Label(org.tribuo.classification.Label) LabelEvaluation(org.tribuo.classification.evaluation.LabelEvaluation) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) Path(java.nio.file.Path) Dataset(org.tribuo.Dataset) MutableDataset(org.tribuo.MutableDataset) Prediction(org.tribuo.Prediction) VariableInfo(org.tribuo.VariableInfo) LabelEvaluator(org.tribuo.classification.evaluation.LabelEvaluator) File(java.io.File) Test(org.junit.jupiter.api.Test)

Example 3 with VariableInfo

use of org.tribuo.VariableInfo in project tribuo by oracle.

the class LIMEBase method samplePoint.

/**
 * Samples a single example from the supplied feature map and input vector.
 * @param rng The rng to use.
 * @param fMap The feature map describing the domain of the features.
 * @param numTrainingExamples The number of training examples the fMap has seen.
 * @param input The input sparse vector to use.
 * @return An Example sampled from the supplied feature map and input vector.
 */
public static Example<Label> samplePoint(Random rng, ImmutableFeatureMap fMap, long numTrainingExamples, SparseVector input) {
    ArrayList<String> names = new ArrayList<>();
    ArrayList<Double> values = new ArrayList<>();
    for (VariableInfo info : fMap) {
        int id = ((VariableIDInfo) info).getID();
        double inputValue = input.get(id);
        if (info instanceof CategoricalInfo) {
            // This one is tricksy as categorical info essentially implicitly includes a zero.
            CategoricalInfo catInfo = (CategoricalInfo) info;
            double sample = catInfo.frequencyBasedSample(rng, numTrainingExamples);
            // If we didn't sample zero.
            if (Math.abs(sample) > 1e-10) {
                names.add(info.getName());
                values.add(sample);
            }
        } else if (info instanceof RealInfo) {
            RealInfo realInfo = (RealInfo) info;
            // As realInfo is sparse we sample from the mixture distribution,
            // either 0 or N(inputValue,variance).
            // This assumes realInfo never observed a zero, which is enforced from v2.1
            // TODO check this makes sense. If the input value is zero do we still want to sample spike and slab?
            // If it's not zero do we want to?
            int count = realInfo.getCount();
            double threshold = count / ((double) numTrainingExamples);
            if (rng.nextDouble() < threshold) {
                double variance = realInfo.getVariance();
                double sample = (rng.nextGaussian() * Math.sqrt(variance)) + inputValue;
                names.add(info.getName());
                values.add(sample);
            }
        } else {
            throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
        }
    }
    return new ArrayExample<>(LabelFactory.UNKNOWN_LABEL, names.toArray(new String[0]), Util.toPrimitiveDouble(values));
}
Also used : VariableInfo(org.tribuo.VariableInfo) ArrayList(java.util.ArrayList) CategoricalInfo(org.tribuo.CategoricalInfo) ArrayExample(org.tribuo.impl.ArrayExample) VariableIDInfo(org.tribuo.VariableIDInfo) RealInfo(org.tribuo.RealInfo)

Example 4 with VariableInfo

use of org.tribuo.VariableInfo in project tribuo by oracle.

the class LIMEColumnar method sampleData.

/**
 * Samples a dataset based on the provided text, tokens and tabular features.
 *
 * The text features are sampled using the {@link LIMEText} sampling approach,
 * and the tabular features are sampled using the {@link LIMEBase} approach.
 *
 * The weight for each example is based on the distance for the tabular features,
 * combined with the distance for the text features (which is a hamming distance).
 * These distances are averaged using a weight function representing how many tokens
 * there are in the text fields, and how many tabular features there are.
 *
 * This weight calculation is subject to change, as it's not necessarily optimal.
 * @param tabularVector The tabular (i.e., non-text) features.
 * @param text A map from the field names to the field values for the text fields.
 * @param textTokens A map from the field names to lists of tokens for those fields.
 * @return A sampled dataset.
 */
private List<Example<Regressor>> sampleData(SparseVector tabularVector, Map<String, String> text, Map<String, List<Token>> textTokens) {
    List<Example<Regressor>> output = new ArrayList<>();
    Random innerRNG = new Random(rng.nextLong());
    for (int i = 0; i < numSamples; i++) {
        // Create the full example
        ListExample<Label> sampledExample = new ListExample<>(LabelFactory.UNKNOWN_LABEL);
        // Tabular features.
        List<Feature> tabularFeatures = new ArrayList<>();
        // Sample the categorical and real features
        for (VariableInfo info : tabularDomain) {
            int id = ((VariableIDInfo) info).getID();
            double inputValue = tabularVector.get(id);
            if (info instanceof CategoricalInfo) {
                // This one is tricksy as categorical info essentially implicitly includes a zero.
                CategoricalInfo catInfo = (CategoricalInfo) info;
                double sample = catInfo.frequencyBasedSample(innerRNG, numTrainingExamples);
                // If we didn't sample zero.
                if (Math.abs(sample) > 1e-10) {
                    Feature newFeature = new Feature(info.getName(), sample);
                    tabularFeatures.add(newFeature);
                }
            } else if (info instanceof RealInfo) {
                RealInfo realInfo = (RealInfo) info;
                // As realInfo is sparse we sample from the mixture distribution,
                // either 0 or N(inputValue,variance).
                // This assumes realInfo never observed a zero, which is enforced from v2.1
                // TODO check this makes sense. If the input value is zero do we still want to sample spike and slab?
                // If it's not zero do we want to?
                int count = realInfo.getCount();
                double threshold = count / ((double) numTrainingExamples);
                if (innerRNG.nextDouble() < threshold) {
                    double variance = realInfo.getVariance();
                    double sample = (innerRNG.nextGaussian() * Math.sqrt(variance)) + inputValue;
                    Feature newFeature = new Feature(info.getName(), sample);
                    tabularFeatures.add(newFeature);
                }
            } else {
                throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
            }
        }
        // Sample the binarised categorical features
        for (Map.Entry<String, double[]> e : binarisedCDFs.entrySet()) {
            // Sample from the CDF
            int sample = Util.sampleFromCDF(e.getValue(), innerRNG);
            // If the sample isn't zero (which is defined to be the last value to make the indices work)
            if (sample != (e.getValue().length - 1)) {
                VariableInfo info = binarisedInfos.get(e.getKey()).get(sample);
                Feature newFeature = new Feature(info.getName(), 1);
                tabularFeatures.add(newFeature);
            }
        }
        // Add the tabular features to the current example
        sampledExample.addAll(tabularFeatures);
        // Calculate tabular distance
        double tabularDistance = measureDistance(tabularDomain, numTrainingExamples, tabularVector, SparseVector.createSparseVector(sampledExample, tabularDomain, false));
        // features are the full text features
        List<Feature> textFeatures = new ArrayList<>();
        // Perturbed features are the binarised tokens
        List<Feature> perturbedFeatures = new ArrayList<>();
        // Sample the text features
        double textDistance = 0.0;
        long numTokens = 0;
        for (Map.Entry<String, String> e : text.entrySet()) {
            String curText = e.getValue();
            List<Token> tokens = textTokens.get(e.getKey());
            numTokens += tokens.size();
            // Sample a new Example.
            int[] activeFeatures = new int[tokens.size()];
            char[] sampledText = curText.toCharArray();
            for (int j = 0; j < activeFeatures.length; j++) {
                activeFeatures[j] = innerRNG.nextInt(2);
                if (activeFeatures[j] == 0) {
                    textDistance++;
                    Token curToken = tokens.get(j);
                    Arrays.fill(sampledText, curToken.start, curToken.end, '\0');
                }
            }
            String sampledString = new String(sampledText);
            sampledString = sampledString.replace("\0", "");
            textFeatures.addAll(textFields.get(e.getKey()).process(sampledString));
            for (int j = 0; j < activeFeatures.length; j++) {
                perturbedFeatures.add(new Feature(nameFeature(e.getKey(), tokens.get(j).text, j), activeFeatures[j]));
            }
        }
        // Add the text features to the current example
        sampledExample.addAll(textFeatures);
        // Calculate text distance
        double totalTextDistance = textDistance / numTokens;
        // Label it using the full model.
        Prediction<Label> samplePrediction = innerModel.predict(sampledExample);
        double totalLength = tabularFeatures.size() + perturbedFeatures.size();
        // Combine the distances and transform into a weight
        // Currently this averages the two values based on their relative sizes.
        double weight = 1.0 - ((tabularFeatures.size() * (kernelDist(tabularDistance, kernelWidth) + perturbedFeatures.size() * totalTextDistance) / totalLength));
        // Generate the new sample with the appropriate label and weight.
        ArrayExample<Regressor> labelledSample = new ArrayExample<>(transformOutput(samplePrediction), (float) weight);
        labelledSample.addAll(tabularFeatures);
        labelledSample.addAll(perturbedFeatures);
        output.add(labelledSample);
    }
    return output;
}
Also used : ArrayList(java.util.ArrayList) Label(org.tribuo.classification.Label) Token(org.tribuo.util.tokens.Token) ColumnarFeature(org.tribuo.data.columnar.ColumnarFeature) Feature(org.tribuo.Feature) ArrayExample(org.tribuo.impl.ArrayExample) Random(java.util.Random) SplittableRandom(java.util.SplittableRandom) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) ListExample(org.tribuo.impl.ListExample) VariableIDInfo(org.tribuo.VariableIDInfo) Regressor(org.tribuo.regression.Regressor) RealInfo(org.tribuo.RealInfo) ListExample(org.tribuo.impl.ListExample) VariableInfo(org.tribuo.VariableInfo) CategoricalInfo(org.tribuo.CategoricalInfo) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map)

Example 5 with VariableInfo

use of org.tribuo.VariableInfo in project tribuo by oracle.

the class TestLibSVM method testOnnxSerialization.

private static void testOnnxSerialization(Pair<Dataset<Regressor>, Dataset<Regressor>> datasetPair, LibSVMRegressionTrainer trainer) throws IOException, OrtException {
    LibSVMRegressionModel model = (LibSVMRegressionModel) trainer.train(datasetPair.getA());
    // Write out model
    Path onnxFile = Files.createTempFile("tribuo-libsvm-test", ".onnx");
    model.saveONNXModel("org.tribuo.regression.libsvm.test", 1, onnxFile);
    // Prep mappings
    Map<String, Integer> featureMapping = new HashMap<>();
    for (VariableInfo f : model.getFeatureIDMap()) {
        VariableIDInfo id = (VariableIDInfo) f;
        featureMapping.put(id.getName(), id.getID());
    }
    Map<Regressor, Integer> outputMapping = new HashMap<>();
    for (Pair<Integer, Regressor> l : model.getOutputIDInfo()) {
        outputMapping.put(l.getB(), l.getA());
    }
    String arch = System.getProperty("os.arch");
    if (arch.equalsIgnoreCase("amd64") || arch.equalsIgnoreCase("x86_64")) {
        // Initialise the OrtEnvironment to load the native library
        // (as OrtSession.SessionOptions doesn't trigger the static initializer).
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        env.close();
        // Load in via ORT
        ONNXExternalModel<Regressor> onnxModel = ONNXExternalModel.createOnnxModel(new RegressionFactory(), featureMapping, outputMapping, new DenseTransformer(), new RegressorTransformer(), new OrtSession.SessionOptions(), onnxFile, "input");
        // Generate predictions
        List<Prediction<Regressor>> nativePredictions = model.predict(datasetPair.getB());
        List<Prediction<Regressor>> onnxPredictions = onnxModel.predict(datasetPair.getB());
        // Assert the predictions are identical
        for (int i = 0; i < nativePredictions.size(); i++) {
            Prediction<Regressor> tribuo = nativePredictions.get(i);
            Prediction<Regressor> external = onnxPredictions.get(i);
            assertArrayEquals(tribuo.getOutput().getNames(), external.getOutput().getNames());
            if (model.isStandardized()) {
                // Standardized models are less numerically stable when cast into floats
                double[] tribuoValues = tribuo.getOutput().getValues();
                double[] externalValues = external.getOutput().getValues();
                assertEquals(tribuoValues.length, externalValues.length);
                for (int j = 0; j < tribuoValues.length; j++) {
                    // compute sf comparison
                    assertEquals(tribuoValues[j], externalValues[j], Math.abs(tribuoValues[j]) / 1e3);
                }
            } else {
                assertArrayEquals(tribuo.getOutput().getValues(), external.getOutput().getValues(), 1e-4);
            }
        }
        // Check that the provenance can be extracted and is the same
        ModelProvenance modelProv = model.getProvenance();
        Optional<ModelProvenance> optProv = onnxModel.getTribuoProvenance();
        assertTrue(optProv.isPresent());
        ModelProvenance onnxProv = optProv.get();
        assertNotSame(onnxProv, modelProv);
        assertEquals(modelProv, onnxProv);
        onnxModel.close();
    } else {
        logger.warning("ORT based tests only supported on x86_64, found " + arch);
    }
    onnxFile.toFile().delete();
}
Also used : RegressorTransformer(org.tribuo.interop.onnx.RegressorTransformer) HashMap(java.util.HashMap) RegressionFactory(org.tribuo.regression.RegressionFactory) OrtEnvironment(ai.onnxruntime.OrtEnvironment) DenseTransformer(org.tribuo.interop.onnx.DenseTransformer) VariableIDInfo(org.tribuo.VariableIDInfo) Regressor(org.tribuo.regression.Regressor) Path(java.nio.file.Path) VariableInfo(org.tribuo.VariableInfo) Prediction(org.tribuo.Prediction) ModelProvenance(org.tribuo.provenance.ModelProvenance) OrtSession(ai.onnxruntime.OrtSession)

Aggregations

VariableInfo (org.tribuo.VariableInfo)15 HashMap (java.util.HashMap)9 ArrayList (java.util.ArrayList)8 VariableIDInfo (org.tribuo.VariableIDInfo)8 Prediction (org.tribuo.Prediction)5 OrtEnvironment (ai.onnxruntime.OrtEnvironment)4 OrtSession (ai.onnxruntime.OrtSession)4 Map (java.util.Map)4 Feature (org.tribuo.Feature)4 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)4 Label (org.tribuo.classification.Label)4 ModelProvenance (org.tribuo.provenance.ModelProvenance)4 Regressor (org.tribuo.regression.Regressor)3 Path (java.nio.file.Path)2 List (java.util.List)2 CategoricalInfo (org.tribuo.CategoricalInfo)2 Dataset (org.tribuo.Dataset)2 RealInfo (org.tribuo.RealInfo)2 LabelEvaluation (org.tribuo.classification.evaluation.LabelEvaluation)2 LabelEvaluator (org.tribuo.classification.evaluation.LabelEvaluator)2