use of edu.rice.cs.caper.bayou.core.sketch_metric.EqualityASTMetric in project bayou by capergroup.
the class TrialsRunner method runTrails.
/**
* Execute the given trials synthesizing result programs from the given Synthesizer and report progress
* to the given view.
*
* @param trials the trials to run
* @param synthesizer the synthesizer to use for mapping draft programs onto results
* @param view the view used to track execution progress
*/
static void runTrails(List<Trial> trials, Synthesizer synthesizer, View view) {
if (trials == null)
throw new NullPointerException("trials");
if (synthesizer == null)
throw new NullPointerException("synthesizer");
if (view == null)
throw new NullPointerException("view");
// unique number for the current trial.
int currentTrialId = 0;
int possibleCompilePointsAccum = 0;
int obtainedCompilePointsAccum = 0;
int possibleTestCasePointsAccum = 0;
int obtainedTestCasePointsAccum = 0;
int possibleSketchMatchPointsAccum = 0;
int obtainedSketchMatchPointsAccum = 0;
// track any thrown SynthesizeException from synthesizer.synthesize(...)
boolean anySynthesizeFailed = false;
for (Trial trial : trials) {
currentTrialId++;
String draftProgram = trial.getDraftProgramSource();
view.declareStartOfTrial(trial.getDescription(), draftProgram);
/*
* Provide the draft program to the synthesizer and collect the synthesis results.
*/
List<String> synthResults;
try {
synthResults = synthesizer.synthesize(draftProgram);
} catch (SynthesizeException e) {
view.declareSynthesisFailed();
view.declareTrialResultSynthesisFailed();
anySynthesizeFailed = true;
continue;
}
view.declareNumberOfResults(synthResults.size());
/*
* Rewrite each result class source to have a unique class name.
*
* We will compile and class-load each result so each class needs to have a unique class name.
*/
List<SourceClass> synthResultsWithUniqueClassNames = new LinkedList<>();
{
// floodgage assumes that the name of the synthesized class remains unchanged from the input class name.
String assumedResultClassName = trial.getDraftProgramClassName();
int resultId = 0;
for (String result : synthResults) {
resultId++;
String uniqueResultClassName = assumedResultClassName + "_" + currentTrialId + "_" + resultId;
String classSource = result.replaceFirst("public class " + assumedResultClassName, "public class " + uniqueResultClassName);
synthResultsWithUniqueClassNames.add(new SourceClass(uniqueResultClassName, classSource));
}
}
/*
* If an expected sketch source was added to this trial, construct the sketch from the source.
* Also increment possiblePointsAccum if there is an expected sketch.
*
* If no expected sketch is defined, set expectedSketch to null.
*/
// true: sketch expected and a match found in some result.
Boolean anySketchMatch;
// false: sketch expected but no match found so far among any result.
// null: no expected sketch.
DSubTree expectedSketch;
{
if (trial.containsSketchProgramSource()) {
String expectedSketchSource = trial.tryGetSketchProgramSource(null);
expectedSketch = makeSketchFromSource(expectedSketchSource, trial.getDraftProgramClassName() + ".java");
// we will test to sketch matches, so init to none seen yet
anySketchMatch = false;
} else {
expectedSketch = null;
// we wont do sketch testing to singal no expected sketch
anySketchMatch = null;
}
}
if (trial.containsSketchProgramSource())
// possible point for some result matching the sketch.
possibleSketchMatchPointsAccum++;
/*
* For each result:
*
* 1.) Check that the result compiles
* 2.) If compiles, check that the result passes the test suite.
* 3.) If compiles and the trial contains an expected sketch, check for sketch equivalency.
*/
int resultId = 0;
for (SourceClass result : synthResultsWithUniqueClassNames) {
resultId++;
// possible point for result compiling.
possibleCompilePointsAccum++;
view.declareSynthesizeResult(resultId, result.classSource);
/*
* If a sketch expectation is declared for this trial, perform sketch comparison.
*/
boolean sketchMatch = false;
if (// if there is an expected sketch
anySketchMatch != null) {
DSubTree resultSketch = makeSketchFromSource(result.classSource, result.className + ".java");
if (// both null
expectedSketch == null && resultSketch == null) {
view.warnSketchesNull();
sketchMatch = true;
} else if (// both non-null
expectedSketch != null && resultSketch != null) {
EqualityASTMetric m = new EqualityASTMetric();
float equalityResult = m.compute(expectedSketch, Collections.singletonList(resultSketch), "");
view.declareSketchMetricResult(equalityResult);
if (equalityResult == 1)
sketchMatch = true;
} else // one null
{
view.warnSketchesNullMismatch();
}
anySketchMatch = anySketchMatch || sketchMatch;
}
/*
* Check that the result compiles.
*/
boolean resultCompiled;
try {
// ensure any compiler errors start on new line
System.out.println("");
CompilerUtils.CACHED_COMPILER.loadFromJava(result.className, result.classSource);
resultCompiled = true;
// for compiling
obtainedCompilePointsAccum++;
} catch (ClassNotFoundException e) {
resultCompiled = false;
}
/*
* If result compiles, run test cases.
*
* (Construct a test suite for result and run the result against the suite.)
*/
boolean testCasesPass;
{
if (resultCompiled) {
view.declareStartOfTestCases();
Class resultSpecificTestSuite = makeResultSpecificTestSuite(result.className);
TestSuiteRunner.RunResult runResult = TestSuiteRunner.runTestSuiteAgainst(resultSpecificTestSuite, view);
possibleTestCasePointsAccum += runResult.TestCaseCount;
obtainedTestCasePointsAccum += runResult.TestCaseCount - runResult.FailCount;
testCasesPass = runResult.FailCount == 0;
} else {
// even if no test cases we say uncompilable code fails the null t.c.
testCasesPass = false;
/*
* Count the number of test cases in TestSuite (since we didnt run JUnit) and add each
* to the possiblePointsAccum count.
*/
Class testSuiteClass;
try {
testSuiteClass = Class.forName("TestSuite");
} catch (ClassNotFoundException e) {
// TestSuite should be in the trialpack
throw new RuntimeException(e);
}
for (Method m : testSuiteClass.getMethods()) {
if (m.getAnnotation(org.junit.Test.class) != null)
possibleTestCasePointsAccum++;
}
}
}
/*
* Report results for the result.
*/
Boolean reportSketchMatch = trial.containsSketchProgramSource() ? sketchMatch : null;
view.declareSynthResultResult(resultCompiled, testCasesPass, reportSketchMatch);
}
/*
* If any sketch matched among the results, award a point.
*/
if (anySketchMatch != null && anySketchMatch)
obtainedSketchMatchPointsAccum++;
}
/*
* Report total tallies.
*/
view.declarePointScore(possibleCompilePointsAccum, obtainedCompilePointsAccum, possibleTestCasePointsAccum, obtainedTestCasePointsAccum, possibleSketchMatchPointsAccum, obtainedSketchMatchPointsAccum);
if (anySynthesizeFailed)
// this can really influence the score by skipping points, so remind.
view.declareSynthesisFailed();
}
use of edu.rice.cs.caper.bayou.core.sketch_metric.EqualityASTMetric in project bayou by capergroup.
the class MetricCalculator method execute.
public void execute() throws IOException {
if (cmdLine == null)
return;
int topk = cmdLine.hasOption("t") ? Integer.parseInt(cmdLine.getOptionValue("t")) : 10;
Metric metric;
String m = cmdLine.getOptionValue("m");
switch(m) {
case "equality-ast":
metric = new EqualityASTMetric();
break;
case "jaccard-sequences":
metric = new JaccardSequencesMetric();
break;
case "jaccard-api-calls":
metric = new JaccardAPICallsMetric();
break;
case "num-control-structures":
metric = new NumControlStructuresMetric();
break;
case "num-statements":
metric = new NumStatementsMetric();
break;
default:
System.err.println("invalid metric: " + cmdLine.getOptionValue("m"));
return;
}
int inCorpus = cmdLine.hasOption("c") ? Integer.parseInt(cmdLine.getOptionValue("c")) : 1;
String aggregate = cmdLine.hasOption("a") ? cmdLine.getOptionValue("a") : "min";
List<JSONInputFormat.DataPoint> data = JSONInputFormat.readData(cmdLine.getOptionValue("f"));
if (inCorpus == 2)
data = data.stream().filter(datapoint -> datapoint.in_corpus).collect(Collectors.toList());
else if (inCorpus == 3)
data = data.stream().filter(datapoint -> !datapoint.in_corpus).collect(Collectors.toList());
List<Float> values = new ArrayList<>();
for (JSONInputFormat.DataPoint datapoint : data) {
DSubTree originalAST = datapoint.ast;
List<DSubTree> predictedASTs = datapoint.out_asts.subList(0, Math.min(topk, datapoint.out_asts.size()));
values.add(metric.compute(originalAST, predictedASTs, aggregate));
}
List<Float> values2 = new ArrayList<>();
if (cmdLine.hasOption("p")) {
List<JSONInputFormat.DataPoint> data2 = JSONInputFormat.readData(cmdLine.getOptionValue("p"));
if (inCorpus == 2)
data2 = data2.stream().filter(datapoint -> datapoint.in_corpus).collect(Collectors.toList());
else if (inCorpus == 3)
data2 = data2.stream().filter(datapoint -> !datapoint.in_corpus).collect(Collectors.toList());
for (JSONInputFormat.DataPoint datapoint : data2) {
DSubTree originalAST = datapoint.ast;
List<DSubTree> predictedASTs = datapoint.out_asts.subList(0, Math.min(topk, datapoint.out_asts.size()));
values2.add(metric.compute(originalAST, predictedASTs, aggregate));
}
if (values.size() != values2.size())
throw new Error("DATA files do not match in size. Cannot compute p-value.");
}
float average = Metric.mean(values);
float stdv = Metric.standardDeviation(values);
if (cmdLine.hasOption("p")) {
double[] dValues = values.stream().mapToDouble(v -> v.floatValue()).toArray();
double[] dValues2 = values2.stream().mapToDouble(v -> v.floatValue()).toArray();
double pValue = new TTest().pairedTTest(dValues, dValues2);
System.out.println(String.format("%s (%d data points, each aggregated with %s): average=%f, stdv=%f, pvalue=%e", m, data.size(), aggregate, average, stdv, pValue));
} else
System.out.println(String.format("%s (%d data points, each aggregated with %s): average=%f, stdv=%f", m, data.size(), aggregate, average, stdv));
}
Aggregations