Search in sources :

Example 1 with Token

use of org.tribuo.util.tokens.Token in project tribuo by oracle.

the class LIMEColumnar method sampleData.

/**
 * Samples a dataset based on the provided text, tokens and tabular features.
 *
 * The text features are sampled using the {@link LIMEText} sampling approach,
 * and the tabular features are sampled using the {@link LIMEBase} approach.
 *
 * The weight for each example is based on the distance for the tabular features,
 * combined with the distance for the text features (which is a hamming distance).
 * These distances are averaged using a weight function representing how many tokens
 * there are in the text fields, and how many tabular features there are.
 *
 * This weight calculation is subject to change, as it's not necessarily optimal.
 * @param tabularVector The tabular (i.e., non-text) features.
 * @param text A map from the field names to the field values for the text fields.
 * @param textTokens A map from the field names to lists of tokens for those fields.
 * @return A sampled dataset.
 */
private List<Example<Regressor>> sampleData(SparseVector tabularVector, Map<String, String> text, Map<String, List<Token>> textTokens) {
    List<Example<Regressor>> output = new ArrayList<>();
    Random innerRNG = new Random(rng.nextLong());
    for (int i = 0; i < numSamples; i++) {
        // Create the full example
        ListExample<Label> sampledExample = new ListExample<>(LabelFactory.UNKNOWN_LABEL);
        // Tabular features.
        List<Feature> tabularFeatures = new ArrayList<>();
        // Sample the categorical and real features
        for (VariableInfo info : tabularDomain) {
            int id = ((VariableIDInfo) info).getID();
            double inputValue = tabularVector.get(id);
            if (info instanceof CategoricalInfo) {
                // This one is tricksy as categorical info essentially implicitly includes a zero.
                CategoricalInfo catInfo = (CategoricalInfo) info;
                double sample = catInfo.frequencyBasedSample(innerRNG, numTrainingExamples);
                // If we didn't sample zero.
                if (Math.abs(sample) > 1e-10) {
                    Feature newFeature = new Feature(info.getName(), sample);
                    tabularFeatures.add(newFeature);
                }
            } else if (info instanceof RealInfo) {
                RealInfo realInfo = (RealInfo) info;
                // As realInfo is sparse we sample from the mixture distribution,
                // either 0 or N(inputValue,variance).
                // This assumes realInfo never observed a zero, which is enforced from v2.1
                // TODO check this makes sense. If the input value is zero do we still want to sample spike and slab?
                // If it's not zero do we want to?
                int count = realInfo.getCount();
                double threshold = count / ((double) numTrainingExamples);
                if (innerRNG.nextDouble() < threshold) {
                    double variance = realInfo.getVariance();
                    double sample = (innerRNG.nextGaussian() * Math.sqrt(variance)) + inputValue;
                    Feature newFeature = new Feature(info.getName(), sample);
                    tabularFeatures.add(newFeature);
                }
            } else {
                throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
            }
        }
        // Sample the binarised categorical features
        for (Map.Entry<String, double[]> e : binarisedCDFs.entrySet()) {
            // Sample from the CDF
            int sample = Util.sampleFromCDF(e.getValue(), innerRNG);
            // If the sample isn't zero (which is defined to be the last value to make the indices work)
            if (sample != (e.getValue().length - 1)) {
                VariableInfo info = binarisedInfos.get(e.getKey()).get(sample);
                Feature newFeature = new Feature(info.getName(), 1);
                tabularFeatures.add(newFeature);
            }
        }
        // Add the tabular features to the current example
        sampledExample.addAll(tabularFeatures);
        // Calculate tabular distance
        double tabularDistance = measureDistance(tabularDomain, numTrainingExamples, tabularVector, SparseVector.createSparseVector(sampledExample, tabularDomain, false));
        // features are the full text features
        List<Feature> textFeatures = new ArrayList<>();
        // Perturbed features are the binarised tokens
        List<Feature> perturbedFeatures = new ArrayList<>();
        // Sample the text features
        double textDistance = 0.0;
        long numTokens = 0;
        for (Map.Entry<String, String> e : text.entrySet()) {
            String curText = e.getValue();
            List<Token> tokens = textTokens.get(e.getKey());
            numTokens += tokens.size();
            // Sample a new Example.
            int[] activeFeatures = new int[tokens.size()];
            char[] sampledText = curText.toCharArray();
            for (int j = 0; j < activeFeatures.length; j++) {
                activeFeatures[j] = innerRNG.nextInt(2);
                if (activeFeatures[j] == 0) {
                    textDistance++;
                    Token curToken = tokens.get(j);
                    Arrays.fill(sampledText, curToken.start, curToken.end, '\0');
                }
            }
            String sampledString = new String(sampledText);
            sampledString = sampledString.replace("\0", "");
            textFeatures.addAll(textFields.get(e.getKey()).process(sampledString));
            for (int j = 0; j < activeFeatures.length; j++) {
                perturbedFeatures.add(new Feature(nameFeature(e.getKey(), tokens.get(j).text, j), activeFeatures[j]));
            }
        }
        // Add the text features to the current example
        sampledExample.addAll(textFeatures);
        // Calculate text distance
        double totalTextDistance = textDistance / numTokens;
        // Label it using the full model.
        Prediction<Label> samplePrediction = innerModel.predict(sampledExample);
        double totalLength = tabularFeatures.size() + perturbedFeatures.size();
        // Combine the distances and transform into a weight
        // Currently this averages the two values based on their relative sizes.
        double weight = 1.0 - ((tabularFeatures.size() * (kernelDist(tabularDistance, kernelWidth) + perturbedFeatures.size() * totalTextDistance) / totalLength));
        // Generate the new sample with the appropriate label and weight.
        ArrayExample<Regressor> labelledSample = new ArrayExample<>(transformOutput(samplePrediction), (float) weight);
        labelledSample.addAll(tabularFeatures);
        labelledSample.addAll(perturbedFeatures);
        output.add(labelledSample);
    }
    return output;
}
Also used : ArrayList(java.util.ArrayList) Label(org.tribuo.classification.Label) Token(org.tribuo.util.tokens.Token) ColumnarFeature(org.tribuo.data.columnar.ColumnarFeature) Feature(org.tribuo.Feature) ArrayExample(org.tribuo.impl.ArrayExample) Random(java.util.Random) SplittableRandom(java.util.SplittableRandom) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) ListExample(org.tribuo.impl.ListExample) VariableIDInfo(org.tribuo.VariableIDInfo) Regressor(org.tribuo.regression.Regressor) RealInfo(org.tribuo.RealInfo) ListExample(org.tribuo.impl.ListExample) VariableInfo(org.tribuo.VariableInfo) CategoricalInfo(org.tribuo.CategoricalInfo) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map)

