Search in sources :

Example 6 with TTest

use of org.apache.commons.math3.stat.inference.TTest in project bayou by capergroup.

the class MetricCalculator method execute.

public void execute() throws IOException {
    if (cmdLine == null)
        return;
    int topk = cmdLine.hasOption("t") ? Integer.parseInt(cmdLine.getOptionValue("t")) : 10;
    Metric metric;
    String m = cmdLine.getOptionValue("m");
    switch(m) {
        case "equality-ast":
            metric = new EqualityASTMetric();
            break;
        case "jaccard-sequences":
            metric = new JaccardSequencesMetric();
            break;
        case "jaccard-api-calls":
            metric = new JaccardAPICallsMetric();
            break;
        case "num-control-structures":
            metric = new NumControlStructuresMetric();
            break;
        case "num-statements":
            metric = new NumStatementsMetric();
            break;
        default:
            System.err.println("invalid metric: " + cmdLine.getOptionValue("m"));
            return;
    }
    int inCorpus = cmdLine.hasOption("c") ? Integer.parseInt(cmdLine.getOptionValue("c")) : 1;
    String aggregate = cmdLine.hasOption("a") ? cmdLine.getOptionValue("a") : "min";
    List<JSONInputFormat.DataPoint> data = JSONInputFormat.readData(cmdLine.getOptionValue("f"));
    if (inCorpus == 2)
        data = data.stream().filter(datapoint -> datapoint.in_corpus).collect(Collectors.toList());
    else if (inCorpus == 3)
        data = data.stream().filter(datapoint -> !datapoint.in_corpus).collect(Collectors.toList());
    List<Float> values = new ArrayList<>();
    for (JSONInputFormat.DataPoint datapoint : data) {
        DSubTree originalAST = datapoint.ast;
        List<DSubTree> predictedASTs = datapoint.out_asts.subList(0, Math.min(topk, datapoint.out_asts.size()));
        values.add(metric.compute(originalAST, predictedASTs, aggregate));
    }
    List<Float> values2 = new ArrayList<>();
    if (cmdLine.hasOption("p")) {
        List<JSONInputFormat.DataPoint> data2 = JSONInputFormat.readData(cmdLine.getOptionValue("p"));
        if (inCorpus == 2)
            data2 = data2.stream().filter(datapoint -> datapoint.in_corpus).collect(Collectors.toList());
        else if (inCorpus == 3)
            data2 = data2.stream().filter(datapoint -> !datapoint.in_corpus).collect(Collectors.toList());
        for (JSONInputFormat.DataPoint datapoint : data2) {
            DSubTree originalAST = datapoint.ast;
            List<DSubTree> predictedASTs = datapoint.out_asts.subList(0, Math.min(topk, datapoint.out_asts.size()));
            values2.add(metric.compute(originalAST, predictedASTs, aggregate));
        }
        if (values.size() != values2.size())
            throw new Error("DATA files do not match in size. Cannot compute p-value.");
    }
    float average = Metric.mean(values);
    float stdv = Metric.standardDeviation(values);
    if (cmdLine.hasOption("p")) {
        double[] dValues = values.stream().mapToDouble(v -> v.floatValue()).toArray();
        double[] dValues2 = values2.stream().mapToDouble(v -> v.floatValue()).toArray();
        double pValue = new TTest().pairedTTest(dValues, dValues2);
        System.out.println(String.format("%s (%d data points, each aggregated with %s): average=%f, stdv=%f, pvalue=%e", m, data.size(), aggregate, average, stdv, pValue));
    } else
        System.out.println(String.format("%s (%d data points, each aggregated with %s): average=%f, stdv=%f", m, data.size(), aggregate, average, stdv));
}
Also used : List(java.util.List) edu.rice.cs.caper.bayou.core.dsl(edu.rice.cs.caper.bayou.core.dsl) Metric(edu.rice.cs.caper.bayou.core.sketch_metric.Metric) org.apache.commons.cli(org.apache.commons.cli) TTest(org.apache.commons.math3.stat.inference.TTest) IOException(java.io.IOException) EqualityASTMetric(edu.rice.cs.caper.bayou.core.sketch_metric.EqualityASTMetric) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) TTest(org.apache.commons.math3.stat.inference.TTest) ArrayList(java.util.ArrayList) EqualityASTMetric(edu.rice.cs.caper.bayou.core.sketch_metric.EqualityASTMetric) Metric(edu.rice.cs.caper.bayou.core.sketch_metric.Metric) EqualityASTMetric(edu.rice.cs.caper.bayou.core.sketch_metric.EqualityASTMetric)

