Search in sources :

Example 41 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class DummyRegressionTrainer method train.

@Override
public DummyRegressionModel train(Dataset<Regressor> examples, Map<String, Provenance> instanceProvenance, int invocationCount) {
    if (invocationCount != INCREMENT_INVOCATION_COUNT) {
        setInvocationCount(invocationCount);
    }
    ModelProvenance provenance = new ModelProvenance(DummyRegressionModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), instanceProvenance);
    trainInvocationCounter++;
    ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo();
    Set<Regressor> domain = outputInfo.getDomain();
    double[][] outputs = new double[outputInfo.size()][examples.size()];
    int i = 0;
    for (Example<Regressor> e : examples) {
        for (Regressor.DimensionTuple r : e.getOutput()) {
            int id = outputInfo.getID(r);
            outputs[id][i] = r.getValue();
        }
        i++;
    }
    Regressor regressor;
    switch(dummyType) {
        case CONSTANT:
            {
                Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
                for (Regressor r : domain) {
                    int id = outputInfo.getID(r);
                    output[id] = new Regressor.DimensionTuple(r.getNames()[0], constantValue);
                }
                regressor = new Regressor(output);
                return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), outputInfo, dummyType, regressor);
            }
        case MEAN:
            {
                Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
                for (Regressor r : domain) {
                    int id = outputInfo.getID(r);
                    output[id] = new Regressor.DimensionTuple(r.getNames()[0], Util.mean(outputs[id]));
                }
                regressor = new Regressor(output);
                return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), outputInfo, dummyType, regressor);
            }
        case MEDIAN:
            {
                Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
                for (Regressor r : domain) {
                    int id = outputInfo.getID(r);
                    Arrays.sort(outputs[id]);
                    output[id] = new Regressor.DimensionTuple(r.getNames()[0], outputs[id][outputs[id].length / 2]);
                }
                regressor = new Regressor(output);
                return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), outputInfo, dummyType, regressor);
            }
        case QUARTILE:
            {
                Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
                for (Regressor r : domain) {
                    int id = outputInfo.getID(r);
                    Arrays.sort(outputs[id]);
                    output[id] = new Regressor.DimensionTuple(r.getNames()[0], outputs[id][(int) (quartile * outputs[id].length)]);
                }
                regressor = new Regressor(output);
                return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), outputInfo, dummyType, regressor);
            }
        case GAUSSIAN:
            {
                double[] means = new double[outputs.length];
                double[] variances = new double[outputs.length];
                String[] names = new String[outputs.length];
                for (Regressor r : domain) {
                    int id = outputInfo.getID(r);
                    names[id] = r.getNames()[0];
                    Pair<Double, Double> meanVariance = Util.meanAndVariance(outputs[id]);
                    means[id] = meanVariance.getA();
                    variances[id] = meanVariance.getB();
                }
                return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), outputInfo, seed, means, variances, names);
            }
        default:
            throw new IllegalStateException("Unknown dummyType " + dummyType);
    }
}
Also used : ModelProvenance(org.tribuo.provenance.ModelProvenance) Regressor(org.tribuo.regression.Regressor) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 42 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class EvaluationAggregationTests method xval.

public static void xval() {
    Trainer<Label> trainer = DummyClassifierTrainer.createUniformTrainer(1L);
    Pair<Dataset<Label>, Dataset<Label>> datasets = LabelledDataGenerator.denseTrainTest();
    Dataset<Label> trainData = datasets.getA();
    Evaluator<Label, LabelEvaluation> evaluator = factory.getEvaluator();
    CrossValidation<Label, LabelEvaluation> xval = new CrossValidation<>(trainer, trainData, evaluator, 5);
    List<Pair<LabelEvaluation, Model<Label>>> results = xval.evaluate();
    List<LabelEvaluation> evals = results.stream().map(Pair::getA).collect(Collectors.toList());
    // Summarize across everything
    Map<MetricID<Label>, DescriptiveStats> summary = EvaluationAggregator.summarize(evals);
    List<MetricID<Label>> keys = new ArrayList<>(summary.keySet()).stream().sorted(Comparator.comparing(Pair::getB)).collect(Collectors.toList());
    for (MetricID<Label> key : keys) {
        DescriptiveStats stats = summary.get(key);
        out.printf("%-10s  %.5f (%.5f)%n", key, stats.getMean(), stats.getStandardDeviation());
    }
    // Summarize across macro F1s only
    DescriptiveStats macroF1Summary = EvaluationAggregator.summarize(evals, LabelEvaluation::macroAveragedF1);
    out.println(macroF1Summary);
    Pair<Integer, Double> argmax = EvaluationAggregator.argmax(evals, LabelEvaluation::macroAveragedF1);
    Model<Label> bestF1 = results.get(argmax.getA()).getB();
    LabelEvaluation testEval = evaluator.evaluate(bestF1, datasets.getB());
    System.out.println(testEval);
}
Also used : Dataset(org.tribuo.Dataset) Label(org.tribuo.classification.Label) MetricID(org.tribuo.evaluation.metrics.MetricID) DescriptiveStats(org.tribuo.evaluation.DescriptiveStats) CrossValidation(org.tribuo.evaluation.CrossValidation) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 43 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class TreeFeature method split.

