Search in sources :

Example 1 with EqualityASTMetric

use of edu.rice.cs.caper.bayou.core.sketch_metric.EqualityASTMetric in project bayou by capergroup.

the class TrialsRunner method runTrails.

/**
 * Execute the given trials synthesizing result programs from the given Synthesizer and report progress
 * to the given view.
 *
 * @param trials the trials to run
 * @param synthesizer the synthesizer to use for mapping draft programs onto results
 * @param view the view used to track execution progress
 */
static void runTrails(List<Trial> trials, Synthesizer synthesizer, View view) {
    if (trials == null)
        throw new NullPointerException("trials");
    if (synthesizer == null)
        throw new NullPointerException("synthesizer");
    if (view == null)
        throw new NullPointerException("view");
    // unique number for the current trial.
    int currentTrialId = 0;
    int possibleCompilePointsAccum = 0;
    int obtainedCompilePointsAccum = 0;
    int possibleTestCasePointsAccum = 0;
    int obtainedTestCasePointsAccum = 0;
    int possibleSketchMatchPointsAccum = 0;
    int obtainedSketchMatchPointsAccum = 0;
    // track any thrown SynthesizeException from synthesizer.synthesize(...)
    boolean anySynthesizeFailed = false;
    for (Trial trial : trials) {
        currentTrialId++;
        String draftProgram = trial.getDraftProgramSource();
        view.declareStartOfTrial(trial.getDescription(), draftProgram);
        /*
             * Provide the draft program to the synthesizer and collect the synthesis results.
             */
        List<String> synthResults;
        try {
            synthResults = synthesizer.synthesize(draftProgram);
        } catch (SynthesizeException e) {
            view.declareSynthesisFailed();
            view.declareTrialResultSynthesisFailed();
            anySynthesizeFailed = true;
            continue;
        }
        view.declareNumberOfResults(synthResults.size());
        /*
             * Rewrite each result class source to have a unique class name.
             *
             * We will compile and class-load each result so each class needs to have a unique class name.
             */
        List<SourceClass> synthResultsWithUniqueClassNames = new LinkedList<>();
        {
            // floodgage assumes that the name of the synthesized class remains unchanged from the input class name.
            String assumedResultClassName = trial.getDraftProgramClassName();
            int resultId = 0;
            for (String result : synthResults) {
                resultId++;
                String uniqueResultClassName = assumedResultClassName + "_" + currentTrialId + "_" + resultId;
                String classSource = result.replaceFirst("public class " + assumedResultClassName, "public class " + uniqueResultClassName);
                synthResultsWithUniqueClassNames.add(new SourceClass(uniqueResultClassName, classSource));
            }
        }
        /*
             * If an expected sketch source was added to this trial, construct the sketch from the source.
             * Also increment possiblePointsAccum if there is an expected sketch.
             *
             * If no expected sketch is defined, set expectedSketch to null.
             */
        // true: sketch expected and a match found in some result.
        Boolean anySketchMatch;
        // false: sketch expected but no match found so far among any result.
        // null: no expected sketch.
        DSubTree expectedSketch;
        {
            if (trial.containsSketchProgramSource()) {
                String expectedSketchSource = trial.tryGetSketchProgramSource(null);
                expectedSketch = makeSketchFromSource(expectedSketchSource, trial.getDraftProgramClassName() + ".java");
                // we will test to sketch matches, so init to none seen yet
                anySketchMatch = false;
            } else {
                expectedSketch = null;
                // we wont do sketch testing to singal no expected sketch
                anySketchMatch = null;
            }
        }
        if (trial.containsSketchProgramSource())
            // possible point for some result matching the sketch.
            possibleSketchMatchPointsAccum++;
        /*
             * For each result:
             *
             *     1.) Check that the result compiles
             *     2.) If compiles, check that the result passes the test suite.
             *     3.) If compiles and the trial contains an expected sketch, check for sketch equivalency.
             */
        int resultId = 0;
        for (SourceClass result : synthResultsWithUniqueClassNames) {
            resultId++;
            // possible point for result compiling.
            possibleCompilePointsAccum++;
            view.declareSynthesizeResult(resultId, result.classSource);
            /*
                 * If a sketch expectation is declared for this trial, perform sketch comparison.
                 */
            boolean sketchMatch = false;
            if (// if there is an expected sketch
            anySketchMatch != null) {
                DSubTree resultSketch = makeSketchFromSource(result.classSource, result.className + ".java");
                if (// both null
                expectedSketch == null && resultSketch == null) {
                    view.warnSketchesNull();
                    sketchMatch = true;
                } else if (// both non-null
                expectedSketch != null && resultSketch != null) {
                    EqualityASTMetric m = new EqualityASTMetric();
                    float equalityResult = m.compute(expectedSketch, Collections.singletonList(resultSketch), "");
                    view.declareSketchMetricResult(equalityResult);
                    if (equalityResult == 1)
                        sketchMatch = true;
                } else // one null
                {
                    view.warnSketchesNullMismatch();
                }
                anySketchMatch = anySketchMatch || sketchMatch;
            }
            /*
                 * Check that the result compiles.
                 */
            boolean resultCompiled;
            try {
                // ensure any compiler errors start on new line
                System.out.println("");
                CompilerUtils.CACHED_COMPILER.loadFromJava(result.className, result.classSource);
                resultCompiled = true;
                // for compiling
                obtainedCompilePointsAccum++;
            } catch (ClassNotFoundException e) {
                resultCompiled = false;
            }
            /*
                 * If result compiles, run test cases.
                 *
                 * (Construct a test suite for result and run the result against the suite.)
                 */
            boolean testCasesPass;
            {
                if (resultCompiled) {
                    view.declareStartOfTestCases();
                    Class resultSpecificTestSuite = makeResultSpecificTestSuite(result.className);
                    TestSuiteRunner.RunResult runResult = TestSuiteRunner.runTestSuiteAgainst(resultSpecificTestSuite, view);
                    possibleTestCasePointsAccum += runResult.TestCaseCount;
                    obtainedTestCasePointsAccum += runResult.TestCaseCount - runResult.FailCount;
                    testCasesPass = runResult.FailCount == 0;
                } else {
                    // even if no test cases we say uncompilable code fails the null t.c.
                    testCasesPass = false;
                    /*
                         * Count the number of test cases in TestSuite (since we didnt run JUnit) and add each
                         * to the possiblePointsAccum count.
                         */
                    Class testSuiteClass;
                    try {
                        testSuiteClass = Class.forName("TestSuite");
                    } catch (ClassNotFoundException e) {
                        // TestSuite should be in the trialpack
                        throw new RuntimeException(e);
                    }
                    for (Method m : testSuiteClass.getMethods()) {
                        if (m.getAnnotation(org.junit.Test.class) != null)
                            possibleTestCasePointsAccum++;
                    }
                }
            }
            /*
                 * Report results for the result.
                 */
            Boolean reportSketchMatch = trial.containsSketchProgramSource() ? sketchMatch : null;
            view.declareSynthResultResult(resultCompiled, testCasesPass, reportSketchMatch);
        }
        /*
             * If any sketch matched among the results, award a point. 
             */
        if (anySketchMatch != null && anySketchMatch)
            obtainedSketchMatchPointsAccum++;
    }
    /*
         * Report total tallies.
         */
    view.declarePointScore(possibleCompilePointsAccum, obtainedCompilePointsAccum, possibleTestCasePointsAccum, obtainedTestCasePointsAccum, possibleSketchMatchPointsAccum, obtainedSketchMatchPointsAccum);
    if (anySynthesizeFailed)
        // this can really influence the score by skipping points, so remind.
        view.declareSynthesisFailed();
}
Also used : Trial(edu.rice.cs.caper.floodgage.application.floodgage.model.plan.Trial) Method(java.lang.reflect.Method) EqualityASTMetric(edu.rice.cs.caper.bayou.core.sketch_metric.EqualityASTMetric) SynthesizeException(edu.rice.cs.caper.floodgage.application.floodgage.synthesizer.SynthesizeException)

