 * Samples a dataset based on the provided text, tokens and tabular features.
 * The text features are sampled using the {@link LIMEText} sampling approach,
 * and the tabular features are sampled using the {@link LIMEBase} approach.
 * The weight for each example is based on the distance for the tabular features,
 * combined with the distance for the text features (which is a hamming distance).
 * These distances are averaged using a weight function representing how many tokens
 * there are in the text fields, and how many tabular features there are.
 * This weight calculation is subject to change, as it's not necessarily optimal.
 * @param tabularVector The tabular (i.e., non-text) features.
 * @param text A map from the field names to the field values for the text fields.
 * @param textTokens A map from the field names to lists of tokens for those fields.
 * @return A sampled dataset.
private List<Example<Regressor>> sampleData(SparseVector tabularVector, Map<String, String> text, Map<String, List<Token>> textTokens) {
    List<Example<Regressor>> output = new ArrayList<>();
    Random innerRNG = new Random(rng.nextLong());
    for (int i = 0; i < numSamples; i++) {
        // Create the full example
        ListExample<Label> sampledExample = new ListExample<>(LabelFactory.UNKNOWN_LABEL);
        // Tabular features.
        List<Feature> tabularFeatures = new ArrayList<>();
        // Sample the categorical and real features
        for (VariableInfo info : tabularDomain) {
            int id = ((VariableIDInfo) info).getID();
            double inputValue = tabularVector.get(id);
            if (info instanceof CategoricalInfo) {
                // This one is tricksy as categorical info essentially implicitly includes a zero.
                CategoricalInfo catInfo = (CategoricalInfo) info;
                double sample = catInfo.frequencyBasedSample(innerRNG, numTrainingExamples);
                // If we didn't sample zero.
                if (Math.abs(sample) > 1e-10) {
                    Feature newFeature = new Feature(info.getName(), sample);
            } else if (info instanceof RealInfo) {
                RealInfo realInfo = (RealInfo) info;
                // As realInfo is sparse we sample from the mixture distribution,
                // either 0 or N(inputValue,variance).
                // This assumes realInfo never observed a zero, which is enforced from v2.1
                // TODO check this makes sense. If the input value is zero do we still want to sample spike and slab?
                // If it's not zero do we want to?
                int count = realInfo.getCount();
                double threshold = count / ((double) numTrainingExamples);
                if (innerRNG.nextDouble() < threshold) {
                    double variance = realInfo.getVariance();
                    double sample = (innerRNG.nextGaussian() * Math.sqrt(variance)) + inputValue;
                    Feature newFeature = new Feature(info.getName(), sample);
            } else {
                throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
        // Sample the binarised categorical features
        for (Map.Entry<String, double[]> e : binarisedCDFs.entrySet()) {
            // Sample from the CDF
            int sample = Util.sampleFromCDF(e.getValue(), innerRNG);
            // If the sample isn't zero (which is defined to be the last value to make the indices work)
            if (sample != (e.getValue().length - 1)) {
                VariableInfo info = binarisedInfos.get(e.getKey()).get(sample);
                Feature newFeature = new Feature(info.getName(), 1);
        // Add the tabular features to the current example
        // Calculate tabular distance
        double tabularDistance = measureDistance(tabularDomain, numTrainingExamples, tabularVector, SparseVector.createSparseVector(sampledExample, tabularDomain, false));
        // features are the full text features
        List<Feature> textFeatures = new ArrayList<>();
        // Perturbed features are the binarised tokens
        List<Feature> perturbedFeatures = new ArrayList<>();
        // Sample the text features
        double textDistance = 0.0;
        long numTokens = 0;
        for (Map.Entry<String, String> e : text.entrySet()) {
            String curText = e.getValue();
            List<Token> tokens = textTokens.get(e.getKey());
            numTokens += tokens.size();
            // Sample a new Example.
            int[] activeFeatures = new int[tokens.size()];
            char[] sampledText = curText.toCharArray();
            for (int j = 0; j < activeFeatures.length; j++) {
                activeFeatures[j] = innerRNG.nextInt(2);
                if (activeFeatures[j] == 0) {
                    Token curToken = tokens.get(j);
                    Arrays.fill(sampledText, curToken.start, curToken.end, '\0');
            String sampledString = new String(sampledText);
            sampledString = sampledString.replace("\0", "");
            for (int j = 0; j < activeFeatures.length; j++) {
                perturbedFeatures.add(new Feature(nameFeature(e.getKey(), tokens.get(j).text, j), activeFeatures[j]));
        // Add the text features to the current example
        // Calculate text distance
        double totalTextDistance = textDistance / numTokens;
        // Label it using the full model.
        Prediction<Label> samplePrediction = innerModel.predict(sampledExample);
        double totalLength = tabularFeatures.size() + perturbedFeatures.size();
        // Combine the distances and transform into a weight
        // Currently this averages the two values based on their relative sizes.
        double weight = 1.0 - ((tabularFeatures.size() * (kernelDist(tabularDistance, kernelWidth) + perturbedFeatures.size() * totalTextDistance) / totalLength));
        // Generate the new sample with the appropriate label and weight.
        ArrayExample<Regressor> labelledSample = new ArrayExample<>(transformOutput(samplePrediction), (float) weight);
    return output;
Also used : ArrayList(java.util.ArrayList) Label(org.tribuo.classification.Label) Token(org.tribuo.util.tokens.Token) ColumnarFeature( Feature(org.tribuo.Feature) ArrayExample(org.tribuo.impl.ArrayExample) Random(java.util.Random) SplittableRandom(java.util.SplittableRandom) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) ListExample(org.tribuo.impl.ListExample) VariableIDInfo(org.tribuo.VariableIDInfo) Regressor(org.tribuo.regression.Regressor) RealInfo(org.tribuo.RealInfo) ListExample(org.tribuo.impl.ListExample) VariableInfo(org.tribuo.VariableInfo) CategoricalInfo(org.tribuo.CategoricalInfo) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map)

public LIMEExplanation explain(String inputText) {
    Example<Label> trueExample = extractor.extract(LabelFactory.UNKNOWN_LABEL, inputText);
    Prediction<Label> prediction = innerModel.predict(trueExample);
    ArrayExample<Regressor> bowExample = new ArrayExample<>(transformOutput(prediction));
    List<Token> tokens = tokenizerThreadLocal.get().tokenize(inputText);
    for (int i = 0; i < tokens.size(); i++) {
        bowExample.add(nameFeature(tokens.get(i).text, i), 1.0);
    // Sample a dataset.
    List<Example<Regressor>> sample = sampleData(inputText, tokens);
    // Generate a sparse model on the sampled data.
    SparseModel<Regressor> model = trainExplainer(bowExample, sample);
    // Test the sparse model against the predictions of the real model.
    List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
    RegressionEvaluation evaluation = evaluator.evaluate(model, predictions, new SimpleDataSourceProvenance("LIMEText sampled data", regressionFactory));
    return new LIMEExplanation(model, prediction, evaluation);
Also used : SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) Prediction(org.tribuo.Prediction) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) Token(org.tribuo.util.tokens.Token) ArrayExample(org.tribuo.impl.ArrayExample) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor)

public boolean advance() {
    if (cs == null) {
        throw new IllegalStateException("SplitFunctionTokenizer has not been reset.");
    if (nextToken != null) {
        currentToken = nextToken;
        nextToken = null;
        return true;
    if (p >= cs.length()) {
        return false;
    currentToken = null;
    SplitResult splitResult;
    SplitType splitType;
    TokenType tokenType;
    tokenSb.delete(0, tokenSb.length());
    while (p < cs.length()) {
        int codepoint = cs.codePointAt(p);
        splitResult = splitFunction.apply(codepoint, p, cs);
        splitType = splitResult.splitType;
        tokenType = splitResult.tokenType;
        // where the end of the token is.
        if (splitType == SplitType.NO_SPLIT) {
            if (tokenSb.length() == 0) {
                start = p;
            p += Character.charCount(codepoint);
            currentType = tokenType;
        if (splitType == SplitType.SPLIT_AT) {
            if (tokenSb.length() > 0) {
                currentToken = new Token(tokenSb.toString(), start, p, currentType);
            p += Character.charCount(codepoint);
            start = p;
            tokenSb.delete(0, tokenSb.length());
        } else if (splitType == SplitType.SPLIT_BEFORE) {
            if (tokenSb.length() > 0) {
                currentToken = new Token(tokenSb.toString(), start, p, currentType);
            start = p;
            tokenSb.delete(0, tokenSb.length());
            p += Character.charCount(codepoint);
        } else if (splitType == SplitType.SPLIT_AFTER) {
            p += Character.charCount(codepoint);
            // no need to check the length since we just added a code point
            currentToken = new Token(tokenSb.toString(), start, p, tokenType);
            tokenSb.delete(0, tokenSb.length());
            start = p;
        } else if (splitType == SplitType.SPLIT_BEFORE_AND_AFTER) {
            // the next token which consists of just the character
            if (tokenSb.length() > 0) {
                currentToken = new Token(tokenSb.toString(), start, p, currentType);
                tokenSb.delete(0, tokenSb.length());
                start = p;
                p += Character.charCount(codepoint);
                nextToken = new Token(tokenSb.toString(), start, p, tokenType);
                tokenSb.delete(0, tokenSb.length());
            } else {
                start = p;
                p += Character.charCount(codepoint);
                currentToken = new Token(tokenSb.toString(), start, p, tokenType);
                tokenSb.delete(0, tokenSb.length());
        if (currentToken != null) {
    if (currentToken == null) {
        if (tokenSb.length() > 0) {
            currentToken = new Token(tokenSb.toString(), start, p, currentType);
    // We advanced if we have some stuff collected.
    if (currentToken != null) {
        ready = true;
        return true;
    } else {
        return false;
Also used : TokenType(org.tribuo.util.tokens.Token.TokenType) Token(org.tribuo.util.tokens.Token)

 * Generates the wordpiece tokens from the next token.
private void getWordpieceTokens() {
    String text = currentToken.text;
    if (neverSplitTokens.contains(text)) {
    List<Token> basicTokens = this.basicTokenizer.tokenize(text);
    for (Token basicToken : basicTokens) {
        text = basicToken.text;
        if (toLowerCase) {
            text = text.toLowerCase();
        if (this.stripAccents) {
            text = normalize(text);
        List<String> wordpieces = wordpiece.wordpiece(text);
        if (wordpieces.size() == 0) {
        } else if (wordpieces.size() == 1) {
            String wp = wordpieces.get(0);
            int start = basicToken.start + currentToken.start;
            int end = basicToken.end + currentToken.start;
            if (wp.equals(this.wordpiece.getUnknownToken())) {
                currentWordpieceTokens.add(new Token(wp, start, end, TokenType.UNKNOWN));
            } else {
                currentWordpieceTokens.add(new Token(wp, start, end, TokenType.WORD));
        } else {
            int begin = currentToken.start + basicToken.start;
            for (String wp : wordpieces) {
                TokenType type = TokenType.PREFIX;
                int end = begin + wp.length();
                if (wp.startsWith("##")) {
                    end -= 2;
                    type = TokenType.SUFFIX;
                currentWordpieceTokens.add(new Token(wp, begin, end, type));
                begin = end;
Also used : TokenType(org.tribuo.util.tokens.Token.TokenType) Token(org.tribuo.util.tokens.Token)

protected Pair<LIMEExplanation, List<Example<Regressor>>> explainWithSamples(Map<String, String> input) {
    Optional<Example<Label>> optExample = generator.generateExample(input, false);
    if (optExample.isPresent()) {
        Example<Label> example = optExample.get();
        if ((textDomain.size() == 0) && (binarisedCDFs.size() == 0)) {
            // Short circuit if there are no text or binarised fields.
            return explainWithSamples(example);
        } else {
            Prediction<Label> prediction = innerModel.predict(example);
            // Build the input example with simplified text features
            ArrayExample<Regressor> labelledExample = new ArrayExample<>(transformOutput(prediction));
            // Add the tabular features
            for (Feature f : example) {
                if (tabularDomain.getID(f.getName()) != -1) {
            // Extract the tabular features into a SparseVector for later
            SparseVector tabularVector = SparseVector.createSparseVector(labelledExample, tabularDomain, false);
            // Tokenize the text fields, and generate the perturbed text representation
            Map<String, String> exampleTextValues = new HashMap<>();
            Map<String, List<Token>> exampleTextTokens = new HashMap<>();
            for (Map.Entry<String, FieldProcessor> e : textFields.entrySet()) {
                String value = input.get(e.getKey());
                if (value != null) {
                    List<Token> tokens = tokenizerThreadLocal.get().tokenize(value);
                    for (int i = 0; i < tokens.size(); i++) {
                        labelledExample.add(nameFeature(e.getKey(), tokens.get(i).text, i), 1.0);
                    exampleTextValues.put(e.getKey(), value);
                    exampleTextTokens.put(e.getKey(), tokens);
            // Sample a dataset.
            List<Example<Regressor>> sample = sampleData(tabularVector, exampleTextValues, exampleTextTokens);
            // Generate a sparse model on the sampled data.
            SparseModel<Regressor> model = trainExplainer(labelledExample, sample);
            // Test the sparse model against the predictions of the real model.
            List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
            RegressionEvaluation evaluation = evaluator.evaluate(model, predictions, new SimpleDataSourceProvenance("LIMEColumnar sampled data", regressionFactory));
            return new Pair<>(new LIMEExplanation(model, prediction, evaluation), sample);
    } else {
        throw new IllegalArgumentException("Label not found in input " + input.toString());
Also used : HashMap(java.util.HashMap) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) Token(org.tribuo.util.tokens.Token) SparseVector( ColumnarFeature( Feature(org.tribuo.Feature) FieldProcessor( ArrayExample(org.tribuo.impl.ArrayExample) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) ListExample(org.tribuo.impl.ListExample) ArrayList(java.util.ArrayList) List(java.util.List) Regressor(org.tribuo.regression.Regressor) Pair( SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) Prediction(org.tribuo.Prediction) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map)


