use of org.apache.commons.math3.stat.inference.TTest in project bayou by capergroup.
the class MetricCalculator method execute.
public void execute() throws IOException {
if (cmdLine == null)
return;
int topk = cmdLine.hasOption("t") ? Integer.parseInt(cmdLine.getOptionValue("t")) : 10;
Metric metric;
String m = cmdLine.getOptionValue("m");
switch(m) {
case "equality-ast":
metric = new EqualityASTMetric();
break;
case "jaccard-sequences":
metric = new JaccardSequencesMetric();
break;
case "jaccard-api-calls":
metric = new JaccardAPICallsMetric();
break;
case "num-control-structures":
metric = new NumControlStructuresMetric();
break;
case "num-statements":
metric = new NumStatementsMetric();
break;
default:
System.err.println("invalid metric: " + cmdLine.getOptionValue("m"));
return;
}
int inCorpus = cmdLine.hasOption("c") ? Integer.parseInt(cmdLine.getOptionValue("c")) : 1;
String aggregate = cmdLine.hasOption("a") ? cmdLine.getOptionValue("a") : "min";
List<JSONInputFormat.DataPoint> data = JSONInputFormat.readData(cmdLine.getOptionValue("f"));
if (inCorpus == 2)
data = data.stream().filter(datapoint -> datapoint.in_corpus).collect(Collectors.toList());
else if (inCorpus == 3)
data = data.stream().filter(datapoint -> !datapoint.in_corpus).collect(Collectors.toList());
List<Float> values = new ArrayList<>();
for (JSONInputFormat.DataPoint datapoint : data) {
DSubTree originalAST = datapoint.ast;
List<DSubTree> predictedASTs = datapoint.out_asts.subList(0, Math.min(topk, datapoint.out_asts.size()));
values.add(metric.compute(originalAST, predictedASTs, aggregate));
}
List<Float> values2 = new ArrayList<>();
if (cmdLine.hasOption("p")) {
List<JSONInputFormat.DataPoint> data2 = JSONInputFormat.readData(cmdLine.getOptionValue("p"));
if (inCorpus == 2)
data2 = data2.stream().filter(datapoint -> datapoint.in_corpus).collect(Collectors.toList());
else if (inCorpus == 3)
data2 = data2.stream().filter(datapoint -> !datapoint.in_corpus).collect(Collectors.toList());
for (JSONInputFormat.DataPoint datapoint : data2) {
DSubTree originalAST = datapoint.ast;
List<DSubTree> predictedASTs = datapoint.out_asts.subList(0, Math.min(topk, datapoint.out_asts.size()));
values2.add(metric.compute(originalAST, predictedASTs, aggregate));
}
if (values.size() != values2.size())
throw new Error("DATA files do not match in size. Cannot compute p-value.");
}
float average = Metric.mean(values);
float stdv = Metric.standardDeviation(values);
if (cmdLine.hasOption("p")) {
double[] dValues = values.stream().mapToDouble(v -> v.floatValue()).toArray();
double[] dValues2 = values2.stream().mapToDouble(v -> v.floatValue()).toArray();
double pValue = new TTest().pairedTTest(dValues, dValues2);
System.out.println(String.format("%s (%d data points, each aggregated with %s): average=%f, stdv=%f, pvalue=%e", m, data.size(), aggregate, average, stdv, pValue));
} else
System.out.println(String.format("%s (%d data points, each aggregated with %s): average=%f, stdv=%f", m, data.size(), aggregate, average, stdv));
}
use of org.apache.commons.math3.stat.inference.TTest in project GDSC-SMLM by aherbert.
the class BaseFunctionSolverTest method canFitSingleGaussianBetter.
void canFitSingleGaussianBetter(RandomSeed seed, FunctionSolver solver, boolean applyBounds, FunctionSolver solver2, boolean applyBounds2, String name, String name2, NoiseModel noiseModel) {
final double[] noise = getNoise(seed, noiseModel);
if (solver.isWeighted()) {
solver.setWeights(getWeights(seed, noiseModel));
}
final int loops = 5;
final UniformRandomProvider rg = RngUtils.create(seed.getSeed());
final StoredDataStatistics[] stats = new StoredDataStatistics[6];
final String[] statName = { "Signal", "X", "Y" };
final int[] betterPrecision = new int[3];
final int[] totalPrecision = new int[3];
final int[] betterAccuracy = new int[3];
final int[] totalAccuracy = new int[3];
final String msg = "%s vs %s : %.1f (%s) %s %f +/- %f vs %f +/- %f (N=%d) %b %s";
int i1 = 0;
int i2 = 0;
for (final double s : signal) {
final double[] expected = createParams(1, s, 0, 0, 1);
double[] lower = null;
double[] upper = null;
if (applyBounds || applyBounds2) {
lower = createParams(0, s * 0.5, -0.3, -0.3, 0.8);
upper = createParams(3, s * 2, 0.3, 0.3, 1.2);
}
if (applyBounds) {
solver.setBounds(lower, upper);
}
if (applyBounds2) {
solver2.setBounds(lower, upper);
}
for (int loop = loops; loop-- > 0; ) {
final double[] data = drawGaussian(expected, noise, noiseModel, rg);
for (int i = 0; i < stats.length; i++) {
stats[i] = new StoredDataStatistics();
}
for (final double db : base) {
for (final double dx : shift) {
for (final double dy : shift) {
for (final double dsx : factor) {
final double[] p = createParams(db, s, dx, dy, dsx);
final double[] fp = fitGaussian(solver, data, p, expected);
i1 += solver.getEvaluations();
final double[] fp2 = fitGaussian(solver2, data, p, expected);
i2 += solver2.getEvaluations();
// Get the mean and sd (the fit precision)
compare(fp, expected, fp2, expected, Gaussian2DFunction.SIGNAL, stats[0], stats[1]);
compare(fp, expected, fp2, expected, Gaussian2DFunction.X_POSITION, stats[2], stats[3]);
compare(fp, expected, fp2, expected, Gaussian2DFunction.Y_POSITION, stats[4], stats[5]);
// Use the distance
// stats[2].add(distance(fp, expected));
// stats[3].add(distance(fp2, expected2));
}
}
}
}
// two sided
final double alpha = 0.05;
for (int i = 0; i < stats.length; i += 2) {
double u1 = stats[i].getMean();
double u2 = stats[i + 1].getMean();
final double sd1 = stats[i].getStandardDeviation();
final double sd2 = stats[i + 1].getStandardDeviation();
final TTest tt = new TTest();
final boolean diff = tt.tTest(stats[i].getValues(), stats[i + 1].getValues(), alpha);
final int index = i / 2;
final Object[] args = new Object[] { name2, name, s, noiseModel, statName[index], u2, sd2, u1, sd1, stats[i].getN(), diff, "" };
if (diff) {
// Different means. Check they are roughly the same
if (DoubleEquality.almostEqualRelativeOrAbsolute(u1, u2, 0.1, 0)) {
// Basically the same. Check which is more precise
if (!DoubleEquality.almostEqualRelativeOrAbsolute(sd1, sd2, 0.05, 0)) {
if (sd2 < sd1) {
betterPrecision[index]++;
args[args.length - 1] = "P*";
logger.log(TestLogUtils.getRecord(Level.FINE, msg, args));
} else {
args[args.length - 1] = "P";
logger.log(TestLogUtils.getRecord(Level.FINE, msg, args));
}
totalPrecision[index]++;
}
} else {
// Check which is more accurate (closer to zero)
u1 = Math.abs(u1);
u2 = Math.abs(u2);
if (u2 < u1) {
betterAccuracy[index]++;
args[args.length - 1] = "A*";
logger.log(TestLogUtils.getRecord(Level.FINE, msg, args));
} else {
args[args.length - 1] = "A";
logger.log(TestLogUtils.getRecord(Level.FINE, msg, args));
}
totalAccuracy[index]++;
}
// The same means. Check that it is more precise
} else if (!DoubleEquality.almostEqualRelativeOrAbsolute(sd1, sd2, 0.05, 0)) {
if (sd2 < sd1) {
betterPrecision[index]++;
args[args.length - 1] = "P*";
logger.log(TestLogUtils.getRecord(Level.FINE, msg, args));
} else {
args[args.length - 1] = "P";
logger.log(TestLogUtils.getRecord(Level.FINE, msg, args));
}
totalPrecision[index]++;
}
}
}
}
int better = 0;
int total = 0;
for (int index = 0; index < statName.length; index++) {
better += betterPrecision[index] + betterAccuracy[index];
total += totalPrecision[index] + totalAccuracy[index];
test(name2, name, statName[index] + " P", betterPrecision[index], totalPrecision[index], Level.FINE);
test(name2, name, statName[index] + " A", betterAccuracy[index], totalAccuracy[index], Level.FINE);
}
test(name2, name, String.format("All (eval [%d] [%d]) : ", i1, i2), better, total, Level.INFO);
}
Aggregations