Example 2 with EqualityASTMetric

use of edu.rice.cs.caper.bayou.core.sketch_metric.EqualityASTMetric in project bayou by capergroup.

the class MetricCalculator method execute.

public void execute() throws IOException {
    if (cmdLine == null)
        return;
    int topk = cmdLine.hasOption("t") ? Integer.parseInt(cmdLine.getOptionValue("t")) : 10;
    Metric metric;
    String m = cmdLine.getOptionValue("m");
    switch(m) {
        case "equality-ast":
            metric = new EqualityASTMetric();
            break;
        case "jaccard-sequences":
            metric = new JaccardSequencesMetric();
            break;
        case "jaccard-api-calls":
            metric = new JaccardAPICallsMetric();
            break;
        case "num-control-structures":
            metric = new NumControlStructuresMetric();
            break;
        case "num-statements":
            metric = new NumStatementsMetric();
            break;
        default:
            System.err.println("invalid metric: " + cmdLine.getOptionValue("m"));
            return;
    }
    int inCorpus = cmdLine.hasOption("c") ? Integer.parseInt(cmdLine.getOptionValue("c")) : 1;
    String aggregate = cmdLine.hasOption("a") ? cmdLine.getOptionValue("a") : "min";
    List<JSONInputFormat.DataPoint> data = JSONInputFormat.readData(cmdLine.getOptionValue("f"));
    if (inCorpus == 2)
        data = data.stream().filter(datapoint -> datapoint.in_corpus).collect(Collectors.toList());
    else if (inCorpus == 3)
        data = data.stream().filter(datapoint -> !datapoint.in_corpus).collect(Collectors.toList());
    List<Float> values = new ArrayList<>();
    for (JSONInputFormat.DataPoint datapoint : data) {
        DSubTree originalAST = datapoint.ast;
        List<DSubTree> predictedASTs = datapoint.out_asts.subList(0, Math.min(topk, datapoint.out_asts.size()));
        values.add(metric.compute(originalAST, predictedASTs, aggregate));
    }
    List<Float> values2 = new ArrayList<>();
    if (cmdLine.hasOption("p")) {
        List<JSONInputFormat.DataPoint> data2 = JSONInputFormat.readData(cmdLine.getOptionValue("p"));
        if (inCorpus == 2)
            data2 = data2.stream().filter(datapoint -> datapoint.in_corpus).collect(Collectors.toList());
        else if (inCorpus == 3)
            data2 = data2.stream().filter(datapoint -> !datapoint.in_corpus).collect(Collectors.toList());
        for (JSONInputFormat.DataPoint datapoint : data2) {
            DSubTree originalAST = datapoint.ast;
            List<DSubTree> predictedASTs = datapoint.out_asts.subList(0, Math.min(topk, datapoint.out_asts.size()));
            values2.add(metric.compute(originalAST, predictedASTs, aggregate));
        }
        if (values.size() != values2.size())
            throw new Error("DATA files do not match in size. Cannot compute p-value.");
    }
    float average = Metric.mean(values);
    float stdv = Metric.standardDeviation(values);
    if (cmdLine.hasOption("p")) {
        double[] dValues = values.stream().mapToDouble(v -> v.floatValue()).toArray();
        double[] dValues2 = values2.stream().mapToDouble(v -> v.floatValue()).toArray();
        double pValue = new TTest().pairedTTest(dValues, dValues2);
        System.out.println(String.format("%s (%d data points, each aggregated with %s): average=%f, stdv=%f, pvalue=%e", m, data.size(), aggregate, average, stdv, pValue));
    } else
        System.out.println(String.format("%s (%d data points, each aggregated with %s): average=%f, stdv=%f", m, data.size(), aggregate, average, stdv));
}
Also used : List(java.util.List) edu.rice.cs.caper.bayou.core.dsl(edu.rice.cs.caper.bayou.core.dsl) Metric(edu.rice.cs.caper.bayou.core.sketch_metric.Metric) org.apache.commons.cli(org.apache.commons.cli) TTest(org.apache.commons.math3.stat.inference.TTest) IOException(java.io.IOException) EqualityASTMetric(edu.rice.cs.caper.bayou.core.sketch_metric.EqualityASTMetric) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) TTest(org.apache.commons.math3.stat.inference.TTest) ArrayList(java.util.ArrayList) EqualityASTMetric(edu.rice.cs.caper.bayou.core.sketch_metric.EqualityASTMetric) Metric(edu.rice.cs.caper.bayou.core.sketch_metric.Metric) EqualityASTMetric(edu.rice.cs.caper.bayou.core.sketch_metric.EqualityASTMetric)

Aggregations

EqualityASTMetric (edu.rice.cs.caper.bayou.core.sketch_metric.EqualityASTMetric)2 edu.rice.cs.caper.bayou.core.dsl (edu.rice.cs.caper.bayou.core.dsl)1 Metric (edu.rice.cs.caper.bayou.core.sketch_metric.Metric)1 Trial (edu.rice.cs.caper.floodgage.application.floodgage.model.plan.Trial)1 SynthesizeException (edu.rice.cs.caper.floodgage.application.floodgage.synthesizer.SynthesizeException)1 IOException (java.io.IOException)1 Method (java.lang.reflect.Method)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1 Collectors (java.util.stream.Collectors)1 org.apache.commons.cli (org.apache.commons.cli)1 TTest (org.apache.commons.math3.stat.inference.TTest)1