Search in sources :

Example 6 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class LabelledDataGenerator method sparseTrainTest.

/**
 * Generates a pair of datasets, where the features are sparse,
 * and unknown features appear in the test data. It has the same
 * 4 classes {Foo,Bar,Baz,Quux}.
 * @param negate Supply -1.0 to negate some values in this dataset.
 * @return A pair of train and test datasets.
 */
public static Pair<Dataset<Label>, Dataset<Label>> sparseTrainTest(double negate) {
    DataSourceProvenance provenance = new SimpleDataSourceProvenance("TrainingData", OffsetDateTime.now(), labelFactory);
    MutableDataset<Label> train = new MutableDataset<>(provenance, labelFactory);
    String[] names = new String[] { "A", "B", "C", "D" };
    double[] values = new double[] { 1.0, 0.5, 1.0, negate * 1.0 };
    train.add(new ArrayExample<>(new Label("Foo"), names, values));
    names = new String[] { "B", "D", "F", "H" };
    values = new double[] { 1.5, 0.35, 1.3, negate * 1.2 };
    train.add(new ArrayExample<>(new Label("Foo"), names, values));
    names = new String[] { "A", "J", "D", "M" };
    values = new double[] { 1.2, 0.45, 1.5, negate * 1.0 };
    train.add(new ArrayExample<>(new Label("Foo"), names, values));
    names = new String[] { "C", "E", "F", "H" };
    values = new double[] { negate * 1.1, 0.55, negate * 1.5, 0.5 };
    train.add(new ArrayExample<>(new Label("Bar"), names, values));
    names = new String[] { "E", "G", "F", "I" };
    values = new double[] { negate * 1.5, 0.25, negate * 1, 0.125 };
    train.add(new ArrayExample<>(new Label("Bar"), names, values));
    names = new String[] { "J", "K", "C", "E" };
    values = new double[] { negate * 1, 0.5, negate * 1.123, 0.123 };
    train.add(new ArrayExample<>(new Label("Bar"), names, values));
    names = new String[] { "E", "A", "K", "J" };
    values = new double[] { 1.5, 5.0, 0.5, 4.5 };
    train.add(new ArrayExample<>(new Label("Baz"), names, values));
    names = new String[] { "B", "C", "E", "H" };
    values = new double[] { 1.234, 5.1235, 0.1235, 6.0 };
    train.add(new ArrayExample<>(new Label("Baz"), names, values));
    names = new String[] { "A", "M", "I", "J" };
    values = new double[] { 1.734, 4.5, 0.5123, 5.5 };
    train.add(new ArrayExample<>(new Label("Baz"), names, values));
    names = new String[] { "Z", "A", "B", "C" };
    values = new double[] { negate * 1, 0.25, 5, 10.0 };
    train.add(new ArrayExample<>(new Label("Quux"), names, values));
    names = new String[] { "K", "V", "E", "D" };
    values = new double[] { negate * 1.4, 0.55, 5.65, 12.0 };
    train.add(new ArrayExample<>(new Label("Quux"), names, values));
    names = new String[] { "B", "G", "E", "A" };
    values = new double[] { negate * 1.9, 0.25, 5.9, 15 };
    train.add(new ArrayExample<>(new Label("Quux"), names, values));
    DataSourceProvenance testProvenance = new SimpleDataSourceProvenance("TestingData", OffsetDateTime.now(), labelFactory);
    MutableDataset<Label> test = new MutableDataset<>(testProvenance, labelFactory);
    names = new String[] { "AA", "B", "C", "D" };
    values = new double[] { 2.0, 0.45, 3.5, negate * 2.0 };
    test.add(new ArrayExample<>(new Label("Foo"), names, values));
    names = new String[] { "B", "BB", "F", "E" };
    values = new double[] { negate * 2.0, 0.55, negate * 2.5, 2.5 };
    test.add(new ArrayExample<>(new Label("Bar"), names, values));
    names = new String[] { "B", "E", "G", "H" };
    values = new double[] { 1.75, 5.0, 1.0, 6.5 };
    test.add(new ArrayExample<>(new Label("Baz"), names, values));
    names = new String[] { "B", "CC", "DD", "EE" };
    values = new double[] { negate * 1.5, 0.25, 5.0, 20.0 };
    test.add(new ArrayExample<>(new Label("Quux"), names, values));
    return new Pair<>(train, test);
}
Also used : SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) Label(org.tribuo.classification.Label) MutableDataset(org.tribuo.MutableDataset) DataSourceProvenance(org.tribuo.provenance.DataSourceProvenance) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 7 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class KDTree method query.

