use of org.tribuo.VariableInfo in project tribuo by oracle.
the class RowProcessor method expandRegexMapping.
/**
* Uses similar logic to {@link org.tribuo.transform.TransformationMap#validateTransformations} to check the regexes
* against the supplied feature map. Throws an IllegalArgumentException if any regexes overlap with
* themselves, or with the currently defined set of fieldProcessorMap.
* @param featureMap The feature map to use to expand the regexes.
*/
public void expandRegexMapping(ImmutableFeatureMap featureMap) {
ArrayList<String> fieldNames = new ArrayList<>(featureMap.size());
for (VariableInfo v : featureMap) {
String[] split = FEATURE_NAME_PATTERN.split(v.getName(), 1);
String fieldName = split[0];
fieldNames.add(fieldName);
}
expandRegexMapping(fieldNames);
}
use of org.tribuo.VariableInfo in project tribuo by oracle.
the class ClassificationTest method classificationCNNTest.
@Test
public void classificationCNNTest() throws IOException {
// Create the train and test data
Pair<Dataset<Label>, Dataset<Label>> data = generateImageData(512, 10, 128, 5, 42);
Dataset<Label> trainData = data.getA();
Dataset<Label> testData = data.getB();
// Build the CNN graph
GraphDefTuple graphDefTuple = CNNExamples.buildLeNetGraph(INPUT_NAME, 10, 255, trainData.getOutputs().size());
// Configure the trainer
Map<String, Float> gradientParams = new HashMap<>();
gradientParams.put("learningRate", 0.01f);
gradientParams.put("initialAccumulatorValue", 0.1f);
FeatureConverter imageConverter = new ImageConverter(INPUT_NAME, 10, 10, 1);
OutputConverter<Label> outputConverter = new LabelConverter();
TensorFlowTrainer<Label> trainer = new TensorFlowTrainer<>(graphDefTuple.graphDef, graphDefTuple.outputName, GradientOptimiser.ADAGRAD, gradientParams, imageConverter, outputConverter, 16, 5, 16, -1);
// Train the model
TensorFlowModel<Label> model = trainer.train(trainData);
// Make some predictions
List<Prediction<Label>> predictions = model.predict(testData);
// Run smoke test evaluation
LabelEvaluation eval = new LabelEvaluator().evaluate(model, predictions, testData.getProvenance());
Assertions.assertTrue(eval.accuracy() > 0.0);
// Check Tribuo serialization
Helpers.testModelSerialization(model, Label.class);
// Check saved model bundle export
Path outputPath = Files.createTempDirectory("tf-classification-cnn-test");
model.exportModel(outputPath.toString());
try (Stream<Path> f = Files.list(outputPath)) {
List<Path> files = f.collect(Collectors.toList());
Assertions.assertNotEquals(0, files.size());
}
// Create external model from bundle
Map<Label, Integer> outputMapping = new HashMap<>();
for (Pair<Integer, Label> p : model.getOutputIDInfo()) {
outputMapping.put(p.getB(), p.getA());
}
Map<String, Integer> featureMapping = new HashMap<>();
ImmutableFeatureMap featureIDMap = model.getFeatureIDMap();
for (VariableInfo info : featureIDMap) {
featureMapping.put(info.getName(), featureIDMap.getID(info.getName()));
}
TensorFlowSavedModelExternalModel<Label> externalModel = TensorFlowSavedModelExternalModel.createTensorflowModel(trainData.getOutputFactory(), featureMapping, outputMapping, model.getOutputName(), imageConverter, outputConverter, outputPath.toString());
// Check predictions are equal
List<Prediction<Label>> externalPredictions = externalModel.predict(testData);
checkPredictionEquality(predictions, externalPredictions);
// Cleanup saved model bundle
externalModel.close();
Files.walk(outputPath).sorted(Comparator.reverseOrder()).map(Path::toFile).forEach(File::delete);
Assertions.assertFalse(Files.exists(outputPath));
// Cleanup created model
model.close();
}
use of org.tribuo.VariableInfo in project tribuo by oracle.
the class LIMEBase method samplePoint.
/**
* Samples a single example from the supplied feature map and input vector.
* @param rng The rng to use.
* @param fMap The feature map describing the domain of the features.
* @param numTrainingExamples The number of training examples the fMap has seen.
* @param input The input sparse vector to use.
* @return An Example sampled from the supplied feature map and input vector.
*/
public static Example<Label> samplePoint(Random rng, ImmutableFeatureMap fMap, long numTrainingExamples, SparseVector input) {
ArrayList<String> names = new ArrayList<>();
ArrayList<Double> values = new ArrayList<>();
for (VariableInfo info : fMap) {
int id = ((VariableIDInfo) info).getID();
double inputValue = input.get(id);
if (info instanceof CategoricalInfo) {
// This one is tricksy as categorical info essentially implicitly includes a zero.
CategoricalInfo catInfo = (CategoricalInfo) info;
double sample = catInfo.frequencyBasedSample(rng, numTrainingExamples);
// If we didn't sample zero.
if (Math.abs(sample) > 1e-10) {
names.add(info.getName());
values.add(sample);
}
} else if (info instanceof RealInfo) {
RealInfo realInfo = (RealInfo) info;
// As realInfo is sparse we sample from the mixture distribution,
// either 0 or N(inputValue,variance).
// This assumes realInfo never observed a zero, which is enforced from v2.1
// TODO check this makes sense. If the input value is zero do we still want to sample spike and slab?
// If it's not zero do we want to?
int count = realInfo.getCount();
double threshold = count / ((double) numTrainingExamples);
if (rng.nextDouble() < threshold) {
double variance = realInfo.getVariance();
double sample = (rng.nextGaussian() * Math.sqrt(variance)) + inputValue;
names.add(info.getName());
values.add(sample);
}
} else {
throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
}
}
return new ArrayExample<>(LabelFactory.UNKNOWN_LABEL, names.toArray(new String[0]), Util.toPrimitiveDouble(values));
}
use of org.tribuo.VariableInfo in project tribuo by oracle.
the class LIMEColumnar method sampleData.
/**
* Samples a dataset based on the provided text, tokens and tabular features.
*
* The text features are sampled using the {@link LIMEText} sampling approach,
* and the tabular features are sampled using the {@link LIMEBase} approach.
*
* The weight for each example is based on the distance for the tabular features,
* combined with the distance for the text features (which is a hamming distance).
* These distances are averaged using a weight function representing how many tokens
* there are in the text fields, and how many tabular features there are.
*
* This weight calculation is subject to change, as it's not necessarily optimal.
* @param tabularVector The tabular (i.e., non-text) features.
* @param text A map from the field names to the field values for the text fields.
* @param textTokens A map from the field names to lists of tokens for those fields.
* @return A sampled dataset.
*/
private List<Example<Regressor>> sampleData(SparseVector tabularVector, Map<String, String> text, Map<String, List<Token>> textTokens) {
List<Example<Regressor>> output = new ArrayList<>();
Random innerRNG = new Random(rng.nextLong());
for (int i = 0; i < numSamples; i++) {
// Create the full example
ListExample<Label> sampledExample = new ListExample<>(LabelFactory.UNKNOWN_LABEL);
// Tabular features.
List<Feature> tabularFeatures = new ArrayList<>();
// Sample the categorical and real features
for (VariableInfo info : tabularDomain) {
int id = ((VariableIDInfo) info).getID();
double inputValue = tabularVector.get(id);
if (info instanceof CategoricalInfo) {
// This one is tricksy as categorical info essentially implicitly includes a zero.
CategoricalInfo catInfo = (CategoricalInfo) info;
double sample = catInfo.frequencyBasedSample(innerRNG, numTrainingExamples);
// If we didn't sample zero.
if (Math.abs(sample) > 1e-10) {
Feature newFeature = new Feature(info.getName(), sample);
tabularFeatures.add(newFeature);
}
} else if (info instanceof RealInfo) {
RealInfo realInfo = (RealInfo) info;
// As realInfo is sparse we sample from the mixture distribution,
// either 0 or N(inputValue,variance).
// This assumes realInfo never observed a zero, which is enforced from v2.1
// TODO check this makes sense. If the input value is zero do we still want to sample spike and slab?
// If it's not zero do we want to?
int count = realInfo.getCount();
double threshold = count / ((double) numTrainingExamples);
if (innerRNG.nextDouble() < threshold) {
double variance = realInfo.getVariance();
double sample = (innerRNG.nextGaussian() * Math.sqrt(variance)) + inputValue;
Feature newFeature = new Feature(info.getName(), sample);
tabularFeatures.add(newFeature);
}
} else {
throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
}
}
// Sample the binarised categorical features
for (Map.Entry<String, double[]> e : binarisedCDFs.entrySet()) {
// Sample from the CDF
int sample = Util.sampleFromCDF(e.getValue(), innerRNG);
// If the sample isn't zero (which is defined to be the last value to make the indices work)
if (sample != (e.getValue().length - 1)) {
VariableInfo info = binarisedInfos.get(e.getKey()).get(sample);
Feature newFeature = new Feature(info.getName(), 1);
tabularFeatures.add(newFeature);
}
}
// Add the tabular features to the current example
sampledExample.addAll(tabularFeatures);
// Calculate tabular distance
double tabularDistance = measureDistance(tabularDomain, numTrainingExamples, tabularVector, SparseVector.createSparseVector(sampledExample, tabularDomain, false));
// features are the full text features
List<Feature> textFeatures = new ArrayList<>();
// Perturbed features are the binarised tokens
List<Feature> perturbedFeatures = new ArrayList<>();
// Sample the text features
double textDistance = 0.0;
long numTokens = 0;
for (Map.Entry<String, String> e : text.entrySet()) {
String curText = e.getValue();
List<Token> tokens = textTokens.get(e.getKey());
numTokens += tokens.size();
// Sample a new Example.
int[] activeFeatures = new int[tokens.size()];
char[] sampledText = curText.toCharArray();
for (int j = 0; j < activeFeatures.length; j++) {
activeFeatures[j] = innerRNG.nextInt(2);
if (activeFeatures[j] == 0) {
textDistance++;
Token curToken = tokens.get(j);
Arrays.fill(sampledText, curToken.start, curToken.end, '\0');
}
}
String sampledString = new String(sampledText);
sampledString = sampledString.replace("\0", "");
textFeatures.addAll(textFields.get(e.getKey()).process(sampledString));
for (int j = 0; j < activeFeatures.length; j++) {
perturbedFeatures.add(new Feature(nameFeature(e.getKey(), tokens.get(j).text, j), activeFeatures[j]));
}
}
// Add the text features to the current example
sampledExample.addAll(textFeatures);
// Calculate text distance
double totalTextDistance = textDistance / numTokens;
// Label it using the full model.
Prediction<Label> samplePrediction = innerModel.predict(sampledExample);
double totalLength = tabularFeatures.size() + perturbedFeatures.size();
// Combine the distances and transform into a weight
// Currently this averages the two values based on their relative sizes.
double weight = 1.0 - ((tabularFeatures.size() * (kernelDist(tabularDistance, kernelWidth) + perturbedFeatures.size() * totalTextDistance) / totalLength));
// Generate the new sample with the appropriate label and weight.
ArrayExample<Regressor> labelledSample = new ArrayExample<>(transformOutput(samplePrediction), (float) weight);
labelledSample.addAll(tabularFeatures);
labelledSample.addAll(perturbedFeatures);
output.add(labelledSample);
}
return output;
}
use of org.tribuo.VariableInfo in project tribuo by oracle.
the class TestLibSVM method testOnnxSerialization.
private static void testOnnxSerialization(Pair<Dataset<Regressor>, Dataset<Regressor>> datasetPair, LibSVMRegressionTrainer trainer) throws IOException, OrtException {
LibSVMRegressionModel model = (LibSVMRegressionModel) trainer.train(datasetPair.getA());
// Write out model
Path onnxFile = Files.createTempFile("tribuo-libsvm-test", ".onnx");
model.saveONNXModel("org.tribuo.regression.libsvm.test", 1, onnxFile);
// Prep mappings
Map<String, Integer> featureMapping = new HashMap<>();
for (VariableInfo f : model.getFeatureIDMap()) {
VariableIDInfo id = (VariableIDInfo) f;
featureMapping.put(id.getName(), id.getID());
}
Map<Regressor, Integer> outputMapping = new HashMap<>();
for (Pair<Integer, Regressor> l : model.getOutputIDInfo()) {
outputMapping.put(l.getB(), l.getA());
}
String arch = System.getProperty("os.arch");
if (arch.equalsIgnoreCase("amd64") || arch.equalsIgnoreCase("x86_64")) {
// Initialise the OrtEnvironment to load the native library
// (as OrtSession.SessionOptions doesn't trigger the static initializer).
OrtEnvironment env = OrtEnvironment.getEnvironment();
env.close();
// Load in via ORT
ONNXExternalModel<Regressor> onnxModel = ONNXExternalModel.createOnnxModel(new RegressionFactory(), featureMapping, outputMapping, new DenseTransformer(), new RegressorTransformer(), new OrtSession.SessionOptions(), onnxFile, "input");
// Generate predictions
List<Prediction<Regressor>> nativePredictions = model.predict(datasetPair.getB());
List<Prediction<Regressor>> onnxPredictions = onnxModel.predict(datasetPair.getB());
// Assert the predictions are identical
for (int i = 0; i < nativePredictions.size(); i++) {
Prediction<Regressor> tribuo = nativePredictions.get(i);
Prediction<Regressor> external = onnxPredictions.get(i);
assertArrayEquals(tribuo.getOutput().getNames(), external.getOutput().getNames());
if (model.isStandardized()) {
// Standardized models are less numerically stable when cast into floats
double[] tribuoValues = tribuo.getOutput().getValues();
double[] externalValues = external.getOutput().getValues();
assertEquals(tribuoValues.length, externalValues.length);
for (int j = 0; j < tribuoValues.length; j++) {
// compute sf comparison
assertEquals(tribuoValues[j], externalValues[j], Math.abs(tribuoValues[j]) / 1e3);
}
} else {
assertArrayEquals(tribuo.getOutput().getValues(), external.getOutput().getValues(), 1e-4);
}
}
// Check that the provenance can be extracted and is the same
ModelProvenance modelProv = model.getProvenance();
Optional<ModelProvenance> optProv = onnxModel.getTribuoProvenance();
assertTrue(optProv.isPresent());
ModelProvenance onnxProv = optProv.get();
assertNotSame(onnxProv, modelProv);
assertEquals(modelProv, onnxProv);
onnxModel.close();
} else {
logger.warning("ORT based tests only supported on x86_64, found " + arch);
}
onnxFile.toFile().delete();
}
Aggregations