/**
 * Splits this tree feature into two.
 *
 * @param leftIndices The indices to go in the left branch.
 * @param firstBuffer A buffer to use.
 * @param secondBuffer Another buffer.
 * @return A pair of TreeFeatures, the first element is the left branch, the second the right.
 */
public Pair<TreeFeature, TreeFeature> split(IntArrayContainer leftIndices, IntArrayContainer firstBuffer, IntArrayContainer secondBuffer) {
    if (!sorted) {
        throw new IllegalStateException("TreeFeature must be sorted before split is called");
    }
    List<InvertedFeature> leftFeatures = new ArrayList<>();
    List<InvertedFeature> rightFeatures = new ArrayList<>();
    firstBuffer.fill(leftIndices);
    for (InvertedFeature f : feature) {
        // Check if we've exhausted all the left side indices
        if (firstBuffer.size > 0) {
            Pair<InvertedFeature, InvertedFeature> split = f.split(firstBuffer, secondBuffer);
            IntArrayContainer tmp = secondBuffer;
            secondBuffer = firstBuffer;
            firstBuffer = tmp;
            InvertedFeature left = split.getA();
            InvertedFeature right = split.getB();
            if (left != null) {
                leftFeatures.add(left);
            }
            if (right != null) {
                rightFeatures.add(right);
            }
        } else {
            rightFeatures.add(f);
        }
    }
    return new Pair<>(new TreeFeature(id, numLabels, leftFeatures), new TreeFeature(id, numLabels, rightFeatures));
}
Also used : IntArrayContainer(org.tribuo.common.tree.impl.IntArrayContainer) ArrayList(java.util.ArrayList) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 44 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class ClassificationTest method generateImageData.

/**
 * Generates image data.
 * <p>
 * The data generating process is as follows:
 * - Compute the number of possible features which could be set. Features are set in a block based on the y
 * co-ordinate which indicates the class label.
 * - Sample a class label y, in the range 0 -> numClasses
 * - For 50% of the number of valid features:
 * -- Randomly sample a feature's y co-ordinate in the range y*pixRange -> (y+1)*pixRange
 * -- Randomly sample the feature's x co-ordinate in the range 0 -> imageSize
 * -- Randomly sample the feature's value in the range (pixelDepth/2,pixelDepth)
 * -- Check if we've added this feature already, if not add it.
 * @param numExamples Number of examples to generate for train and test.
 * @param imageSize The image size in pixels, must be a multiple of the number of classes.
 * @param pixelDepth The number of valid pixel values, must be greater than 1.
 * @param numClasses The number of classes.
 * @param seed The RNG seed.
 * @return Training and test datasets.
 */