@Override
public List<List<Pair<Integer, Double>>> query(SGDVector[] points, int k) {
    int numQueries = points.length;
    @SuppressWarnings("unchecked") List<Pair<Integer, Double>>[] indexDistancePairListArray = (List<Pair<Integer, Double>>[]) new List[numQueries];
    // When the number of threads is 1, the overhead of thread pools must be avoided
    if (numThreads == 1) {
        for (int point = 0; point < numQueries; point++) {
            indexDistancePairListArray[point] = query(points[point], k);
        }
    } else {
        // This makes each k-d tree neighbor query in a separate thread
        ExecutorService executorService = Executors.newFixedThreadPool(numThreads);
        for (int pointInd = 0; pointInd < numQueries; pointInd++) {
            executorService.execute(new SingleQueryRunnable(pointInd, points[pointInd], k, indexDistancePairListArray));
        }
        executorService.shutdown();
        try {
            boolean finished = executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MINUTES);
            if (!finished) {
                throw new RuntimeException("Parallel execution failed");
            }
        } catch (InterruptedException e) {
            throw new RuntimeException("Parallel execution failed", e);
        }
    }
    return Arrays.asList(indexDistancePairListArray);
}
Also used : ExecutorService(java.util.concurrent.ExecutorService) List(java.util.List) ArrayList(java.util.ArrayList) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 8 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class LIMEColumnarTest method generateBinarisedDataset.

private Pair<RowProcessor<Label>, Dataset<Label>> generateBinarisedDataset() throws URISyntaxException {
    LabelFactory labelFactory = new LabelFactory();
    ResponseProcessor<Label> responseProcessor = new FieldResponseProcessor<>("Response", "N", labelFactory);
    Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
    fieldProcessors.put("A", new IdentityProcessor("A"));
    fieldProcessors.put("B", new DoubleFieldProcessor("B"));
    fieldProcessors.put("C", new DoubleFieldProcessor("C"));
    fieldProcessors.put("D", new IdentityProcessor("D"));
    fieldProcessors.put("TextField", new TextFieldProcessor("TextField", new BasicPipeline(tokenizer, 2)));
    RowProcessor<Label> rp = new RowProcessor<>(responseProcessor, fieldProcessors);
    CSVDataSource<Label> source = new CSVDataSource<>(LIMEColumnarTest.class.getResource("/org/tribuo/classification/explanations/lime/test-columnar.csv").toURI(), rp, true);
    Dataset<Label> dataset = new MutableDataset<>(source);
    return new Pair<>(rp, dataset);
}
Also used : TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) HashMap(java.util.HashMap) BasicPipeline(org.tribuo.data.text.impl.BasicPipeline) Label(org.tribuo.classification.Label) FieldResponseProcessor(org.tribuo.data.columnar.processors.response.FieldResponseProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) CSVDataSource(org.tribuo.data.csv.CSVDataSource) FieldProcessor(org.tribuo.data.columnar.FieldProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) LabelFactory(org.tribuo.classification.LabelFactory) IdentityProcessor(org.tribuo.data.columnar.processors.field.IdentityProcessor) RowProcessor(org.tribuo.data.columnar.RowProcessor) MutableDataset(org.tribuo.MutableDataset) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 9 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class LIMEColumnarTest method generateCategoricalDataset.

private Pair<RowProcessor<Label>, Dataset<Label>> generateCategoricalDataset() throws URISyntaxException {
    LabelFactory labelFactory = new LabelFactory();
    ResponseProcessor<Label> responseProcessor = new FieldResponseProcessor<>("Response", "N", labelFactory);
    Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
    fieldProcessors.put("A", new IdentityProcessor("A") {

        @Override
        public GeneratedFeatureType getFeatureType() {
            return GeneratedFeatureType.CATEGORICAL;
        }
    });
    fieldProcessors.put("B", new DoubleFieldProcessor("B"));
    fieldProcessors.put("C", new DoubleFieldProcessor("C"));
    fieldProcessors.put("D", new IdentityProcessor("D") {

        @Override
        public GeneratedFeatureType getFeatureType() {
            return GeneratedFeatureType.CATEGORICAL;
        }
    });
    fieldProcessors.put("TextField", new TextFieldProcessor("TextField", new BasicPipeline(tokenizer, 2)));
    RowProcessor<Label> rp = new RowProcessor<>(responseProcessor, fieldProcessors);
    CSVDataSource<Label> source = new CSVDataSource<>(LIMEColumnarTest.class.getResource("/org/tribuo/classification/explanations/lime/test-columnar.csv").toURI(), rp, true);
    Dataset<Label> dataset = new MutableDataset<>(source);
    return new Pair<>(rp, dataset);
}
Also used : TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) HashMap(java.util.HashMap) BasicPipeline(org.tribuo.data.text.impl.BasicPipeline) Label(org.tribuo.classification.Label) FieldResponseProcessor(org.tribuo.data.columnar.processors.response.FieldResponseProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) CSVDataSource(org.tribuo.data.csv.CSVDataSource) FieldProcessor(org.tribuo.data.columnar.FieldProcessor) DoubleFieldProcessor(org.tribuo.data.columnar.processors.field.DoubleFieldProcessor) TextFieldProcessor(org.tribuo.data.columnar.processors.field.TextFieldProcessor) LabelFactory(org.tribuo.classification.LabelFactory) IdentityProcessor(org.tribuo.data.columnar.processors.field.IdentityProcessor) RowProcessor(org.tribuo.data.columnar.RowProcessor) MutableDataset(org.tribuo.MutableDataset) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 10 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class LibLinearClassificationModel method innerGetExcuse.