Example 2 with Token

use of org.tribuo.util.tokens.Token in project tribuo by oracle.

the class LIMEText method explain.

@Override
public LIMEExplanation explain(String inputText) {
    Example<Label> trueExample = extractor.extract(LabelFactory.UNKNOWN_LABEL, inputText);
    Prediction<Label> prediction = innerModel.predict(trueExample);
    ArrayExample<Regressor> bowExample = new ArrayExample<>(transformOutput(prediction));
    List<Token> tokens = tokenizerThreadLocal.get().tokenize(inputText);
    for (int i = 0; i < tokens.size(); i++) {
        bowExample.add(nameFeature(tokens.get(i).text, i), 1.0);
    }
    // Sample a dataset.
    List<Example<Regressor>> sample = sampleData(inputText, tokens);
    // Generate a sparse model on the sampled data.
    SparseModel<Regressor> model = trainExplainer(bowExample, sample);
    // Test the sparse model against the predictions of the real model.
    List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
    predictions.add(model.predict(bowExample));
    RegressionEvaluation evaluation = evaluator.evaluate(model, predictions, new SimpleDataSourceProvenance("LIMEText sampled data", regressionFactory));
    return new LIMEExplanation(model, prediction, evaluation);
}
Also used : SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) Prediction(org.tribuo.Prediction) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) Token(org.tribuo.util.tokens.Token) ArrayExample(org.tribuo.impl.ArrayExample) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor)

Example 3 with Token

use of org.tribuo.util.tokens.Token in project tribuo by oracle.

the class SplitFunctionTokenizer method advance.

