use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class DummyRegressionTrainer method train.
@Override
public DummyRegressionModel train(Dataset<Regressor> examples, Map<String, Provenance> instanceProvenance, int invocationCount) {
if (invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
ModelProvenance provenance = new ModelProvenance(DummyRegressionModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), instanceProvenance);
trainInvocationCounter++;
ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo();
Set<Regressor> domain = outputInfo.getDomain();
double[][] outputs = new double[outputInfo.size()][examples.size()];
int i = 0;
for (Example<Regressor> e : examples) {
for (Regressor.DimensionTuple r : e.getOutput()) {
int id = outputInfo.getID(r);
outputs[id][i] = r.getValue();
}
i++;
}
Regressor regressor;
switch(dummyType) {
case CONSTANT:
{
Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
for (Regressor r : domain) {
int id = outputInfo.getID(r);
output[id] = new Regressor.DimensionTuple(r.getNames()[0], constantValue);
}
regressor = new Regressor(output);
return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), outputInfo, dummyType, regressor);
}
case MEAN:
{
Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
for (Regressor r : domain) {
int id = outputInfo.getID(r);
output[id] = new Regressor.DimensionTuple(r.getNames()[0], Util.mean(outputs[id]));
}
regressor = new Regressor(output);
return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), outputInfo, dummyType, regressor);
}
case MEDIAN:
{
Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
for (Regressor r : domain) {
int id = outputInfo.getID(r);
Arrays.sort(outputs[id]);
output[id] = new Regressor.DimensionTuple(r.getNames()[0], outputs[id][outputs[id].length / 2]);
}
regressor = new Regressor(output);
return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), outputInfo, dummyType, regressor);
}
case QUARTILE:
{
Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
for (Regressor r : domain) {
int id = outputInfo.getID(r);
Arrays.sort(outputs[id]);
output[id] = new Regressor.DimensionTuple(r.getNames()[0], outputs[id][(int) (quartile * outputs[id].length)]);
}
regressor = new Regressor(output);
return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), outputInfo, dummyType, regressor);
}
case GAUSSIAN:
{
double[] means = new double[outputs.length];
double[] variances = new double[outputs.length];
String[] names = new String[outputs.length];
for (Regressor r : domain) {
int id = outputInfo.getID(r);
names[id] = r.getNames()[0];
Pair<Double, Double> meanVariance = Util.meanAndVariance(outputs[id]);
means[id] = meanVariance.getA();
variances[id] = meanVariance.getB();
}
return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), outputInfo, seed, means, variances, names);
}
default:
throw new IllegalStateException("Unknown dummyType " + dummyType);
}
}
use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class EvaluationAggregationTests method xval.
public static void xval() {
Trainer<Label> trainer = DummyClassifierTrainer.createUniformTrainer(1L);
Pair<Dataset<Label>, Dataset<Label>> datasets = LabelledDataGenerator.denseTrainTest();
Dataset<Label> trainData = datasets.getA();
Evaluator<Label, LabelEvaluation> evaluator = factory.getEvaluator();
CrossValidation<Label, LabelEvaluation> xval = new CrossValidation<>(trainer, trainData, evaluator, 5);
List<Pair<LabelEvaluation, Model<Label>>> results = xval.evaluate();
List<LabelEvaluation> evals = results.stream().map(Pair::getA).collect(Collectors.toList());
// Summarize across everything
Map<MetricID<Label>, DescriptiveStats> summary = EvaluationAggregator.summarize(evals);
List<MetricID<Label>> keys = new ArrayList<>(summary.keySet()).stream().sorted(Comparator.comparing(Pair::getB)).collect(Collectors.toList());
for (MetricID<Label> key : keys) {
DescriptiveStats stats = summary.get(key);
out.printf("%-10s %.5f (%.5f)%n", key, stats.getMean(), stats.getStandardDeviation());
}
// Summarize across macro F1s only
DescriptiveStats macroF1Summary = EvaluationAggregator.summarize(evals, LabelEvaluation::macroAveragedF1);
out.println(macroF1Summary);
Pair<Integer, Double> argmax = EvaluationAggregator.argmax(evals, LabelEvaluation::macroAveragedF1);
Model<Label> bestF1 = results.get(argmax.getA()).getB();
LabelEvaluation testEval = evaluator.evaluate(bestF1, datasets.getB());
System.out.println(testEval);
}
use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class TreeFeature method split.
/**
* Splits this tree feature into two.
*
* @param leftIndices The indices to go in the left branch.
* @param firstBuffer A buffer to use.
* @param secondBuffer Another buffer.
* @return A pair of TreeFeatures, the first element is the left branch, the second the right.
*/
public Pair<TreeFeature, TreeFeature> split(IntArrayContainer leftIndices, IntArrayContainer firstBuffer, IntArrayContainer secondBuffer) {
if (!sorted) {
throw new IllegalStateException("TreeFeature must be sorted before split is called");
}
List<InvertedFeature> leftFeatures = new ArrayList<>();
List<InvertedFeature> rightFeatures = new ArrayList<>();
firstBuffer.fill(leftIndices);
for (InvertedFeature f : feature) {
// Check if we've exhausted all the left side indices
if (firstBuffer.size > 0) {
Pair<InvertedFeature, InvertedFeature> split = f.split(firstBuffer, secondBuffer);
IntArrayContainer tmp = secondBuffer;
secondBuffer = firstBuffer;
firstBuffer = tmp;
InvertedFeature left = split.getA();
InvertedFeature right = split.getB();
if (left != null) {
leftFeatures.add(left);
}
if (right != null) {
rightFeatures.add(right);
}
} else {
rightFeatures.add(f);
}
}
return new Pair<>(new TreeFeature(id, numLabels, leftFeatures), new TreeFeature(id, numLabels, rightFeatures));
}
use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class ClassificationTest method generateImageData.
/**
* Generates image data.
* <p>
* The data generating process is as follows:
* - Compute the number of possible features which could be set. Features are set in a block based on the y
* co-ordinate which indicates the class label.
* - Sample a class label y, in the range 0 -> numClasses
* - For 50% of the number of valid features:
* -- Randomly sample a feature's y co-ordinate in the range y*pixRange -> (y+1)*pixRange
* -- Randomly sample the feature's x co-ordinate in the range 0 -> imageSize
* -- Randomly sample the feature's value in the range (pixelDepth/2,pixelDepth)
* -- Check if we've added this feature already, if not add it.
* @param numExamples Number of examples to generate for train and test.
* @param imageSize The image size in pixels, must be a multiple of the number of classes.
* @param pixelDepth The number of valid pixel values, must be greater than 1.
* @param numClasses The number of classes.
* @param seed The RNG seed.
* @return Training and test datasets.
*/
private static Pair<Dataset<Label>, Dataset<Label>> generateImageData(int numExamples, int imageSize, int pixelDepth, int numClasses, int seed) {
if (imageSize % numClasses != 0) {
throw new IllegalArgumentException("The data generating process needs imageSize to be a multiple of numClasses.");
}
if (pixelDepth < 1) {
throw new IllegalArgumentException("Pixel depth must be greater than 1");
}
SplittableRandom rng = new SplittableRandom(seed);
LabelFactory factory = new LabelFactory();
String description = "(numExamples=" + numExamples + ",imageSize=" + imageSize + ",pixelDepth=" + pixelDepth + ",numClasses=" + numClasses + ",seed=" + seed + ")";
int maxFeature = imageSize * imageSize;
int width = ("" + maxFeature).length();
String formatString = "%0" + width + "d";
Map<Integer, String> featureNameMap = new HashMap<>(maxFeature);
for (int i = 0; i < maxFeature; i++) {
featureNameMap.put(i, String.format(formatString, i));
}
int halfDepth = pixelDepth / 2;
int pixRange = imageSize / numClasses;
int numValidFeatures = pixRange * imageSize;
List<Example<Label>> trainList = new ArrayList<>();
Set<String> names = new HashSet<>();
List<Feature> featuresCache = new ArrayList<>();
for (int i = 0; i < numExamples; i++) {
names.clear();
featuresCache.clear();
int curLabelIdx = rng.nextInt(numClasses);
Label curLabel = new Label("" + curLabelIdx);
for (int j = 0; j < numValidFeatures / 2; j++) {
int yValue = rng.nextInt(pixRange) + (curLabelIdx * pixRange);
int xValue = rng.nextInt(imageSize);
int value = rng.nextInt(halfDepth) + halfDepth;
// feature name = x*imageSize + y
int featureIdx = xValue * imageSize + yValue;
String featureName = featureNameMap.get(featureIdx);
if (!names.contains(featureName)) {
names.add(featureName);
featuresCache.add(new Feature(featureName, value));
}
}
trainList.add(new ArrayExample<>(curLabel, featuresCache));
}
ListDataSource<Label> trainListSource = new ListDataSource<>(trainList, factory, new SimpleDataSourceProvenance("Training " + description, factory));
List<Example<Label>> testList = new ArrayList<>();
for (int i = 0; i < numExamples; i++) {
names.clear();
featuresCache.clear();
int curLabelIdx = rng.nextInt(numClasses);
Label curLabel = new Label("" + curLabelIdx);
for (int j = 0; j < numValidFeatures / 2; j++) {
int yValue = rng.nextInt(pixRange) + (curLabelIdx * pixRange);
int xValue = rng.nextInt(imageSize);
int value = rng.nextInt(halfDepth) + halfDepth;
// feature name = x*imageSize + y
int featureIdx = xValue * imageSize + yValue;
String featureName = featureNameMap.get(featureIdx);
if (!names.contains(featureName)) {
names.add(featureName);
featuresCache.add(new Feature(featureName, value));
}
}
testList.add(new ArrayExample<>(curLabel, featuresCache));
}
ListDataSource<Label> testListSource = new ListDataSource<>(testList, factory, new SimpleDataSourceProvenance("Testing " + description, factory));
return new Pair<>(new MutableDataset<>(trainListSource), new MutableDataset<>(testListSource));
}
use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class NeighboursBruteForce method query.
@Override
public List<List<Pair<Integer, Double>>> query(SGDVector[] points, int k) {
int numQueries = points.length;
@SuppressWarnings("unchecked") List<Pair<Integer, Double>>[] indexDistancePairListArray = (List<Pair<Integer, Double>>[]) new List[numQueries];
// When the number of threads is 1, the overhead of thread pools must be avoided
if (numThreads == 1) {
for (int point = 0; point < numQueries; point++) {
indexDistancePairListArray[point] = query(points[point], k);
}
} else {
// This makes the nearest neighbor queries with multiple threads
ExecutorService executorService = Executors.newFixedThreadPool(numThreads);
for (int pointInd = 0; pointInd < numQueries; pointInd++) {
executorService.execute(new SingleQueryRunnable(pointInd, points[pointInd], k, indexDistancePairListArray));
}
executorService.shutdown();
try {
boolean finished = executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MINUTES);
if (!finished) {
throw new RuntimeException("Parallel execution failed");
}
} catch (InterruptedException e) {
throw new RuntimeException("Parallel execution failed", e);
}
}
return new ArrayList<>(Arrays.asList(indexDistancePairListArray));
}
Aggregations