use of org.tribuo.util.tokens.Token in project tribuo by oracle.
the class LIMEColumnar method sampleData.
/**
* Samples a dataset based on the provided text, tokens and tabular features.
*
* The text features are sampled using the {@link LIMEText} sampling approach,
* and the tabular features are sampled using the {@link LIMEBase} approach.
*
* The weight for each example is based on the distance for the tabular features,
* combined with the distance for the text features (which is a hamming distance).
* These distances are averaged using a weight function representing how many tokens
* there are in the text fields, and how many tabular features there are.
*
* This weight calculation is subject to change, as it's not necessarily optimal.
* @param tabularVector The tabular (i.e., non-text) features.
* @param text A map from the field names to the field values for the text fields.
* @param textTokens A map from the field names to lists of tokens for those fields.
* @return A sampled dataset.
*/
private List<Example<Regressor>> sampleData(SparseVector tabularVector, Map<String, String> text, Map<String, List<Token>> textTokens) {
List<Example<Regressor>> output = new ArrayList<>();
Random innerRNG = new Random(rng.nextLong());
for (int i = 0; i < numSamples; i++) {
// Create the full example
ListExample<Label> sampledExample = new ListExample<>(LabelFactory.UNKNOWN_LABEL);
// Tabular features.
List<Feature> tabularFeatures = new ArrayList<>();
// Sample the categorical and real features
for (VariableInfo info : tabularDomain) {
int id = ((VariableIDInfo) info).getID();
double inputValue = tabularVector.get(id);
if (info instanceof CategoricalInfo) {
// This one is tricksy as categorical info essentially implicitly includes a zero.
CategoricalInfo catInfo = (CategoricalInfo) info;
double sample = catInfo.frequencyBasedSample(innerRNG, numTrainingExamples);
// If we didn't sample zero.
if (Math.abs(sample) > 1e-10) {
Feature newFeature = new Feature(info.getName(), sample);
tabularFeatures.add(newFeature);
}
} else if (info instanceof RealInfo) {
RealInfo realInfo = (RealInfo) info;
// As realInfo is sparse we sample from the mixture distribution,
// either 0 or N(inputValue,variance).
// This assumes realInfo never observed a zero, which is enforced from v2.1
// TODO check this makes sense. If the input value is zero do we still want to sample spike and slab?
// If it's not zero do we want to?
int count = realInfo.getCount();
double threshold = count / ((double) numTrainingExamples);
if (innerRNG.nextDouble() < threshold) {
double variance = realInfo.getVariance();
double sample = (innerRNG.nextGaussian() * Math.sqrt(variance)) + inputValue;
Feature newFeature = new Feature(info.getName(), sample);
tabularFeatures.add(newFeature);
}
} else {
throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
}
}
// Sample the binarised categorical features
for (Map.Entry<String, double[]> e : binarisedCDFs.entrySet()) {
// Sample from the CDF
int sample = Util.sampleFromCDF(e.getValue(), innerRNG);
// If the sample isn't zero (which is defined to be the last value to make the indices work)
if (sample != (e.getValue().length - 1)) {
VariableInfo info = binarisedInfos.get(e.getKey()).get(sample);
Feature newFeature = new Feature(info.getName(), 1);
tabularFeatures.add(newFeature);
}
}
// Add the tabular features to the current example
sampledExample.addAll(tabularFeatures);
// Calculate tabular distance
double tabularDistance = measureDistance(tabularDomain, numTrainingExamples, tabularVector, SparseVector.createSparseVector(sampledExample, tabularDomain, false));
// features are the full text features
List<Feature> textFeatures = new ArrayList<>();
// Perturbed features are the binarised tokens
List<Feature> perturbedFeatures = new ArrayList<>();
// Sample the text features
double textDistance = 0.0;
long numTokens = 0;
for (Map.Entry<String, String> e : text.entrySet()) {
String curText = e.getValue();
List<Token> tokens = textTokens.get(e.getKey());
numTokens += tokens.size();
// Sample a new Example.
int[] activeFeatures = new int[tokens.size()];
char[] sampledText = curText.toCharArray();
for (int j = 0; j < activeFeatures.length; j++) {
activeFeatures[j] = innerRNG.nextInt(2);
if (activeFeatures[j] == 0) {
textDistance++;
Token curToken = tokens.get(j);
Arrays.fill(sampledText, curToken.start, curToken.end, '\0');
}
}
String sampledString = new String(sampledText);
sampledString = sampledString.replace("\0", "");
textFeatures.addAll(textFields.get(e.getKey()).process(sampledString));
for (int j = 0; j < activeFeatures.length; j++) {
perturbedFeatures.add(new Feature(nameFeature(e.getKey(), tokens.get(j).text, j), activeFeatures[j]));
}
}
// Add the text features to the current example
sampledExample.addAll(textFeatures);
// Calculate text distance
double totalTextDistance = textDistance / numTokens;
// Label it using the full model.
Prediction<Label> samplePrediction = innerModel.predict(sampledExample);
double totalLength = tabularFeatures.size() + perturbedFeatures.size();
// Combine the distances and transform into a weight
// Currently this averages the two values based on their relative sizes.
double weight = 1.0 - ((tabularFeatures.size() * (kernelDist(tabularDistance, kernelWidth) + perturbedFeatures.size() * totalTextDistance) / totalLength));
// Generate the new sample with the appropriate label and weight.
ArrayExample<Regressor> labelledSample = new ArrayExample<>(transformOutput(samplePrediction), (float) weight);
labelledSample.addAll(tabularFeatures);
labelledSample.addAll(perturbedFeatures);
output.add(labelledSample);
}
return output;
}
use of org.tribuo.util.tokens.Token in project tribuo by oracle.
the class LIMEText method explain.
@Override
public LIMEExplanation explain(String inputText) {
Example<Label> trueExample = extractor.extract(LabelFactory.UNKNOWN_LABEL, inputText);
Prediction<Label> prediction = innerModel.predict(trueExample);
ArrayExample<Regressor> bowExample = new ArrayExample<>(transformOutput(prediction));
List<Token> tokens = tokenizerThreadLocal.get().tokenize(inputText);
for (int i = 0; i < tokens.size(); i++) {
bowExample.add(nameFeature(tokens.get(i).text, i), 1.0);
}
// Sample a dataset.
List<Example<Regressor>> sample = sampleData(inputText, tokens);
// Generate a sparse model on the sampled data.
SparseModel<Regressor> model = trainExplainer(bowExample, sample);
// Test the sparse model against the predictions of the real model.
List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
predictions.add(model.predict(bowExample));
RegressionEvaluation evaluation = evaluator.evaluate(model, predictions, new SimpleDataSourceProvenance("LIMEText sampled data", regressionFactory));
return new LIMEExplanation(model, prediction, evaluation);
}
use of org.tribuo.util.tokens.Token in project tribuo by oracle.
the class SplitFunctionTokenizer method advance.
@Override
public boolean advance() {
if (cs == null) {
throw new IllegalStateException("SplitFunctionTokenizer has not been reset.");
}
if (nextToken != null) {
currentToken = nextToken;
nextToken = null;
return true;
}
if (p >= cs.length()) {
return false;
}
currentToken = null;
SplitResult splitResult;
SplitType splitType;
TokenType tokenType;
tokenSb.delete(0, tokenSb.length());
while (p < cs.length()) {
int codepoint = cs.codePointAt(p);
splitResult = splitFunction.apply(codepoint, p, cs);
splitType = splitResult.splitType;
tokenType = splitResult.tokenType;
// where the end of the token is.
if (splitType == SplitType.NO_SPLIT) {
if (tokenSb.length() == 0) {
start = p;
}
p += Character.charCount(codepoint);
tokenSb.appendCodePoint(codepoint);
currentType = tokenType;
continue;
}
if (splitType == SplitType.SPLIT_AT) {
if (tokenSb.length() > 0) {
currentToken = new Token(tokenSb.toString(), start, p, currentType);
}
p += Character.charCount(codepoint);
start = p;
tokenSb.delete(0, tokenSb.length());
} else if (splitType == SplitType.SPLIT_BEFORE) {
if (tokenSb.length() > 0) {
currentToken = new Token(tokenSb.toString(), start, p, currentType);
}
start = p;
tokenSb.delete(0, tokenSb.length());
tokenSb.appendCodePoint(codepoint);
p += Character.charCount(codepoint);
} else if (splitType == SplitType.SPLIT_AFTER) {
p += Character.charCount(codepoint);
tokenSb.appendCodePoint(codepoint);
// no need to check the length since we just added a code point
currentToken = new Token(tokenSb.toString(), start, p, tokenType);
tokenSb.delete(0, tokenSb.length());
start = p;
} else if (splitType == SplitType.SPLIT_BEFORE_AND_AFTER) {
// the next token which consists of just the character
if (tokenSb.length() > 0) {
currentToken = new Token(tokenSb.toString(), start, p, currentType);
tokenSb.delete(0, tokenSb.length());
start = p;
p += Character.charCount(codepoint);
tokenSb.appendCodePoint(codepoint);
nextToken = new Token(tokenSb.toString(), start, p, tokenType);
tokenSb.delete(0, tokenSb.length());
} else {
start = p;
p += Character.charCount(codepoint);
tokenSb.appendCodePoint(codepoint);
currentToken = new Token(tokenSb.toString(), start, p, tokenType);
tokenSb.delete(0, tokenSb.length());
}
}
if (currentToken != null) {
break;
}
}
if (currentToken == null) {
if (tokenSb.length() > 0) {
currentToken = new Token(tokenSb.toString(), start, p, currentType);
}
}
// We advanced if we have some stuff collected.
if (currentToken != null) {
ready = true;
return true;
} else {
return false;
}
}
use of org.tribuo.util.tokens.Token in project tribuo by oracle.
the class WordpieceTokenizer method getWordpieceTokens.
/**
* Generates the wordpiece tokens from the next token.
*/
private void getWordpieceTokens() {
this.currentWordpieceTokens.clear();
String text = currentToken.text;
if (neverSplitTokens.contains(text)) {
currentWordpieceTokens.add(currentToken);
return;
}
List<Token> basicTokens = this.basicTokenizer.tokenize(text);
for (Token basicToken : basicTokens) {
text = basicToken.text;
if (toLowerCase) {
text = text.toLowerCase();
}
if (this.stripAccents) {
text = normalize(text);
}
List<String> wordpieces = wordpiece.wordpiece(text);
if (wordpieces.size() == 0) {
return;
} else if (wordpieces.size() == 1) {
String wp = wordpieces.get(0);
int start = basicToken.start + currentToken.start;
int end = basicToken.end + currentToken.start;
if (wp.equals(this.wordpiece.getUnknownToken())) {
currentWordpieceTokens.add(new Token(wp, start, end, TokenType.UNKNOWN));
} else {
currentWordpieceTokens.add(new Token(wp, start, end, TokenType.WORD));
}
} else {
int begin = currentToken.start + basicToken.start;
for (String wp : wordpieces) {
TokenType type = TokenType.PREFIX;
int end = begin + wp.length();
if (wp.startsWith("##")) {
end -= 2;
type = TokenType.SUFFIX;
}
currentWordpieceTokens.add(new Token(wp, begin, end, type));
begin = end;
}
}
}
}
use of org.tribuo.util.tokens.Token in project tribuo by oracle.
the class LIMEColumnar method explainWithSamples.
protected Pair<LIMEExplanation, List<Example<Regressor>>> explainWithSamples(Map<String, String> input) {
Optional<Example<Label>> optExample = generator.generateExample(input, false);
if (optExample.isPresent()) {
Example<Label> example = optExample.get();
if ((textDomain.size() == 0) && (binarisedCDFs.size() == 0)) {
// Short circuit if there are no text or binarised fields.
return explainWithSamples(example);
} else {
Prediction<Label> prediction = innerModel.predict(example);
// Build the input example with simplified text features
ArrayExample<Regressor> labelledExample = new ArrayExample<>(transformOutput(prediction));
// Add the tabular features
for (Feature f : example) {
if (tabularDomain.getID(f.getName()) != -1) {
labelledExample.add(f);
}
}
// Extract the tabular features into a SparseVector for later
SparseVector tabularVector = SparseVector.createSparseVector(labelledExample, tabularDomain, false);
// Tokenize the text fields, and generate the perturbed text representation
Map<String, String> exampleTextValues = new HashMap<>();
Map<String, List<Token>> exampleTextTokens = new HashMap<>();
for (Map.Entry<String, FieldProcessor> e : textFields.entrySet()) {
String value = input.get(e.getKey());
if (value != null) {
List<Token> tokens = tokenizerThreadLocal.get().tokenize(value);
for (int i = 0; i < tokens.size(); i++) {
labelledExample.add(nameFeature(e.getKey(), tokens.get(i).text, i), 1.0);
}
exampleTextValues.put(e.getKey(), value);
exampleTextTokens.put(e.getKey(), tokens);
}
}
// Sample a dataset.
List<Example<Regressor>> sample = sampleData(tabularVector, exampleTextValues, exampleTextTokens);
// Generate a sparse model on the sampled data.
SparseModel<Regressor> model = trainExplainer(labelledExample, sample);
// Test the sparse model against the predictions of the real model.
List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
predictions.add(model.predict(labelledExample));
RegressionEvaluation evaluation = evaluator.evaluate(model, predictions, new SimpleDataSourceProvenance("LIMEColumnar sampled data", regressionFactory));
return new Pair<>(new LIMEExplanation(model, prediction, evaluation), sample);
}
} else {
throw new IllegalArgumentException("Label not found in input " + input.toString());
}
}
Aggregations