Example 7 with TTest

use of org.apache.commons.math3.stat.inference.TTest in project GDSC-SMLM by aherbert.

the class BaseFunctionSolverTest method canFitSingleGaussianBetter.

void canFitSingleGaussianBetter(RandomSeed seed, FunctionSolver solver, boolean applyBounds, FunctionSolver solver2, boolean applyBounds2, String name, String name2, NoiseModel noiseModel) {
    final double[] noise = getNoise(seed, noiseModel);
    if (solver.isWeighted()) {
        solver.setWeights(getWeights(seed, noiseModel));
    }
    final int loops = 5;
    final UniformRandomProvider rg = RngUtils.create(seed.getSeed());
    final StoredDataStatistics[] stats = new StoredDataStatistics[6];
    final String[] statName = { "Signal", "X", "Y" };
    final int[] betterPrecision = new int[3];
    final int[] totalPrecision = new int[3];
    final int[] betterAccuracy = new int[3];
    final int[] totalAccuracy = new int[3];
    final String msg = "%s vs %s : %.1f (%s) %s %f +/- %f vs %f +/- %f  (N=%d) %b %s";
    int i1 = 0;
    int i2 = 0;
    for (final double s : signal) {
        final double[] expected = createParams(1, s, 0, 0, 1);
        double[] lower = null;
        double[] upper = null;
        if (applyBounds || applyBounds2) {
            lower = createParams(0, s * 0.5, -0.3, -0.3, 0.8);
            upper = createParams(3, s * 2, 0.3, 0.3, 1.2);
        }
        if (applyBounds) {
            solver.setBounds(lower, upper);
        }
        if (applyBounds2) {
            solver2.setBounds(lower, upper);
        }
        for (int loop = loops; loop-- > 0; ) {
            final double[] data = drawGaussian(expected, noise, noiseModel, rg);
            for (int i = 0; i < stats.length; i++) {
                stats[i] = new StoredDataStatistics();
            }
            for (final double db : base) {
                for (final double dx : shift) {
                    for (final double dy : shift) {
                        for (final double dsx : factor) {
                            final double[] p = createParams(db, s, dx, dy, dsx);
                            final double[] fp = fitGaussian(solver, data, p, expected);
                            i1 += solver.getEvaluations();
                            final double[] fp2 = fitGaussian(solver2, data, p, expected);
                            i2 += solver2.getEvaluations();
                            // Get the mean and sd (the fit precision)
                            compare(fp, expected, fp2, expected, Gaussian2DFunction.SIGNAL, stats[0], stats[1]);
                            compare(fp, expected, fp2, expected, Gaussian2DFunction.X_POSITION, stats[2], stats[3]);
                            compare(fp, expected, fp2, expected, Gaussian2DFunction.Y_POSITION, stats[4], stats[5]);
                        // Use the distance
                        // stats[2].add(distance(fp, expected));
                        // stats[3].add(distance(fp2, expected2));
                        }
                    }
                }
            }
            // two sided
            final double alpha = 0.05;
            for (int i = 0; i < stats.length; i += 2) {
                double u1 = stats[i].getMean();
                double u2 = stats[i + 1].getMean();
                final double sd1 = stats[i].getStandardDeviation();
                final double sd2 = stats[i + 1].getStandardDeviation();
                final TTest tt = new TTest();
                final boolean diff = tt.tTest(stats[i].getValues(), stats[i + 1].getValues(), alpha);
                final int index = i / 2;
                final Object[] args = new Object[] { name2, name, s, noiseModel, statName[index], u2, sd2, u1, sd1, stats[i].getN(), diff, "" };
                if (diff) {
                    // Different means. Check they are roughly the same
                    if (DoubleEquality.almostEqualRelativeOrAbsolute(u1, u2, 0.1, 0)) {
                        // Basically the same. Check which is more precise
                        if (!DoubleEquality.almostEqualRelativeOrAbsolute(sd1, sd2, 0.05, 0)) {
                            if (sd2 < sd1) {
                                betterPrecision[index]++;
                                args[args.length - 1] = "P*";
                                logger.log(TestLogUtils.getRecord(Level.FINE, msg, args));
                            } else {
                                args[args.length - 1] = "P";
                                logger.log(TestLogUtils.getRecord(Level.FINE, msg, args));
                            }
                            totalPrecision[index]++;
                        }
                    } else {
                        // Check which is more accurate (closer to zero)
                        u1 = Math.abs(u1);
                        u2 = Math.abs(u2);
                        if (u2 < u1) {
                            betterAccuracy[index]++;
                            args[args.length - 1] = "A*";
                            logger.log(TestLogUtils.getRecord(Level.FINE, msg, args));
                        } else {
                            args[args.length - 1] = "A";
                            logger.log(TestLogUtils.getRecord(Level.FINE, msg, args));
                        }
                        totalAccuracy[index]++;
                    }
                // The same means. Check that it is more precise
                } else if (!DoubleEquality.almostEqualRelativeOrAbsolute(sd1, sd2, 0.05, 0)) {
                    if (sd2 < sd1) {
                        betterPrecision[index]++;
                        args[args.length - 1] = "P*";
                        logger.log(TestLogUtils.getRecord(Level.FINE, msg, args));
                    } else {
                        args[args.length - 1] = "P";
                        logger.log(TestLogUtils.getRecord(Level.FINE, msg, args));
                    }
                    totalPrecision[index]++;
                }
            }
        }
    }
    int better = 0;
    int total = 0;
    for (int index = 0; index < statName.length; index++) {
        better += betterPrecision[index] + betterAccuracy[index];
        total += totalPrecision[index] + totalAccuracy[index];
        test(name2, name, statName[index] + " P", betterPrecision[index], totalPrecision[index], Level.FINE);
        test(name2, name, statName[index] + " A", betterAccuracy[index], totalAccuracy[index], Level.FINE);
    }
    test(name2, name, String.format("All (eval [%d] [%d]) : ", i1, i2), better, total, Level.INFO);
}
Also used : TTest(org.apache.commons.math3.stat.inference.TTest) StoredDataStatistics(uk.ac.sussex.gdsc.core.utils.StoredDataStatistics) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider)