private static Pair<Dataset<Label>, Dataset<Label>> generateImageData(int numExamples, int imageSize, int pixelDepth, int numClasses, int seed) {
    if (imageSize % numClasses != 0) {
        throw new IllegalArgumentException("The data generating process needs imageSize to be a multiple of numClasses.");
    }
    if (pixelDepth < 1) {
        throw new IllegalArgumentException("Pixel depth must be greater than 1");
    }
    SplittableRandom rng = new SplittableRandom(seed);
    LabelFactory factory = new LabelFactory();
    String description = "(numExamples=" + numExamples + ",imageSize=" + imageSize + ",pixelDepth=" + pixelDepth + ",numClasses=" + numClasses + ",seed=" + seed + ")";
    int maxFeature = imageSize * imageSize;
    int width = ("" + maxFeature).length();
    String formatString = "%0" + width + "d";
    Map<Integer, String> featureNameMap = new HashMap<>(maxFeature);
    for (int i = 0; i < maxFeature; i++) {
        featureNameMap.put(i, String.format(formatString, i));
    }
    int halfDepth = pixelDepth / 2;
    int pixRange = imageSize / numClasses;
    int numValidFeatures = pixRange * imageSize;
    List<Example<Label>> trainList = new ArrayList<>();
    Set<String> names = new HashSet<>();
    List<Feature> featuresCache = new ArrayList<>();
    for (int i = 0; i < numExamples; i++) {
        names.clear();
        featuresCache.clear();
        int curLabelIdx = rng.nextInt(numClasses);
        Label curLabel = new Label("" + curLabelIdx);
        for (int j = 0; j < numValidFeatures / 2; j++) {
            int yValue = rng.nextInt(pixRange) + (curLabelIdx * pixRange);
            int xValue = rng.nextInt(imageSize);
            int value = rng.nextInt(halfDepth) + halfDepth;
            // feature name = x*imageSize + y
            int featureIdx = xValue * imageSize + yValue;
            String featureName = featureNameMap.get(featureIdx);
            if (!names.contains(featureName)) {
                names.add(featureName);
                featuresCache.add(new Feature(featureName, value));
            }
        }
        trainList.add(new ArrayExample<>(curLabel, featuresCache));
    }
    ListDataSource<Label> trainListSource = new ListDataSource<>(trainList, factory, new SimpleDataSourceProvenance("Training " + description, factory));
    List<Example<Label>> testList = new ArrayList<>();
    for (int i = 0; i < numExamples; i++) {
        names.clear();
        featuresCache.clear();
        int curLabelIdx = rng.nextInt(numClasses);
        Label curLabel = new Label("" + curLabelIdx);
        for (int j = 0; j < numValidFeatures / 2; j++) {
            int yValue = rng.nextInt(pixRange) + (curLabelIdx * pixRange);
            int xValue = rng.nextInt(imageSize);
            int value = rng.nextInt(halfDepth) + halfDepth;
            // feature name = x*imageSize + y
            int featureIdx = xValue * imageSize + yValue;
            String featureName = featureNameMap.get(featureIdx);
            if (!names.contains(featureName)) {
                names.add(featureName);
                featuresCache.add(new Feature(featureName, value));
            }
        }
        testList.add(new ArrayExample<>(curLabel, featuresCache));
    }
    ListDataSource<Label> testListSource = new ListDataSource<>(testList, factory, new SimpleDataSourceProvenance("Testing " + description, factory));
    return new Pair<>(new MutableDataset<>(trainListSource), new MutableDataset<>(testListSource));
}
Also used : HashMap(java.util.HashMap) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) ArrayList(java.util.ArrayList) Label(org.tribuo.classification.Label) Feature(org.tribuo.Feature) ListDataSource(org.tribuo.datasource.ListDataSource) LabelFactory(org.tribuo.classification.LabelFactory) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) SplittableRandom(java.util.SplittableRandom) HashSet(java.util.HashSet) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 45 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class NeighboursBruteForce method query.

@Override
public List<List<Pair<Integer, Double>>> query(SGDVector[] points, int k) {
    int numQueries = points.length;
    @SuppressWarnings("unchecked") List<Pair<Integer, Double>>[] indexDistancePairListArray = (List<Pair<Integer, Double>>[]) new List[numQueries];
    // When the number of threads is 1, the overhead of thread pools must be avoided
    if (numThreads == 1) {
        for (int point = 0; point < numQueries; point++) {
            indexDistancePairListArray[point] = query(points[point], k);
        }
    } else {
        // This makes the nearest neighbor queries with multiple threads
        ExecutorService executorService = Executors.newFixedThreadPool(numThreads);
        for (int pointInd = 0; pointInd < numQueries; pointInd++) {
            executorService.execute(new SingleQueryRunnable(pointInd, points[pointInd], k, indexDistancePairListArray));
        }
        executorService.shutdown();
        try {
            boolean finished = executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MINUTES);
            if (!finished) {
                throw new RuntimeException("Parallel execution failed");
            }
        } catch (InterruptedException e) {
            throw new RuntimeException("Parallel execution failed", e);
        }
    }
    return new ArrayList<>(Arrays.asList(indexDistancePairListArray));
}
Also used : ArrayList(java.util.ArrayList) ExecutorService(java.util.concurrent.ExecutorService) List(java.util.List) ArrayList(java.util.ArrayList) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Aggregations

Pair (com.oracle.labs.mlrg.olcut.util.Pair)59 ArrayList (java.util.ArrayList)27 List (java.util.List)21 HashMap (java.util.HashMap)18 MutableDataset (org.tribuo.MutableDataset)17 SimpleDataSourceProvenance (org.tribuo.provenance.SimpleDataSourceProvenance)16 Label (org.tribuo.classification.Label)14 Feature (org.tribuo.Feature)11 Regressor (org.tribuo.regression.Regressor)11 Prediction (org.tribuo.Prediction)10 DenseVector (org.tribuo.math.la.DenseVector)10 SparseVector (org.tribuo.math.la.SparseVector)10 SGDVector (org.tribuo.math.la.SGDVector)9 Map (java.util.Map)7 Example (org.tribuo.Example)7 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)7 PriorityQueue (java.util.PriorityQueue)6 Excuse (org.tribuo.Excuse)5 Model (org.tribuo.Model)5 LabelFactory (org.tribuo.classification.LabelFactory)5