use of edu.rice.cs.caper.bayou.core.dsl.DSubTree in project bayou by capergroup.
the class JaccardAPICallsMetric method compute.
/**
* Computes the minimum Jaccard distance on the set of API calls
* between the original and the predicted ASTs.
*/
@Override
public float compute(DSubTree originalAST, List<DSubTree> predictedASTs) {
List<Float> jaccard = new ArrayList<>();
jaccard.add((float) 1);
Set<DAPICall> A = originalAST.bagOfAPICalls();
for (DSubTree predictedAST : predictedASTs) {
Set<DAPICall> B = predictedAST.bagOfAPICalls();
// A union B
Set<DAPICall> AunionB = new HashSet<>();
AunionB.addAll(A);
AunionB.addAll(B);
// A intersect B
Set<DAPICall> AinterB = new HashSet<>();
AinterB.addAll(A);
AinterB.retainAll(B);
jaccard.add(1 - ((float) AinterB.size()) / AunionB.size());
}
return Metric.min(jaccard);
}
use of edu.rice.cs.caper.bayou.core.dsl.DSubTree in project bayou by capergroup.
the class JaccardSequencesMetric method compute.
/**
* Computes the minimum Jaccard distance on the set of sequences of API calls
* between the original and the predicted ASTs.
*/
@Override
public float compute(DSubTree originalAST, List<DSubTree> predictedASTs) {
List<Float> jaccard = new ArrayList<>();
jaccard.add((float) 1);
Set<Sequence> A;
try {
List<Sequence> _A = new ArrayList<>();
_A.add(new Sequence());
originalAST.updateSequences(_A, 999, 999);
A = new HashSet<>(_A);
} catch (DASTNode.TooManySequencesException | DASTNode.TooLongSequenceException e) {
return (float) 1;
}
for (DSubTree predictedAST : predictedASTs) {
Set<Sequence> B;
try {
List<Sequence> _B = new ArrayList<>();
_B.add(new Sequence());
predictedAST.updateSequences(_B, 999, 999);
B = new HashSet<>(_B);
} catch (DASTNode.TooManySequencesException | DASTNode.TooLongSequenceException e) {
jaccard.add((float) 1);
continue;
}
// A union B
Set<Sequence> AunionB = new HashSet<>();
AunionB.addAll(A);
AunionB.addAll(B);
// A intersect B
Set<Sequence> AinterB = new HashSet<>();
AinterB.addAll(A);
AinterB.retainAll(B);
jaccard.add(1 - ((float) AinterB.size()) / AunionB.size());
}
return Metric.min(jaccard);
}
use of edu.rice.cs.caper.bayou.core.dsl.DSubTree in project bayou by capergroup.
the class MetricCalculator method execute.
public void execute() throws IOException {
if (cmdLine == null)
return;
int topk = cmdLine.hasOption("t") ? Integer.parseInt(cmdLine.getOptionValue("t")) : 3;
Metric metric;
String m = cmdLine.getOptionValue("m");
switch(m) {
case "equality-ast":
metric = new EqualityASTMetric();
break;
case "jaccard-sequences":
metric = new JaccardSequencesMetric();
break;
case "jaccard-api-calls":
metric = new JaccardAPICallsMetric();
break;
case "num-control-structures":
metric = new NumControlStructuresMetric();
break;
case "num-statements":
metric = new NumStatementsMetric();
break;
case "latency":
metric = null;
break;
default:
System.err.println("invalid metric: " + cmdLine.getOptionValue("m"));
return;
}
List<JSONInputFormat.DataPoint> data = JSONInputFormat.readData(cmdLine.getOptionValue("f"));
List<Float> values = new ArrayList<>();
for (JSONInputFormat.DataPoint datapoint : data) {
DSubTree originalAST = datapoint.ast;
List<DSubTree> predictedASTs = datapoint.out_asts.subList(0, Math.min(topk, datapoint.out_asts.size()));
if (m.equals("latency"))
values.add(datapoint.latency);
else
values.add(metric.compute(originalAST, predictedASTs));
}
float average = Metric.mean(values);
float stdv = Metric.standardDeviation(values);
System.out.println(String.format("%.2f,%.2f", average, stdv));
}
use of edu.rice.cs.caper.bayou.core.dsl.DSubTree in project bayou by capergroup.
the class JaccardSequencesMetric method compute.
/**
* Computes the minimum Jaccard distance on the set of sequences of API calls
* between the original and the predicted ASTs.
*/
@Override
public float compute(DSubTree originalAST, List<DSubTree> predictedASTs, String aggregate) {
List<Float> jaccard = new ArrayList<>();
jaccard.add((float) 1);
Set<Sequence> A;
try {
List<Sequence> _A = new ArrayList<>();
_A.add(new Sequence());
originalAST.updateSequences(_A, 999, 999);
A = new HashSet<>(_A);
} catch (DASTNode.TooManySequencesException | DASTNode.TooLongSequenceException e) {
return (float) 1;
}
for (DSubTree predictedAST : predictedASTs) {
Set<Sequence> B;
try {
List<Sequence> _B = new ArrayList<>();
_B.add(new Sequence());
predictedAST.updateSequences(_B, 999, 999);
B = new HashSet<>(_B);
} catch (DASTNode.TooManySequencesException | DASTNode.TooLongSequenceException e) {
jaccard.add((float) 1);
continue;
}
// A union B
Set<Sequence> AunionB = new HashSet<>();
AunionB.addAll(A);
AunionB.addAll(B);
// A intersect B
Set<Sequence> AinterB = new HashSet<>();
AinterB.addAll(A);
AinterB.retainAll(B);
jaccard.add(1 - ((float) AinterB.size()) / AunionB.size());
}
return Metric.aggregate(jaccard, aggregate);
}
use of edu.rice.cs.caper.bayou.core.dsl.DSubTree in project bayou by capergroup.
the class NumStatementsMetric method compute.
/**
* Computes the minimum ratio of the difference between the number of
* statements in the original vs predicted ASTs.
*/
@Override
public float compute(DSubTree originalAST, List<DSubTree> predictedASTs, String aggregate) {
int original = originalAST.numStatements();
List<Integer> diffs = new ArrayList<>();
diffs.add(original);
for (DSubTree predictedAST : predictedASTs) {
int predicted = predictedAST.numStatements();
int diff_predicted = Math.abs(predicted - original);
diffs.add(diff_predicted);
}
float aggr_diff = Metric.aggregate(diffs, aggregate);
return aggr_diff / original;
}
Aggregations