use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class LabelledDataGenerator method sparseTrainTest.
/**
* Generates a pair of datasets, where the features are sparse,
* and unknown features appear in the test data. It has the same
* 4 classes {Foo,Bar,Baz,Quux}.
* @param negate Supply -1.0 to negate some values in this dataset.
* @return A pair of train and test datasets.
*/
public static Pair<Dataset<Label>, Dataset<Label>> sparseTrainTest(double negate) {
DataSourceProvenance provenance = new SimpleDataSourceProvenance("TrainingData", OffsetDateTime.now(), labelFactory);
MutableDataset<Label> train = new MutableDataset<>(provenance, labelFactory);
String[] names = new String[] { "A", "B", "C", "D" };
double[] values = new double[] { 1.0, 0.5, 1.0, negate * 1.0 };
train.add(new ArrayExample<>(new Label("Foo"), names, values));
names = new String[] { "B", "D", "F", "H" };
values = new double[] { 1.5, 0.35, 1.3, negate * 1.2 };
train.add(new ArrayExample<>(new Label("Foo"), names, values));
names = new String[] { "A", "J", "D", "M" };
values = new double[] { 1.2, 0.45, 1.5, negate * 1.0 };
train.add(new ArrayExample<>(new Label("Foo"), names, values));
names = new String[] { "C", "E", "F", "H" };
values = new double[] { negate * 1.1, 0.55, negate * 1.5, 0.5 };
train.add(new ArrayExample<>(new Label("Bar"), names, values));
names = new String[] { "E", "G", "F", "I" };
values = new double[] { negate * 1.5, 0.25, negate * 1, 0.125 };
train.add(new ArrayExample<>(new Label("Bar"), names, values));
names = new String[] { "J", "K", "C", "E" };
values = new double[] { negate * 1, 0.5, negate * 1.123, 0.123 };
train.add(new ArrayExample<>(new Label("Bar"), names, values));
names = new String[] { "E", "A", "K", "J" };
values = new double[] { 1.5, 5.0, 0.5, 4.5 };
train.add(new ArrayExample<>(new Label("Baz"), names, values));
names = new String[] { "B", "C", "E", "H" };
values = new double[] { 1.234, 5.1235, 0.1235, 6.0 };
train.add(new ArrayExample<>(new Label("Baz"), names, values));
names = new String[] { "A", "M", "I", "J" };
values = new double[] { 1.734, 4.5, 0.5123, 5.5 };
train.add(new ArrayExample<>(new Label("Baz"), names, values));
names = new String[] { "Z", "A", "B", "C" };
values = new double[] { negate * 1, 0.25, 5, 10.0 };
train.add(new ArrayExample<>(new Label("Quux"), names, values));
names = new String[] { "K", "V", "E", "D" };
values = new double[] { negate * 1.4, 0.55, 5.65, 12.0 };
train.add(new ArrayExample<>(new Label("Quux"), names, values));
names = new String[] { "B", "G", "E", "A" };
values = new double[] { negate * 1.9, 0.25, 5.9, 15 };
train.add(new ArrayExample<>(new Label("Quux"), names, values));
DataSourceProvenance testProvenance = new SimpleDataSourceProvenance("TestingData", OffsetDateTime.now(), labelFactory);
MutableDataset<Label> test = new MutableDataset<>(testProvenance, labelFactory);
names = new String[] { "AA", "B", "C", "D" };
values = new double[] { 2.0, 0.45, 3.5, negate * 2.0 };
test.add(new ArrayExample<>(new Label("Foo"), names, values));
names = new String[] { "B", "BB", "F", "E" };
values = new double[] { negate * 2.0, 0.55, negate * 2.5, 2.5 };
test.add(new ArrayExample<>(new Label("Bar"), names, values));
names = new String[] { "B", "E", "G", "H" };
values = new double[] { 1.75, 5.0, 1.0, 6.5 };
test.add(new ArrayExample<>(new Label("Baz"), names, values));
names = new String[] { "B", "CC", "DD", "EE" };
values = new double[] { negate * 1.5, 0.25, 5.0, 20.0 };
test.add(new ArrayExample<>(new Label("Quux"), names, values));
return new Pair<>(train, test);
}
use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class KDTree method query.
@Override
public List<List<Pair<Integer, Double>>> query(SGDVector[] points, int k) {
int numQueries = points.length;
@SuppressWarnings("unchecked") List<Pair<Integer, Double>>[] indexDistancePairListArray = (List<Pair<Integer, Double>>[]) new List[numQueries];
// When the number of threads is 1, the overhead of thread pools must be avoided
if (numThreads == 1) {
for (int point = 0; point < numQueries; point++) {
indexDistancePairListArray[point] = query(points[point], k);
}
} else {
// This makes each k-d tree neighbor query in a separate thread
ExecutorService executorService = Executors.newFixedThreadPool(numThreads);
for (int pointInd = 0; pointInd < numQueries; pointInd++) {
executorService.execute(new SingleQueryRunnable(pointInd, points[pointInd], k, indexDistancePairListArray));
}
executorService.shutdown();
try {
boolean finished = executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MINUTES);
if (!finished) {
throw new RuntimeException("Parallel execution failed");
}
} catch (InterruptedException e) {
throw new RuntimeException("Parallel execution failed", e);
}
}
return Arrays.asList(indexDistancePairListArray);
}
use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class LIMEColumnarTest method generateBinarisedDataset.
private Pair<RowProcessor<Label>, Dataset<Label>> generateBinarisedDataset() throws URISyntaxException {
LabelFactory labelFactory = new LabelFactory();
ResponseProcessor<Label> responseProcessor = new FieldResponseProcessor<>("Response", "N", labelFactory);
Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
fieldProcessors.put("A", new IdentityProcessor("A"));
fieldProcessors.put("B", new DoubleFieldProcessor("B"));
fieldProcessors.put("C", new DoubleFieldProcessor("C"));
fieldProcessors.put("D", new IdentityProcessor("D"));
fieldProcessors.put("TextField", new TextFieldProcessor("TextField", new BasicPipeline(tokenizer, 2)));
RowProcessor<Label> rp = new RowProcessor<>(responseProcessor, fieldProcessors);
CSVDataSource<Label> source = new CSVDataSource<>(LIMEColumnarTest.class.getResource("/org/tribuo/classification/explanations/lime/test-columnar.csv").toURI(), rp, true);
Dataset<Label> dataset = new MutableDataset<>(source);
return new Pair<>(rp, dataset);
}
use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class LIMEColumnarTest method generateCategoricalDataset.
private Pair<RowProcessor<Label>, Dataset<Label>> generateCategoricalDataset() throws URISyntaxException {
LabelFactory labelFactory = new LabelFactory();
ResponseProcessor<Label> responseProcessor = new FieldResponseProcessor<>("Response", "N", labelFactory);
Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
fieldProcessors.put("A", new IdentityProcessor("A") {
@Override
public GeneratedFeatureType getFeatureType() {
return GeneratedFeatureType.CATEGORICAL;
}
});
fieldProcessors.put("B", new DoubleFieldProcessor("B"));
fieldProcessors.put("C", new DoubleFieldProcessor("C"));
fieldProcessors.put("D", new IdentityProcessor("D") {
@Override
public GeneratedFeatureType getFeatureType() {
return GeneratedFeatureType.CATEGORICAL;
}
});
fieldProcessors.put("TextField", new TextFieldProcessor("TextField", new BasicPipeline(tokenizer, 2)));
RowProcessor<Label> rp = new RowProcessor<>(responseProcessor, fieldProcessors);
CSVDataSource<Label> source = new CSVDataSource<>(LIMEColumnarTest.class.getResource("/org/tribuo/classification/explanations/lime/test-columnar.csv").toURI(), rp, true);
Dataset<Label> dataset = new MutableDataset<>(source);
return new Pair<>(rp, dataset);
}
use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class LibLinearClassificationModel method innerGetExcuse.
/**
* The call to model.getFeatureWeights in the public methods copies the
* weights array so this inner method exists to save the copy in getExcuses.
* <p>
* If it becomes a problem then we could cache the feature weights in the
* model.
* @param e The example.
* @param allFeatureWeights The feature weights.
* @return An excuse for this example.
*/
@Override
protected Excuse<Label> innerGetExcuse(Example<Label> e, double[][] allFeatureWeights) {
de.bwaldvogel.liblinear.Model model = models.get(0);
double[] featureWeights = allFeatureWeights[0];
int[] labels = model.getLabels();
int numClasses = model.getNrClass();
Prediction<Label> prediction = predict(e);
Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
if (numClasses == 2) {
List<Pair<String, Double>> posScores = new ArrayList<>();
List<Pair<String, Double>> negScores = new ArrayList<>();
for (Feature f : e) {
int id = featureIDMap.getID(f.getName());
if (id > -1) {
double score = featureWeights[id] * f.getValue();
posScores.add(new Pair<>(f.getName(), score));
negScores.add(new Pair<>(f.getName(), -score));
}
}
posScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
negScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
weightMap.put(outputIDInfo.getOutput(labels[0]).getLabel(), posScores);
weightMap.put(outputIDInfo.getOutput(labels[1]).getLabel(), negScores);
} else {
for (int i = 0; i < labels.length; i++) {
List<Pair<String, Double>> classScores = new ArrayList<>();
for (Feature f : e) {
int id = featureIDMap.getID(f.getName());
if (id > -1) {
double score = featureWeights[id * numClasses + i] * f.getValue();
classScores.add(new Pair<>(f.getName(), score));
}
}
classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB()));
weightMap.put(outputIDInfo.getOutput(labels[i]).getLabel(), classScores);
}
}
return new Excuse<>(e, prediction, weightMap);
}
Aggregations