use of claw.wani.language.ClawMapping in project claw-compiler by C2SM-RCM.
the class LoopExtraction method transform.
/**
* Apply the transformation. A loop extraction is applied in the following
* steps: 1) Duplicate the function targeted by the transformation 2) Extract
* the loop body in the duplicated function and remove the loop. 3) Adapt
* function call and demote array references in the duplicated function body. 4)
* Optional: Add a LoopFusion transformation to the transformations' queue.
*
* @param xcodeml The XcodeML on which the transformations are applied.
* @param translator The translator used to applied the transformations.
* @param transformation Only for dependent transformation. The other
* transformation part of the transformation.
* @throws IllegalTransformationException if the transformation cannot be
* applied.
*/
@Override
public void transform(XcodeProgram xcodeml, Translator translator, Transformation transformation) throws Exception {
ClawTranslator ct = (ClawTranslator) translator;
final Context context = ct.context();
/*
* DUPLICATE THE FUNCTION
*/
// Duplicate function definition
FfunctionDefinition clonedFctDef = _fctDefToExtract.cloneNode();
String newFctTypeHash = xcodeml.getTypeTable().generateHash(FortranType.FUNCTION);
String newFctName = clonedFctDef.getName() + ClawConstant.EXTRACTION_SUFFIX + translator.getNextTransformationCounter();
clonedFctDef.name().setValue(newFctName);
clonedFctDef.name().setType(newFctTypeHash);
// Update the symbol table in the fct definition
Xid fctId = clonedFctDef.getSymbolTable().get(_fctDefToExtract.getName());
fctId.setType(newFctTypeHash);
fctId.setName(newFctName);
// Get the fctType in typeTable
FfunctionType fctType = xcodeml.getTypeTable().getFunctionType(_fctDefToExtract);
FfunctionType newFctType = fctType.cloneNode();
newFctType.setType(newFctTypeHash);
xcodeml.getTypeTable().add(newFctType);
// Get the id from the global symbols table
Xid globalFctId = xcodeml.getGlobalSymbolsTable().get(_fctDefToExtract.getName());
// If the fct is define in the global symbol table, duplicate it
if (globalFctId != null) {
Xid newFctId = globalFctId.cloneNode();
newFctId.setType(newFctTypeHash);
newFctId.setName(newFctName);
xcodeml.getGlobalSymbolsTable().add(newFctId);
}
// Insert the duplicated function declaration
_fctDefToExtract.insertAfter(clonedFctDef);
// Find the loop that will be extracted
Xnode loopInClonedFct = locateDoStatement(clonedFctDef);
Message.debug(context, "loop-extract transformation: " + _claw.getPragma().value());
Message.debug(context, " created subroutine: " + clonedFctDef.getName());
/*
* REMOVE BODY FROM THE LOOP AND DELETE THE LOOP
*/
// 1. append body into fct body after loop
Loop.extractBody(loopInClonedFct);
// 2. delete loop
loopInClonedFct.delete();
/*
* ADAPT FUNCTION CALL AND DEMOTE ARRAY REFERENCES IN THE BODY OF THE FUNCTION
*/
// Wrap function call with loop
Xnode extractedLoop = wrapCallWithLoop(xcodeml, _extractedLoop);
Message.debug(context, " call wrapped with loop: " + _fctCall.matchDirectDescendant(Xcode.NAME).value() + " --> " + clonedFctDef.getName());
// Change called fct name
_fctCall.matchDirectDescendant(Xcode.NAME).setValue(newFctName);
_fctCall.matchDirectDescendant(Xcode.NAME).setType(newFctTypeHash);
// Adapt function call parameters and function declaration
XdeclTable fctDeclarations = clonedFctDef.getDeclarationTable();
XsymbolTable fctSymbols = clonedFctDef.getSymbolTable();
Message.debug(context, " Start to apply mapping: " + _claw.getMappings().size());
for (ClawMapping mapping : _claw.getMappings()) {
Message.debug(context, "Apply mapping (" + mapping.getMappedDimensions() + ") ");
for (ClawMappingVar var : mapping.getMappedVariables()) {
Message.debug(context, " Var: " + var);
Optional<Xnode> argument = _fctCall.findArg(var.getArgMapping());
if (!argument.isPresent()) {
continue;
}
/*
* Case 1: Var --> ArrayRef Var --> ArrayRef transformation 1. Check that the
* variable used as array index exists in the current scope (XdeclTable). If so,
* get its type value. Create a Var element for the arrayIndex. Create the
* arrayIndex element with Var as child.
*
* 2. Get the reference type of the base variable. 2.1 Create the varRef element
* with the type of base variable 2.2 insert clone of base variable in varRef 3.
* Create arrayRef element with varRef + arrayIndex
*/
if (argument.get().is(Xcode.VAR)) {
FbasicType type = xcodeml.getTypeTable().getBasicType(argument.get());
// Demotion cannot be applied as type dimension is smaller
if (type.getDimensions() < mapping.getMappedDimensions()) {
throw new IllegalTransformationException("mapping dimensions too big. Mapping " + mapping.toString() + " is wrong ...", _claw.getPragma().lineNo());
}
Xnode newArg = xcodeml.createNode(Xcode.F_ARRAY_REF);
newArg.setType(type.getRef());
Xnode varRef = xcodeml.createNode(Xcode.VAR_REF);
varRef.setType(argument.get().getType());
varRef.append(argument.get(), true);
newArg.append(varRef);
// create arrayIndex
for (ClawMappingVar mappingVar : mapping.getMappingVariables()) {
Xnode arrayIndex = xcodeml.createNode(Xcode.ARRAY_INDEX);
// Find the mapping var in the local table (fct scope)
Xnode mappingVarDecl = _fctDef.getDeclarationTable().get(mappingVar.getArgMapping());
// Add to arrayIndex
Xnode newMappingVar = xcodeml.createVar(mappingVarDecl.getType(), mappingVarDecl.matchSeq(Xcode.NAME).value(), Xscope.LOCAL);
arrayIndex.append(newMappingVar);
newArg.append(arrayIndex);
}
argument.get().insertAfter(newArg);
argument.get().delete();
}
// Case 2: ArrayRef (n arrayIndex) --> ArrayRef (n+m arrayIndex)
// Change variable declaration in extracted fct
Xnode varDecl = fctDeclarations.get(var.getFctMapping());
Xid id = fctSymbols.get(var.getFctMapping());
FbasicType varDeclType = xcodeml.getTypeTable().getBasicType(varDecl);
// Case 1: variable is demoted to scalar then take the ref type
if (varDeclType.getDimensions() == mapping.getMappedDimensions()) {
Xnode newVarDecl = xcodeml.createNode(Xcode.VAR_DECL);
newVarDecl.append(xcodeml.createName(var.getFctMapping(), varDeclType.getRef()));
fctDeclarations.replace(newVarDecl, var.getFctMapping());
id.setType(varDeclType.getRef());
}
}
// Loop mapped variables
}
// Loop over mapping clauses
// Adapt array reference in function body
List<Xnode> arrayReferences = clonedFctDef.body().matchAll(Xcode.F_ARRAY_REF);
for (Xnode ref : arrayReferences) {
if (!Xnode.isOfCode(ref.matchSeq(Xcode.VAR_REF).child(0), Xcode.VAR)) {
continue;
}
String mappedVar = ref.matchSeq(Xcode.VAR_REF, Xcode.VAR).value();
if (_fctMappingMap.containsKey(mappedVar)) {
ClawMapping mapping = _fctMappingMap.get(mappedVar);
boolean changeRef = true;
int mappingIndex = 0;
for (Xnode e : ref.children()) {
if (e.is(Xcode.ARRAY_INDEX)) {
List<Xnode> children = e.children();
if (!children.isEmpty() && Xnode.isOfCode(children.get(0), Xcode.VAR)) {
String varName = e.matchSeq(Xcode.VAR).value();
if (varName.equals(mapping.getMappingVariables().get(mappingIndex).getFctMapping())) {
++mappingIndex;
} else {
changeRef = false;
}
}
}
}
if (changeRef) {
// TODO Var ref should be extracted only if the reference can be
// totally demoted
ref.insertBefore(ref.matchSeq(Xcode.VAR_REF, Xcode.VAR).cloneNode());
ref.delete();
}
}
}
// Generate directive pragmas if needed
Xnode grip = null;
if (_claw.hasClause(ClawClause.ACC)) {
/*
* TODO see TODO in ExpandNotation OpenACC and OpenMP loop construct are pretty
* different ... have to look how to do that properly. See issue #22
*/
grip = Directive.generateAcceleratorClause(xcodeml, extractedLoop, _claw.value(ClawClause.ACC));
}
if (_claw.hasClause(ClawClause.PARALLEL)) {
Directive.generateParallelRegion(xcodeml, (grip == null) ? extractedLoop : grip, extractedLoop);
}
// TODO must be triggered by a clause
// Directive.generateRoutineDirectives(_claw, xcodeml, clonedFctDef);
// Add any additional transformation defined in the directive clauses
ct.generateAdditionalTransformation(_claw, xcodeml, extractedLoop);
removePragma();
transformed();
}
use of claw.wani.language.ClawMapping in project claw-compiler by C2SM-RCM.
the class ClawPragmaTest method extractTest.
/**
* Test various input for the CLAW loop extract directive.
*/
@Test
public void extractTest() {
// Valid directives
ClawPragma l = analyzeValidClawLoopExtract("claw loop-extract range(i=istart,iend) map(i:j)", "i", "istart", "iend", null, null);
assertNotNull(l);
assertEquals(1, l.getMappings().size());
assertNotNull(l.getMappings().get(0));
ClawMapping map = l.getMappings().get(0);
assertEquals(1, map.getMappedVariables().size());
assertEquals(1, map.getMappingVariables().size());
assertEquals("i", map.getMappedVariables().get(0).getArgMapping());
assertEquals("i", map.getMappedVariables().get(0).getFctMapping());
assertFalse(map.getMappedVariables().get(0).hasDifferentMapping());
assertEquals("j", map.getMappingVariables().get(0).getArgMapping());
assertEquals("j", map.getMappingVariables().get(0).getFctMapping());
assertFalse(map.getMappingVariables().get(0).hasDifferentMapping());
l = analyzeValidClawLoopExtract("claw loop-extract range(i=istart,iend,2) map(i:j)", "i", "istart", "iend", "2", null);
assertNotNull(l);
map = l.getMappings().get(0);
assertEquals(1, map.getMappedVariables().size());
assertEquals(1, map.getMappingVariables().size());
assertEquals("i", map.getMappedVariables().get(0).getArgMapping());
assertEquals("i", map.getMappedVariables().get(0).getFctMapping());
assertFalse(map.getMappedVariables().get(0).hasDifferentMapping());
assertEquals("j", map.getMappingVariables().get(0).getArgMapping());
assertEquals("j", map.getMappingVariables().get(0).getFctMapping());
assertFalse(map.getMappingVariables().get(0).hasDifferentMapping());
l = analyzeValidClawLoopExtract("claw loop-extract range(i=1,10) map(i:j)", "i", "1", "10", null, null);
assertNotNull(l);
map = l.getMappings().get(0);
assertEquals(1, map.getMappedVariables().size());
assertEquals(1, map.getMappingVariables().size());
assertEquals("i", map.getMappedVariables().get(0).getArgMapping());
assertEquals("i", map.getMappedVariables().get(0).getFctMapping());
assertFalse(map.getMappedVariables().get(0).hasDifferentMapping());
assertEquals("j", map.getMappingVariables().get(0).getArgMapping());
assertEquals("j", map.getMappingVariables().get(0).getFctMapping());
assertFalse(map.getMappingVariables().get(0).hasDifferentMapping());
l = analyzeValidClawLoopExtract("claw loop-extract range(i=1,10,2) map(i:j) parallel", "i", "1", "10", "2", null);
assertNotNull(l);
map = l.getMappings().get(0);
assertTrue(l.hasClause(ClawClause.PARALLEL));
assertEquals(1, map.getMappedVariables().size());
assertEquals(1, map.getMappingVariables().size());
assertEquals("i", map.getMappedVariables().get(0).getArgMapping());
assertEquals("i", map.getMappedVariables().get(0).getFctMapping());
assertFalse(map.getMappedVariables().get(0).hasDifferentMapping());
assertEquals("j", map.getMappingVariables().get(0).getArgMapping());
assertEquals("j", map.getMappingVariables().get(0).getFctMapping());
assertFalse(map.getMappingVariables().get(0).hasDifferentMapping());
l = analyzeValidClawLoopExtract("claw loop-extract range(i=istart,iend) map(i:j) fusion", "i", "istart", "iend", null, null);
assertNotNull(l);
assertEquals(1, l.getMappings().size());
assertNotNull(l.getMappings().get(0));
assertTrue(l.hasClause(ClawClause.FUSION));
assertFalse(l.hasClause(ClawClause.GROUP));
assertFalse(l.hasClause(ClawClause.PARALLEL));
map = l.getMappings().get(0);
assertEquals(1, map.getMappedVariables().size());
assertEquals(1, map.getMappingVariables().size());
assertEquals("i", map.getMappedVariables().get(0).getArgMapping());
assertEquals("i", map.getMappedVariables().get(0).getFctMapping());
assertFalse(map.getMappedVariables().get(0).hasDifferentMapping());
assertEquals("j", map.getMappingVariables().get(0).getArgMapping());
assertEquals("j", map.getMappingVariables().get(0).getFctMapping());
assertFalse(map.getMappingVariables().get(0).hasDifferentMapping());
l = analyzeValidClawLoopExtract("claw loop-extract range(i=istart,iend) map(i:j) fusion group(j1)", "i", "istart", "iend", null, null);
assertNotNull(l);
assertEquals(1, l.getMappings().size());
assertNotNull(l.getMappings().get(0));
assertTrue(l.hasClause(ClawClause.FUSION));
assertTrue(l.hasClause(ClawClause.GROUP));
assertEquals("j1", l.value(ClawClause.GROUP));
map = l.getMappings().get(0);
assertEquals(1, map.getMappedVariables().size());
assertEquals(1, map.getMappingVariables().size());
assertEquals("i", map.getMappedVariables().get(0).getArgMapping());
assertEquals("i", map.getMappedVariables().get(0).getFctMapping());
assertFalse(map.getMappedVariables().get(0).hasDifferentMapping());
assertEquals("j", map.getMappingVariables().get(0).getArgMapping());
assertEquals("j", map.getMappingVariables().get(0).getFctMapping());
assertFalse(map.getMappingVariables().get(0).hasDifferentMapping());
l = analyzeValidClawLoopExtract("claw loop-extract range(i=istart,iend) map(i:j) fusion group(j1) " + "acc(loop gang vector)", "i", "istart", "iend", null, null);
assertNotNull(l);
assertEquals(1, l.getMappings().size());
assertNotNull(l.getMappings().get(0));
assertTrue(l.hasClause(ClawClause.FUSION));
assertTrue(l.hasClause(ClawClause.GROUP));
assertTrue(l.hasClause(ClawClause.ACC));
assertEquals("loop gang vector", l.value(ClawClause.ACC));
assertEquals("j1", l.value(ClawClause.GROUP));
map = l.getMappings().get(0);
assertEquals(1, map.getMappedVariables().size());
assertEquals(1, map.getMappingVariables().size());
assertEquals("i", map.getMappedVariables().get(0).getArgMapping());
assertEquals("i", map.getMappedVariables().get(0).getFctMapping());
assertFalse(map.getMappedVariables().get(0).hasDifferentMapping());
assertEquals("j", map.getMappingVariables().get(0).getArgMapping());
assertEquals("j", map.getMappingVariables().get(0).getFctMapping());
assertFalse(map.getMappingVariables().get(0).hasDifferentMapping());
l = analyzeValidClawLoopExtract("claw loop-extract range(j1=ki1sc,ki1ec) " + "map(pduh2oc,pduh2of:j1,ki3sc/j3) " + "map(pduco2,pduo3,palogp,palogt,podsc,podsf,podac," + "podaf:j1,ki3sc/j3) " + "map(pbsff,pbsfc:j1,ki3sc/j3) " + "map(pa1c,pa1f,pa2c,pa2f,pa3c,pa3f:j1) " + "fusion group(coeth-j1) parallel acc(loop gang vector)", "j1", "ki1sc", "ki1ec", null, null);
assertNotNull(l);
assertEquals(4, l.getMappings().size());
ClawMapping map1 = l.getMappings().get(0);
assertNotNull(map1);
assertEquals(2, map1.getMappedVariables().size());
assertEquals(2, map1.getMappingVariables().size());
assertEquals("pduh2oc", map1.getMappedVariables().get(0).getArgMapping());
assertEquals("pduh2oc", map1.getMappedVariables().get(0).getFctMapping());
assertEquals("pduh2of", map1.getMappedVariables().get(1).getArgMapping());
assertEquals("pduh2of", map1.getMappedVariables().get(1).getFctMapping());
assertEquals("j1", map1.getMappingVariables().get(0).getArgMapping());
assertEquals("j1", map1.getMappingVariables().get(0).getFctMapping());
assertEquals("ki3sc", map1.getMappingVariables().get(1).getArgMapping());
assertEquals("j3", map1.getMappingVariables().get(1).getFctMapping());
ClawMapping map2 = l.getMappings().get(1);
assertNotNull(map2);
assertEquals(8, map2.getMappedVariables().size());
assertEquals(2, map2.getMappingVariables().size());
assertEquals("pduco2", map2.getMappedVariables().get(0).getArgMapping());
assertEquals("pduco2", map2.getMappedVariables().get(0).getFctMapping());
assertEquals("pduo3", map2.getMappedVariables().get(1).getArgMapping());
assertEquals("pduo3", map2.getMappedVariables().get(1).getFctMapping());
assertEquals("palogp", map2.getMappedVariables().get(2).getArgMapping());
assertEquals("palogp", map2.getMappedVariables().get(2).getFctMapping());
assertEquals("palogt", map2.getMappedVariables().get(3).getArgMapping());
assertEquals("palogt", map2.getMappedVariables().get(3).getFctMapping());
assertEquals("podsc", map2.getMappedVariables().get(4).getArgMapping());
assertEquals("podsc", map2.getMappedVariables().get(4).getFctMapping());
assertEquals("podsf", map2.getMappedVariables().get(5).getArgMapping());
assertEquals("podsf", map2.getMappedVariables().get(5).getFctMapping());
assertEquals("podac", map2.getMappedVariables().get(6).getArgMapping());
assertEquals("podac", map2.getMappedVariables().get(6).getFctMapping());
assertEquals("podaf", map2.getMappedVariables().get(7).getArgMapping());
assertEquals("podaf", map2.getMappedVariables().get(7).getFctMapping());
assertEquals("j1", map2.getMappingVariables().get(0).getArgMapping());
assertEquals("j1", map2.getMappingVariables().get(0).getFctMapping());
assertEquals("ki3sc", map2.getMappingVariables().get(1).getArgMapping());
assertEquals("j3", map2.getMappingVariables().get(1).getFctMapping());
ClawMapping map3 = l.getMappings().get(2);
assertNotNull(map3);
assertEquals(2, map3.getMappedVariables().size());
assertEquals(2, map3.getMappingVariables().size());
assertEquals("pbsff", map3.getMappedVariables().get(0).getArgMapping());
assertEquals("pbsff", map3.getMappedVariables().get(0).getFctMapping());
assertEquals("pbsfc", map3.getMappedVariables().get(1).getArgMapping());
assertEquals("pbsfc", map3.getMappedVariables().get(1).getFctMapping());
assertEquals("j1", map3.getMappingVariables().get(0).getArgMapping());
assertEquals("j1", map3.getMappingVariables().get(0).getFctMapping());
assertEquals("ki3sc", map3.getMappingVariables().get(1).getArgMapping());
assertEquals("j3", map3.getMappingVariables().get(1).getFctMapping());
ClawMapping map4 = l.getMappings().get(3);
assertNotNull(map4);
assertEquals(6, map4.getMappedVariables().size());
assertEquals(1, map4.getMappingVariables().size());
assertEquals("pa1c", map4.getMappedVariables().get(0).getArgMapping());
assertEquals("pa1c", map4.getMappedVariables().get(0).getFctMapping());
assertEquals("pa1f", map4.getMappedVariables().get(1).getArgMapping());
assertEquals("pa1f", map4.getMappedVariables().get(1).getFctMapping());
assertEquals("pa2c", map4.getMappedVariables().get(2).getArgMapping());
assertEquals("pa2c", map4.getMappedVariables().get(2).getFctMapping());
assertEquals("pa2f", map4.getMappedVariables().get(3).getArgMapping());
assertEquals("pa2f", map4.getMappedVariables().get(3).getFctMapping());
assertEquals("pa3c", map4.getMappedVariables().get(4).getArgMapping());
assertEquals("pa3c", map4.getMappedVariables().get(4).getFctMapping());
assertEquals("pa3f", map4.getMappedVariables().get(5).getArgMapping());
assertEquals("pa3f", map4.getMappedVariables().get(5).getFctMapping());
assertEquals("j1", map4.getMappingVariables().get(0).getArgMapping());
assertEquals("j1", map4.getMappingVariables().get(0).getFctMapping());
assertTrue(l.hasClause(ClawClause.FUSION));
assertTrue(l.hasClause(ClawClause.GROUP));
assertEquals("coeth-j1", l.value(ClawClause.GROUP));
assertTrue(l.hasClause(ClawClause.ACC));
assertEquals("loop gang vector", l.value(ClawClause.ACC));
analyzeValidClawLoopExtract("claw loop-extract range(i=istart,iend) map(i:j) target(gpu) fusion " + "group(j1)", "i", "istart", "iend", null, Collections.singletonList(Target.GPU));
analyzeValidClawLoopExtract("claw loop-extract range(i=istart,iend) map(i:j) fusion group(j1) " + "target(gpu)", "i", "istart", "iend", null, Collections.singletonList(Target.GPU));
// Invalid directives
analyzeInvalidClawLanguage("claw loop-extract");
analyzeInvalidClawLanguage("claw loop - extract ");
}
Aggregations