@Override
public boolean advance() {
    if (cs == null) {
        throw new IllegalStateException("SplitFunctionTokenizer has not been reset.");
    }
    if (nextToken != null) {
        currentToken = nextToken;
        nextToken = null;
        return true;
    }
    if (p >= cs.length()) {
        return false;
    }
    currentToken = null;
    SplitResult splitResult;
    SplitType splitType;
    TokenType tokenType;
    tokenSb.delete(0, tokenSb.length());
    while (p < cs.length()) {
        int codepoint = cs.codePointAt(p);
        splitResult = splitFunction.apply(codepoint, p, cs);
        splitType = splitResult.splitType;
        tokenType = splitResult.tokenType;
        // where the end of the token is.
        if (splitType == SplitType.NO_SPLIT) {
            if (tokenSb.length() == 0) {
                start = p;
            }
            p += Character.charCount(codepoint);
            tokenSb.appendCodePoint(codepoint);
            currentType = tokenType;
            continue;
        }
        if (splitType == SplitType.SPLIT_AT) {
            if (tokenSb.length() > 0) {
                currentToken = new Token(tokenSb.toString(), start, p, currentType);
            }
            p += Character.charCount(codepoint);
            start = p;
            tokenSb.delete(0, tokenSb.length());
        } else if (splitType == SplitType.SPLIT_BEFORE) {
            if (tokenSb.length() > 0) {
                currentToken = new Token(tokenSb.toString(), start, p, currentType);
            }
            start = p;
            tokenSb.delete(0, tokenSb.length());
            tokenSb.appendCodePoint(codepoint);
            p += Character.charCount(codepoint);
        } else if (splitType == SplitType.SPLIT_AFTER) {
            p += Character.charCount(codepoint);
            tokenSb.appendCodePoint(codepoint);
            // no need to check the length since we just added a code point
            currentToken = new Token(tokenSb.toString(), start, p, tokenType);
            tokenSb.delete(0, tokenSb.length());
            start = p;
        } else if (splitType == SplitType.SPLIT_BEFORE_AND_AFTER) {
            // the next token which consists of just the character
            if (tokenSb.length() > 0) {
                currentToken = new Token(tokenSb.toString(), start, p, currentType);
                tokenSb.delete(0, tokenSb.length());
                start = p;
                p += Character.charCount(codepoint);
                tokenSb.appendCodePoint(codepoint);
                nextToken = new Token(tokenSb.toString(), start, p, tokenType);
                tokenSb.delete(0, tokenSb.length());
            } else {
                start = p;
                p += Character.charCount(codepoint);
                tokenSb.appendCodePoint(codepoint);
                currentToken = new Token(tokenSb.toString(), start, p, tokenType);
                tokenSb.delete(0, tokenSb.length());
            }
        }
        if (currentToken != null) {
            break;
        }
    }
    if (currentToken == null) {
        if (tokenSb.length() > 0) {
            currentToken = new Token(tokenSb.toString(), start, p, currentType);
        }
    }
    // We advanced if we have some stuff collected.
    if (currentToken != null) {
        ready = true;
        return true;
    } else {
        return false;
    }
}
Also used : TokenType(org.tribuo.util.tokens.Token.TokenType) Token(org.tribuo.util.tokens.Token)

Example 4 with Token

use of org.tribuo.util.tokens.Token in project tribuo by oracle.

the class WordpieceTokenizer method getWordpieceTokens.

/**
 * Generates the wordpiece tokens from the next token.
 */
private void getWordpieceTokens() {
    this.currentWordpieceTokens.clear();
    String text = currentToken.text;
    if (neverSplitTokens.contains(text)) {
        currentWordpieceTokens.add(currentToken);
        return;
    }
    List<Token> basicTokens = this.basicTokenizer.tokenize(text);
    for (Token basicToken : basicTokens) {
        text = basicToken.text;
        if (toLowerCase) {
            text = text.toLowerCase();
        }
        if (this.stripAccents) {
            text = normalize(text);
        }
        List<String> wordpieces = wordpiece.wordpiece(text);
        if (wordpieces.size() == 0) {
            return;
        } else if (wordpieces.size() == 1) {
            String wp = wordpieces.get(0);
            int start = basicToken.start + currentToken.start;
            int end = basicToken.end + currentToken.start;
            if (wp.equals(this.wordpiece.getUnknownToken())) {
                currentWordpieceTokens.add(new Token(wp, start, end, TokenType.UNKNOWN));
            } else {
                currentWordpieceTokens.add(new Token(wp, start, end, TokenType.WORD));
            }
        } else {
            int begin = currentToken.start + basicToken.start;
            for (String wp : wordpieces) {
                TokenType type = TokenType.PREFIX;
                int end = begin + wp.length();
                if (wp.startsWith("##")) {
                    end -= 2;
                    type = TokenType.SUFFIX;
                }
                currentWordpieceTokens.add(new Token(wp, begin, end, type));
                begin = end;
            }
        }
    }
}
Also used : TokenType(org.tribuo.util.tokens.Token.TokenType) Token(org.tribuo.util.tokens.Token)

