use of at.ac.tuwien.kr.alpha.api.programs.literals.AggregateLiteral in project Alpha by alpha-asp.
the class AggregateRewritingRuleAnalysis method analyzeRuleDependencies.
private void analyzeRuleDependencies() {
for (AggregateLiteral lit : globalVariablesPerAggregate.keySet()) {
Set<VariableTerm> nonBindingVars = new HashSet<>(globalVariablesPerAggregate.get(lit));
Term leftHandTerm = lit.getAtom().getLowerBoundTerm();
if (lit.getBindingVariables().isEmpty() && leftHandTerm instanceof VariableTerm) {
/*
* If the "left-hand" term LT of the literal is a variable and not binding, it has to be non-binding,
* i.e. the aggregate literal depends on the literals binding LT.
*/
nonBindingVars.add((VariableTerm) leftHandTerm);
}
Set<Literal> dependencies = new HashSet<>();
Set<Literal> bodyWithoutLit = SetUtils.difference(rule.getBody(), Collections.singleton(lit));
findBindingLiterals(nonBindingVars, new HashSet<>(), dependencies, bodyWithoutLit, globalVariablesPerAggregate);
dependenciesPerAggregate.put(lit, dependencies);
}
}
use of at.ac.tuwien.kr.alpha.api.programs.literals.AggregateLiteral in project Alpha by alpha-asp.
the class AggregateRewritingContext method registerRule.
/**
* Registers a rule that potentially contains one or more {@link AggregateLiteral}s with this context.
* In case aggregates are found in the rule, global variables and dependencies of the aggregate are calculated and the
* aggregate literal is stored in the rewriting context along with a reference to the rule it occurs in
*
* @param rule
* @return true if the given rule contains one or more aggregate literals, false otherwise
*/
public boolean registerRule(Rule<Head> rule) {
AggregateRewritingRuleAnalysis ruleAnalysis = AggregateRewritingRuleAnalysis.analyzeRuleDependencies(rule);
if (ruleAnalysis.aggregatesInRule.isEmpty()) {
// Rule has no aggregates.
return false;
}
// Do initial registration of each aggregate literal and keep the ids.
for (Map.Entry<AggregateLiteral, Set<VariableTerm>> entry : ruleAnalysis.globalVariablesPerAggregate.entrySet()) {
registerAggregateLiteral(entry.getKey(), entry.getValue());
}
// Now go through dependencies and replace the actual aggregate literals with their rewritten versions
for (Map.Entry<AggregateLiteral, Set<Literal>> entry : ruleAnalysis.dependenciesPerAggregate.entrySet()) {
AggregateInfo aggregateInfo = getAggregateInfo(entry.getKey());
for (Literal dependency : entry.getValue()) {
if (dependency instanceof AggregateLiteral) {
AggregateInfo dependencyInfo = getAggregateInfo((AggregateLiteral) dependency);
aggregateInfo.addDependency(dependencyInfo.getOutputAtom().toLiteral(!dependency.isNegated()));
} else {
aggregateInfo.addDependency(dependency);
}
}
}
rulesWithAggregates.add(rule);
return true;
}
use of at.ac.tuwien.kr.alpha.api.programs.literals.AggregateLiteral in project Alpha by alpha-asp.
the class AggregateOperatorNormalizationTest method assertAggregateBoundIncremented.
private static void assertAggregateBoundIncremented(Rule<Head> sourceRule, Rule<Head> rewrittenRule) {
AggregateLiteral sourceAggregate = null;
for (Literal lit : sourceRule.getBody()) {
if (lit instanceof AggregateLiteral) {
sourceAggregate = (AggregateLiteral) lit;
}
}
AggregateLiteral rewrittenAggregate = null;
ComparisonLiteral addedComparisonLiteral = null;
for (Literal lit : rewrittenRule.getBody()) {
if (lit instanceof AggregateLiteral) {
rewrittenAggregate = (AggregateLiteral) lit;
} else if (lit instanceof ComparisonLiteral) {
addedComparisonLiteral = (ComparisonLiteral) lit;
}
}
assertNotNull(addedComparisonLiteral);
assertEquals(addedComparisonLiteral.getAtom().getTerms().get(0), rewrittenAggregate.getAtom().getLowerBoundTerm());
Term comparisonRightHandTerm = addedComparisonLiteral.getAtom().getTerms().get(1);
assertTrue(comparisonRightHandTerm instanceof ArithmeticTerm);
ArithmeticTerm incrementTerm = (ArithmeticTerm) comparisonRightHandTerm;
assertEquals(ArithmeticOperator.PLUS, incrementTerm.getOperator());
assertEquals(Terms.newConstant(1), incrementTerm.getRightOperand());
Term sourceBound = sourceAggregate.getAtom().getLowerBoundTerm() != null ? sourceAggregate.getAtom().getLowerBoundTerm() : sourceAggregate.getAtom().getUpperBoundTerm();
assertEquals(sourceBound, incrementTerm.getLeftOperand());
}
use of at.ac.tuwien.kr.alpha.api.programs.literals.AggregateLiteral in project Alpha by alpha-asp.
the class AggregateOperatorNormalizationTest method assertOperatorNormalized.
private static void assertOperatorNormalized(Rule<Head> rewrittenRule, ComparisonOperator expectedRewrittenOperator, boolean expectedRewrittenLiteralPositive) {
AggregateLiteral rewrittenAggregate = null;
for (Literal lit : rewrittenRule.getBody()) {
if (lit instanceof AggregateLiteral) {
rewrittenAggregate = (AggregateLiteral) lit;
}
}
assertNotNull(rewrittenAggregate);
assertEquals(expectedRewrittenOperator, rewrittenAggregate.getAtom().getLowerBoundOperator());
assertTrue(expectedRewrittenLiteralPositive == !rewrittenAggregate.isNegated());
}
use of at.ac.tuwien.kr.alpha.api.programs.literals.AggregateLiteral in project Alpha by alpha-asp.
the class AggregateRewritingRuleAnalysisTest method bindingAggregateWithGlobals2.
@Test
public void bindingAggregateWithGlobals2() {
AggregateRewritingRuleAnalysis analysis = analyze(BINDING_AGGREGATE_WITH_GLOBALS_2);
assertEquals(2, analysis.globalVariablesPerAggregate.size());
assertEquals(2, analysis.dependenciesPerAggregate.size());
// Verify correct analysis of max aggregate
List<Term> vertexDegreeTerms = Collections.singletonList(Terms.newVariable("DV"));
Literal vertexDegreeLiteral = Literals.fromAtom(Atoms.newBasicAtom(Predicates.getPredicate("graph_vertex_degree", 3), Terms.newVariable("G"), Terms.newVariable("V"), Terms.newVariable("DV")), true);
List<Literal> vertexDegreeLiterals = Collections.singletonList(vertexDegreeLiteral);
AggregateElement vertexDegree = Atoms.newAggregateElement(vertexDegreeTerms, vertexDegreeLiterals);
AggregateLiteral maxAggregate = Literals.fromAtom(Atoms.newAggregateAtom(ComparisonOperators.EQ, Terms.newVariable("DMAX"), AggregateFunctionSymbol.MAX, Collections.singletonList(vertexDegree)), true);
assertTrue(analysis.globalVariablesPerAggregate.containsKey(maxAggregate));
Set<VariableTerm> maxAggrGlobalVars = analysis.globalVariablesPerAggregate.get(maxAggregate);
assertEquals(1, maxAggrGlobalVars.size());
assertTrue(maxAggrGlobalVars.contains(Terms.newVariable("G")));
assertTrue(analysis.dependenciesPerAggregate.containsKey(maxAggregate));
Set<Literal> maxAggrDependencies = analysis.dependenciesPerAggregate.get(maxAggregate);
assertEquals(1, maxAggrDependencies.size());
Literal graph = Literals.fromAtom(Atoms.newBasicAtom(Predicates.getPredicate("graph", 1), Terms.newVariable("G")), true);
assertTrue(maxAggrDependencies.contains(graph));
// Verify correct analysis of count aggregate
List<Term> maxVertexDegreeTerms = Collections.singletonList(Terms.newVariable("V"));
Literal maxVertexDegreeLiteral = Literals.fromAtom(Atoms.newBasicAtom(Predicates.getPredicate("graph_vertex_degree", 3), Terms.newVariable("G"), Terms.newVariable("V"), Terms.newVariable("DMAX")), true);
List<Literal> maxVertexDegreeLiterals = Collections.singletonList(maxVertexDegreeLiteral);
AggregateElement maxVertexDegree = Atoms.newAggregateElement(maxVertexDegreeTerms, maxVertexDegreeLiterals);
AggregateLiteral countAggregate = Literals.fromAtom(Atoms.newAggregateAtom(ComparisonOperators.EQ, Terms.newVariable("N"), AggregateFunctionSymbol.COUNT, Collections.singletonList(maxVertexDegree)), true);
assertTrue(analysis.globalVariablesPerAggregate.containsKey(countAggregate));
Set<VariableTerm> cntAggrGlobalVars = analysis.globalVariablesPerAggregate.get(countAggregate);
assertEquals(2, cntAggrGlobalVars.size());
assertTrue(cntAggrGlobalVars.contains(Terms.newVariable("G")));
assertTrue(cntAggrGlobalVars.contains(Terms.newVariable("DMAX")));
assertTrue(analysis.dependenciesPerAggregate.containsKey(countAggregate));
Set<Literal> cntAggrDependencies = analysis.dependenciesPerAggregate.get(countAggregate);
assertEquals(2, cntAggrDependencies.size());
assertTrue(cntAggrDependencies.contains(graph));
assertTrue(cntAggrDependencies.contains(maxAggregate));
}
Aggregations