/**
 * The call to model.getFeatureWeights in the public methods copies the
 * weights array so this inner method exists to save the copy in getExcuses.
 * <p>
 * If it becomes a problem then we could cache the feature weights in the
 * model.
 * @param e The example.
 * @param allFeatureWeights The feature weights.
 * @return An excuse for this example.
 */
@Override
protected Excuse<Label> innerGetExcuse(Example<Label> e, double[][] allFeatureWeights) {
    de.bwaldvogel.liblinear.Model model = models.get(0);
    double[] featureWeights = allFeatureWeights[0];
    int[] labels = model.getLabels();
    int numClasses = model.getNrClass();
    Prediction<Label> prediction = predict(e);
    Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
    if (numClasses == 2) {
        List<Pair<String, Double>> posScores = new ArrayList<>();
        List<Pair<String, Double>> negScores = new ArrayList<>();
        for (Feature f : e) {
            int id = featureIDMap.getID(f.getName());
            if (id > -1) {
                double score = featureWeights[id] * f.getValue();
                posScores.add(new Pair<>(f.getName(), score));
                negScores.add(new Pair<>(f.getName(), -score));
            }
        }
        posScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
        negScores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
        weightMap.put(outputIDInfo.getOutput(labels[0]).getLabel(), posScores);
        weightMap.put(outputIDInfo.getOutput(labels[1]).getLabel(), negScores);
    } else {
        for (int i = 0; i < labels.length; i++) {
            List<Pair<String, Double>> classScores = new ArrayList<>();
            for (Feature f : e) {
                int id = featureIDMap.getID(f.getName());
                if (id > -1) {
                    double score = featureWeights[id * numClasses + i] * f.getValue();
                    classScores.add(new Pair<>(f.getName(), score));
                }
            }
            classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB()));
            weightMap.put(outputIDInfo.getOutput(labels[i]).getLabel(), classScores);
        }
    }
    return new Excuse<>(e, prediction, weightMap);
}
Also used : ONNXNode(org.tribuo.util.onnx.ONNXNode) FeatureNode(de.bwaldvogel.liblinear.FeatureNode) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Label(org.tribuo.classification.Label) ArrayList(java.util.ArrayList) Feature(org.tribuo.Feature) ArrayList(java.util.ArrayList) List(java.util.List) Pair(com.oracle.labs.mlrg.olcut.util.Pair) Excuse(org.tribuo.Excuse)

Aggregations

Pair (com.oracle.labs.mlrg.olcut.util.Pair)59 ArrayList (java.util.ArrayList)27 List (java.util.List)21 HashMap (java.util.HashMap)18 MutableDataset (org.tribuo.MutableDataset)17 SimpleDataSourceProvenance (org.tribuo.provenance.SimpleDataSourceProvenance)16 Label (org.tribuo.classification.Label)14 Feature (org.tribuo.Feature)11 Regressor (org.tribuo.regression.Regressor)11 Prediction (org.tribuo.Prediction)10 DenseVector (org.tribuo.math.la.DenseVector)10 SparseVector (org.tribuo.math.la.SparseVector)10 SGDVector (org.tribuo.math.la.SGDVector)9 Map (java.util.Map)7 Example (org.tribuo.Example)7 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)7 PriorityQueue (java.util.PriorityQueue)6 Excuse (org.tribuo.Excuse)5 Model (org.tribuo.Model)5 LabelFactory (org.tribuo.classification.LabelFactory)5