Example 5 with Token

use of org.tribuo.util.tokens.Token in project tribuo by oracle.

the class LIMEColumnar method explainWithSamples.

protected Pair<LIMEExplanation, List<Example<Regressor>>> explainWithSamples(Map<String, String> input) {
    Optional<Example<Label>> optExample = generator.generateExample(input, false);
    if (optExample.isPresent()) {
        Example<Label> example = optExample.get();
        if ((textDomain.size() == 0) && (binarisedCDFs.size() == 0)) {
            // Short circuit if there are no text or binarised fields.
            return explainWithSamples(example);
        } else {
            Prediction<Label> prediction = innerModel.predict(example);
            // Build the input example with simplified text features
            ArrayExample<Regressor> labelledExample = new ArrayExample<>(transformOutput(prediction));
            // Add the tabular features
            for (Feature f : example) {
                if (tabularDomain.getID(f.getName()) != -1) {
                    labelledExample.add(f);
                }
            }
            // Extract the tabular features into a SparseVector for later
            SparseVector tabularVector = SparseVector.createSparseVector(labelledExample, tabularDomain, false);
            // Tokenize the text fields, and generate the perturbed text representation
            Map<String, String> exampleTextValues = new HashMap<>();
            Map<String, List<Token>> exampleTextTokens = new HashMap<>();
            for (Map.Entry<String, FieldProcessor> e : textFields.entrySet()) {
                String value = input.get(e.getKey());
                if (value != null) {
                    List<Token> tokens = tokenizerThreadLocal.get().tokenize(value);
                    for (int i = 0; i < tokens.size(); i++) {
                        labelledExample.add(nameFeature(e.getKey(), tokens.get(i).text, i), 1.0);
                    }
                    exampleTextValues.put(e.getKey(), value);
                    exampleTextTokens.put(e.getKey(), tokens);
                }
            }
            // Sample a dataset.
            List<Example<Regressor>> sample = sampleData(tabularVector, exampleTextValues, exampleTextTokens);
            // Generate a sparse model on the sampled data.
            SparseModel<Regressor> model = trainExplainer(labelledExample, sample);
            // Test the sparse model against the predictions of the real model.
            List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
            predictions.add(model.predict(labelledExample));
            RegressionEvaluation evaluation = evaluator.evaluate(model, predictions, new SimpleDataSourceProvenance("LIMEColumnar sampled data", regressionFactory));
            return new Pair<>(new LIMEExplanation(model, prediction, evaluation), sample);
        }
    } else {
        throw new IllegalArgumentException("Label not found in input " + input.toString());
    }
}
Also used : HashMap(java.util.HashMap) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) Token(org.tribuo.util.tokens.Token) SparseVector(org.tribuo.math.la.SparseVector) ColumnarFeature(org.tribuo.data.columnar.ColumnarFeature) Feature(org.tribuo.Feature) FieldProcessor(org.tribuo.data.columnar.FieldProcessor) ArrayExample(org.tribuo.impl.ArrayExample) RegressionEvaluation(org.tribuo.regression.evaluation.RegressionEvaluation) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) ListExample(org.tribuo.impl.ListExample) ArrayList(java.util.ArrayList) List(java.util.List) Regressor(org.tribuo.regression.Regressor) Pair(com.oracle.labs.mlrg.olcut.util.Pair) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) Prediction(org.tribuo.Prediction) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) HashMap(java.util.HashMap) Map(java.util.Map)

Aggregations

Token (org.tribuo.util.tokens.Token)6 ArrayList (java.util.ArrayList)4 Example (org.tribuo.Example)4 Label (org.tribuo.classification.Label)4 ArrayExample (org.tribuo.impl.ArrayExample)4 Regressor (org.tribuo.regression.Regressor)4 HashMap (java.util.HashMap)2 Map (java.util.Map)2 Random (java.util.Random)2 SplittableRandom (java.util.SplittableRandom)2 Feature (org.tribuo.Feature)2 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)2 Prediction (org.tribuo.Prediction)2 ColumnarFeature (org.tribuo.data.columnar.ColumnarFeature)2 ListExample (org.tribuo.impl.ListExample)2 SimpleDataSourceProvenance (org.tribuo.provenance.SimpleDataSourceProvenance)2 RegressionEvaluation (org.tribuo.regression.evaluation.RegressionEvaluation)2 TokenType (org.tribuo.util.tokens.Token.TokenType)2 Pair (com.oracle.labs.mlrg.olcut.util.Pair)1 List (java.util.List)1