Aggregations

TTest (org.apache.commons.math3.stat.inference.TTest)7 SummaryStatistics (org.apache.commons.math3.stat.descriptive.SummaryStatistics)2 DataByteArray (org.apache.pig.data.DataByteArray)2 Tuple (org.apache.pig.data.Tuple)2 ArrayOfDoublesSketch (com.yahoo.sketches.tuple.ArrayOfDoublesSketch)1 ArrayOfDoublesUpdatableSketch (com.yahoo.sketches.tuple.ArrayOfDoublesUpdatableSketch)1 ArrayOfDoublesUpdatableSketchBuilder (com.yahoo.sketches.tuple.ArrayOfDoublesUpdatableSketchBuilder)1 edu.rice.cs.caper.bayou.core.dsl (edu.rice.cs.caper.bayou.core.dsl)1 EqualityASTMetric (edu.rice.cs.caper.bayou.core.sketch_metric.EqualityASTMetric)1 Metric (edu.rice.cs.caper.bayou.core.sketch_metric.Metric)1 StoredDataStatistics (gdsc.core.utils.StoredDataStatistics)1 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1 Random (java.util.Random)1 Collectors (java.util.stream.Collectors)1 org.apache.commons.cli (org.apache.commons.cli)1 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)1 ArrayOfDoublesSketch (org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesSketch)1 IAE (org.apache.druid.java.util.common.IAE)1