Search in sources :

Example 1 with ClawMapping

use of cx2x.translator.language.common.ClawMapping in project claw-compiler by C2SM-RCM.

the class LoopExtraction method transform.

/**
   * Apply the transformation. A loop extraction is applied in the following
   * steps:
   * 1) Duplicate the function targeted by the transformation
   * 2) Extract the loop body in the duplicated function and remove the loop.
   * 3) Adapt function call and demote array references in the duplicated
   * function body.
   * 4) Optional: Add a LoopFusion transformation to the transformations'
   * queue.
   *
   * @param xcodeml        The XcodeML on which the transformations are applied.
   * @param transformer    The transformer used to applied the transformations.
   * @param transformation Only for dependent transformation. The other
   *                       transformation part of the transformation.
   * @throws IllegalTransformationException if the transformation cannot be
   *                                        applied.
   */
@Override
public void transform(XcodeProgram xcodeml, Transformer transformer, Transformation transformation) throws Exception {
    /*
     * DUPLICATE THE FUNCTION
     */
    // Duplicate function definition
    XfunctionDefinition clonedFctDef = _fctDefToExtract.cloneNode();
    String newFctTypeHash = xcodeml.getTypeTable().generateFctTypeHash();
    String newFctName = clonedFctDef.getName().value() + ClawConstant.EXTRACTION_SUFFIX + transformer.getNextTransformationCounter();
    clonedFctDef.getName().setValue(newFctName);
    clonedFctDef.getName().setAttribute(Xattr.TYPE, newFctTypeHash);
    // Update the symbol table in the fct definition
    Xid fctId = clonedFctDef.getSymbolTable().get(_fctDefToExtract.getName().value());
    fctId.setType(newFctTypeHash);
    fctId.setName(newFctName);
    // Get the fctType in typeTable
    XfunctionType fctType = (XfunctionType) xcodeml.getTypeTable().get(_fctDefToExtract.getName().getAttribute(Xattr.TYPE));
    XfunctionType newFctType = fctType.cloneNode();
    newFctType.setType(newFctTypeHash);
    xcodeml.getTypeTable().add(newFctType);
    // Get the id from the global symbols table
    Xid globalFctId = xcodeml.getGlobalSymbolsTable().get(_fctDefToExtract.getName().value());
    // If the fct is define in the global symbol table, duplicate it
    if (globalFctId != null) {
        Xid newFctId = globalFctId.cloneNode();
        newFctId.setType(newFctTypeHash);
        newFctId.setName(newFctName);
        xcodeml.getGlobalSymbolsTable().add(newFctId);
    }
    // Insert the duplicated function declaration
    _fctDefToExtract.insertAfter(clonedFctDef);
    // Find the loop that will be extracted
    Xnode loopInClonedFct = locateDoStatement(clonedFctDef);
    if (XmOption.isDebugOutput()) {
        System.out.println("loop-extract transformation: " + _claw.getPragma().value());
        System.out.println("  created subroutine: " + clonedFctDef.getName().value());
    }
    /*
     * REMOVE BODY FROM THE LOOP AND DELETE THE LOOP
     */
    // 1. append body into fct body after loop
    XnodeUtil.extractBody(loopInClonedFct);
    // 2. delete loop
    loopInClonedFct.delete();
    /*
     * ADAPT FUNCTION CALL AND DEMOTE ARRAY REFERENCES IN THE BODY
     * OF THE FUNCTION
     */
    // Wrap function call with loop
    Xnode extractedLoop = wrapCallWithLoop(xcodeml, _extractedLoop);
    if (XmOption.isDebugOutput()) {
        System.out.println("  call wrapped with loop: " + _fctCall.matchDirectDescendant(Xcode.NAME).value() + " --> " + clonedFctDef.getName().value());
    }
    // Change called fct name
    _fctCall.matchDirectDescendant(Xcode.NAME).setValue(newFctName);
    _fctCall.matchDirectDescendant(Xcode.NAME).setAttribute(Xattr.TYPE, newFctTypeHash);
    // Adapt function call parameters and function declaration
    XdeclTable fctDeclarations = clonedFctDef.getDeclarationTable();
    XsymbolTable fctSymbols = clonedFctDef.getSymbolTable();
    Utility.debug("  Start to apply mapping: " + _claw.getMappings().size());
    for (ClawMapping mapping : _claw.getMappings()) {
        Utility.debug("Apply mapping (" + mapping.getMappedDimensions() + ") ");
        for (ClawMappingVar var : mapping.getMappedVariables()) {
            Utility.debug("  Var: " + var);
            Xnode argument = XnodeUtil.findArg(var.getArgMapping(), _fctCall);
            if (argument == null) {
                continue;
            }
            /* Case 1: Var --> ArrayRef
         * Var --> ArrayRef transformation
         * 1. Check that the variable used as array index exists in the
         *    current scope (XdeclTable). If so, get its type value. Create a
         *    Var element for the arrayIndex. Create the arrayIndex element
         *    with Var as child.
         *
         * 2. Get the reference type of the base variable.
         *    2.1 Create the varRef element with the type of base variable
         *    2.2 insert clone of base variable in varRef
         * 3. Create arrayRef element with varRef + arrayIndex
         */
            if (argument.opcode() == Xcode.VAR) {
                XbasicType type = (XbasicType) xcodeml.getTypeTable().get(argument.getAttribute(Xattr.TYPE));
                // Demotion cannot be applied as type dimension is smaller
                if (type.getDimensions() < mapping.getMappedDimensions()) {
                    throw new IllegalTransformationException("mapping dimensions too big. Mapping " + mapping.toString() + " is wrong ...", _claw.getPragma().lineNo());
                }
                Xnode newArg = new Xnode(Xcode.FARRAYREF, xcodeml);
                newArg.setAttribute(Xattr.TYPE, type.getRef());
                Xnode varRef = new Xnode(Xcode.VARREF, xcodeml);
                varRef.setAttribute(Xattr.TYPE, argument.getAttribute(Xattr.TYPE));
                varRef.append(argument, true);
                newArg.append(varRef, false);
                //  create arrayIndex
                for (ClawMappingVar mappingVar : mapping.getMappingVariables()) {
                    Xnode arrayIndex = new Xnode(Xcode.ARRAYINDEX, xcodeml);
                    // Find the mapping var in the local table (fct scope)
                    Xdecl mappingVarDecl = _fctDef.getDeclarationTable().get(mappingVar.getArgMapping());
                    // Add to arrayIndex
                    Xnode newMappingVar = new Xnode(Xcode.VAR, xcodeml);
                    newMappingVar.setAttribute(Xattr.SCLASS, Xscope.LOCAL.toString());
                    newMappingVar.setAttribute(Xattr.TYPE, mappingVarDecl.matchSeq(Xcode.NAME).getAttribute(Xattr.TYPE));
                    newMappingVar.setValue(mappingVarDecl.matchSeq(Xcode.NAME).value());
                    arrayIndex.append(newMappingVar, false);
                    newArg.append(arrayIndex, false);
                }
                argument.insertAfter(newArg);
                argument.delete();
            }
            // Case 2: ArrayRef (n arrayIndex) --> ArrayRef (n+m arrayIndex)
            /*else if(argument.opcode() == Xcode.FARRAYREF) {
          // TODO
        }*/
            // Change variable declaration in extracted fct
            Xdecl varDecl = fctDeclarations.get(var.getFctMapping());
            Xid id = fctSymbols.get(var.getFctMapping());
            XbasicType varDeclType = (XbasicType) xcodeml.getTypeTable().get(varDecl.matchSeq(Xcode.NAME).getAttribute(Xattr.TYPE));
            // Case 1: variable is demoted to scalar then take the ref type
            if (varDeclType.getDimensions() == mapping.getMappedDimensions()) {
                Xnode tempName = new Xnode(Xcode.NAME, xcodeml);
                tempName.setValue(var.getFctMapping());
                tempName.setAttribute(Xattr.TYPE, varDeclType.getRef());
                Xdecl newVarDecl = new Xdecl(new Xnode(Xcode.VARDECL, xcodeml).element());
                newVarDecl.append(tempName, false);
                fctDeclarations.replace(newVarDecl, var.getFctMapping());
                id.setType(varDeclType.getRef());
            }
        /* else {
          // Case 2: variable is not totally demoted then create new type
          // TODO

        }*/
        }
    // Loop mapped variables
    }
    // Loop over mapping clauses
    // Adapt array reference in function body
    List<Xnode> arrayReferences = clonedFctDef.body().matchAll(Xcode.FARRAYREF);
    for (Xnode ref : arrayReferences) {
        if (!(ref.matchSeq(Xcode.VARREF).child(0).opcode() == Xcode.VAR)) {
            continue;
        }
        String mappedVar = ref.matchSeq(Xcode.VARREF, Xcode.VAR).value();
        if (_fctMappingMap.containsKey(mappedVar)) {
            ClawMapping mapping = _fctMappingMap.get(mappedVar);
            boolean changeRef = true;
            int mappingIndex = 0;
            for (Xnode e : ref.children()) {
                if (e.opcode() == Xcode.ARRAYINDEX) {
                    List<Xnode> children = e.children();
                    if (children.size() > 0 && children.get(0).opcode() == Xcode.VAR) {
                        String varName = e.matchSeq(Xcode.VAR).value();
                        if (varName.equals(mapping.getMappingVariables().get(mappingIndex).getFctMapping())) {
                            ++mappingIndex;
                        } else {
                            changeRef = false;
                        }
                    }
                }
            }
            if (changeRef) {
                // TODO Var ref should be extracted only if the reference can be
                // totally demoted
                ref.insertBefore(ref.matchSeq(Xcode.VARREF, Xcode.VAR).cloneNode());
                ref.delete();
            }
        }
    }
    // Generate accelerator pragmas if needed
    AcceleratorHelper.generateAdditionalDirectives(_claw, xcodeml, extractedLoop, extractedLoop);
    // TODO must be triggered by a clause
    //AcceleratorHelper.generateRoutineDirectives(_claw, xcodeml, clonedFctDef);
    // Add any additional transformation defined in the directive clauses
    TransformationHelper.generateAdditionalTransformation(_claw, xcodeml, transformer, extractedLoop);
    _claw.getPragma().delete();
    this.transformed();
}
Also used : ClawMappingVar(cx2x.translator.language.common.ClawMappingVar) IllegalTransformationException(cx2x.xcodeml.exception.IllegalTransformationException) ClawMapping(cx2x.translator.language.common.ClawMapping)

Example 2 with ClawMapping

use of cx2x.translator.language.common.ClawMapping in project claw-compiler by C2SM-RCM.

the class ClawLanguageTest method extractTest.

/**
   * Test various input for the CLAW loop extract directive.
   */
@Test
public void extractTest() {
    // Valid directives
    ClawLanguage l = analyzeValidClawLoopExtract("claw loop-extract range(i=istart,iend) map(i:j)", "i", "istart", "iend", null, null);
    assertNotNull(l);
    assertEquals(1, l.getMappings().size());
    assertNotNull(l.getMappings().get(0));
    ClawMapping map = l.getMappings().get(0);
    assertEquals(1, map.getMappedVariables().size());
    assertEquals(1, map.getMappingVariables().size());
    assertEquals("i", map.getMappedVariables().get(0).getArgMapping());
    assertEquals("i", map.getMappedVariables().get(0).getFctMapping());
    assertFalse(map.getMappedVariables().get(0).hasDifferentMapping());
    assertEquals("j", map.getMappingVariables().get(0).getArgMapping());
    assertEquals("j", map.getMappingVariables().get(0).getFctMapping());
    assertFalse(map.getMappingVariables().get(0).hasDifferentMapping());
    l = analyzeValidClawLoopExtract("claw loop-extract range(i=istart,iend,2) map(i:j)", "i", "istart", "iend", "2", null);
    assertNotNull(l);
    map = l.getMappings().get(0);
    assertEquals(1, map.getMappedVariables().size());
    assertEquals(1, map.getMappingVariables().size());
    assertEquals("i", map.getMappedVariables().get(0).getArgMapping());
    assertEquals("i", map.getMappedVariables().get(0).getFctMapping());
    assertFalse(map.getMappedVariables().get(0).hasDifferentMapping());
    assertEquals("j", map.getMappingVariables().get(0).getArgMapping());
    assertEquals("j", map.getMappingVariables().get(0).getFctMapping());
    assertFalse(map.getMappingVariables().get(0).hasDifferentMapping());
    l = analyzeValidClawLoopExtract("claw loop-extract range(i=1,10) map(i:j)", "i", "1", "10", null, null);
    assertNotNull(l);
    map = l.getMappings().get(0);
    assertEquals(1, map.getMappedVariables().size());
    assertEquals(1, map.getMappingVariables().size());
    assertEquals("i", map.getMappedVariables().get(0).getArgMapping());
    assertEquals("i", map.getMappedVariables().get(0).getFctMapping());
    assertFalse(map.getMappedVariables().get(0).hasDifferentMapping());
    assertEquals("j", map.getMappingVariables().get(0).getArgMapping());
    assertEquals("j", map.getMappingVariables().get(0).getFctMapping());
    assertFalse(map.getMappingVariables().get(0).hasDifferentMapping());
    l = analyzeValidClawLoopExtract("claw loop-extract range(i=1,10,2) map(i:j) parallel", "i", "1", "10", "2", null);
    assertNotNull(l);
    map = l.getMappings().get(0);
    assertTrue(l.hasParallelClause());
    assertEquals(1, map.getMappedVariables().size());
    assertEquals(1, map.getMappingVariables().size());
    assertEquals("i", map.getMappedVariables().get(0).getArgMapping());
    assertEquals("i", map.getMappedVariables().get(0).getFctMapping());
    assertFalse(map.getMappedVariables().get(0).hasDifferentMapping());
    assertEquals("j", map.getMappingVariables().get(0).getArgMapping());
    assertEquals("j", map.getMappingVariables().get(0).getFctMapping());
    assertFalse(map.getMappingVariables().get(0).hasDifferentMapping());
    l = analyzeValidClawLoopExtract("claw loop-extract range(i=istart,iend) map(i:j) fusion", "i", "istart", "iend", null, null);
    assertNotNull(l);
    assertEquals(1, l.getMappings().size());
    assertNotNull(l.getMappings().get(0));
    assertTrue(l.hasFusionClause());
    assertFalse(l.hasGroupClause());
    assertFalse(l.hasParallelClause());
    map = l.getMappings().get(0);
    assertEquals(1, map.getMappedVariables().size());
    assertEquals(1, map.getMappingVariables().size());
    assertEquals("i", map.getMappedVariables().get(0).getArgMapping());
    assertEquals("i", map.getMappedVariables().get(0).getFctMapping());
    assertFalse(map.getMappedVariables().get(0).hasDifferentMapping());
    assertEquals("j", map.getMappingVariables().get(0).getArgMapping());
    assertEquals("j", map.getMappingVariables().get(0).getFctMapping());
    assertFalse(map.getMappingVariables().get(0).hasDifferentMapping());
    l = analyzeValidClawLoopExtract("claw loop-extract range(i=istart,iend) map(i:j) fusion group(j1)", "i", "istart", "iend", null, null);
    assertNotNull(l);
    assertEquals(1, l.getMappings().size());
    assertNotNull(l.getMappings().get(0));
    assertTrue(l.hasFusionClause());
    assertTrue(l.hasGroupClause());
    assertEquals("j1", l.getGroupValue());
    map = l.getMappings().get(0);
    assertEquals(1, map.getMappedVariables().size());
    assertEquals(1, map.getMappingVariables().size());
    assertEquals("i", map.getMappedVariables().get(0).getArgMapping());
    assertEquals("i", map.getMappedVariables().get(0).getFctMapping());
    assertFalse(map.getMappedVariables().get(0).hasDifferentMapping());
    assertEquals("j", map.getMappingVariables().get(0).getArgMapping());
    assertEquals("j", map.getMappingVariables().get(0).getFctMapping());
    assertFalse(map.getMappingVariables().get(0).hasDifferentMapping());
    l = analyzeValidClawLoopExtract("claw loop-extract range(i=istart,iend) map(i:j) fusion group(j1) " + "acc(loop gang vector)", "i", "istart", "iend", null, null);
    assertNotNull(l);
    assertEquals(1, l.getMappings().size());
    assertNotNull(l.getMappings().get(0));
    assertTrue(l.hasFusionClause());
    assertTrue(l.hasGroupClause());
    assertTrue(l.hasAcceleratorClause());
    assertEquals("loop gang vector", l.getAcceleratorClauses());
    assertEquals("j1", l.getGroupValue());
    map = l.getMappings().get(0);
    assertEquals(1, map.getMappedVariables().size());
    assertEquals(1, map.getMappingVariables().size());
    assertEquals("i", map.getMappedVariables().get(0).getArgMapping());
    assertEquals("i", map.getMappedVariables().get(0).getFctMapping());
    assertFalse(map.getMappedVariables().get(0).hasDifferentMapping());
    assertEquals("j", map.getMappingVariables().get(0).getArgMapping());
    assertEquals("j", map.getMappingVariables().get(0).getFctMapping());
    assertFalse(map.getMappingVariables().get(0).hasDifferentMapping());
    l = analyzeValidClawLoopExtract("claw loop-extract range(j1=ki1sc,ki1ec) " + "map(pduh2oc,pduh2of:j1,ki3sc/j3) " + "map(pduco2,pduo3,palogp,palogt,podsc,podsf,podac,podaf:j1,ki3sc/j3) " + "map(pbsff,pbsfc:j1,ki3sc/j3) map(pa1c,pa1f,pa2c,pa2f,pa3c,pa3f:j1) " + "fusion group(coeth-j1) parallel acc(loop gang vector)", "j1", "ki1sc", "ki1ec", null, null);
    assertNotNull(l);
    assertEquals(4, l.getMappings().size());
    ClawMapping map1 = l.getMappings().get(0);
    assertNotNull(map1);
    assertEquals(2, map1.getMappedVariables().size());
    assertEquals(2, map1.getMappingVariables().size());
    assertEquals("pduh2oc", map1.getMappedVariables().get(0).getArgMapping());
    assertEquals("pduh2oc", map1.getMappedVariables().get(0).getFctMapping());
    assertEquals("pduh2of", map1.getMappedVariables().get(1).getArgMapping());
    assertEquals("pduh2of", map1.getMappedVariables().get(1).getFctMapping());
    assertEquals("j1", map1.getMappingVariables().get(0).getArgMapping());
    assertEquals("j1", map1.getMappingVariables().get(0).getFctMapping());
    assertEquals("ki3sc", map1.getMappingVariables().get(1).getArgMapping());
    assertEquals("j3", map1.getMappingVariables().get(1).getFctMapping());
    ClawMapping map2 = l.getMappings().get(1);
    assertNotNull(map2);
    assertEquals(8, map2.getMappedVariables().size());
    assertEquals(2, map2.getMappingVariables().size());
    assertEquals("pduco2", map2.getMappedVariables().get(0).getArgMapping());
    assertEquals("pduco2", map2.getMappedVariables().get(0).getFctMapping());
    assertEquals("pduo3", map2.getMappedVariables().get(1).getArgMapping());
    assertEquals("pduo3", map2.getMappedVariables().get(1).getFctMapping());
    assertEquals("palogp", map2.getMappedVariables().get(2).getArgMapping());
    assertEquals("palogp", map2.getMappedVariables().get(2).getFctMapping());
    assertEquals("palogt", map2.getMappedVariables().get(3).getArgMapping());
    assertEquals("palogt", map2.getMappedVariables().get(3).getFctMapping());
    assertEquals("podsc", map2.getMappedVariables().get(4).getArgMapping());
    assertEquals("podsc", map2.getMappedVariables().get(4).getFctMapping());
    assertEquals("podsf", map2.getMappedVariables().get(5).getArgMapping());
    assertEquals("podsf", map2.getMappedVariables().get(5).getFctMapping());
    assertEquals("podac", map2.getMappedVariables().get(6).getArgMapping());
    assertEquals("podac", map2.getMappedVariables().get(6).getFctMapping());
    assertEquals("podaf", map2.getMappedVariables().get(7).getArgMapping());
    assertEquals("podaf", map2.getMappedVariables().get(7).getFctMapping());
    assertEquals("j1", map2.getMappingVariables().get(0).getArgMapping());
    assertEquals("j1", map2.getMappingVariables().get(0).getFctMapping());
    assertEquals("ki3sc", map2.getMappingVariables().get(1).getArgMapping());
    assertEquals("j3", map2.getMappingVariables().get(1).getFctMapping());
    ClawMapping map3 = l.getMappings().get(2);
    assertNotNull(map3);
    assertEquals(2, map3.getMappedVariables().size());
    assertEquals(2, map3.getMappingVariables().size());
    assertEquals("pbsff", map3.getMappedVariables().get(0).getArgMapping());
    assertEquals("pbsff", map3.getMappedVariables().get(0).getFctMapping());
    assertEquals("pbsfc", map3.getMappedVariables().get(1).getArgMapping());
    assertEquals("pbsfc", map3.getMappedVariables().get(1).getFctMapping());
    assertEquals("j1", map3.getMappingVariables().get(0).getArgMapping());
    assertEquals("j1", map3.getMappingVariables().get(0).getFctMapping());
    assertEquals("ki3sc", map3.getMappingVariables().get(1).getArgMapping());
    assertEquals("j3", map3.getMappingVariables().get(1).getFctMapping());
    ClawMapping map4 = l.getMappings().get(3);
    assertNotNull(map4);
    assertEquals(6, map4.getMappedVariables().size());
    assertEquals(1, map4.getMappingVariables().size());
    assertEquals("pa1c", map4.getMappedVariables().get(0).getArgMapping());
    assertEquals("pa1c", map4.getMappedVariables().get(0).getFctMapping());
    assertEquals("pa1f", map4.getMappedVariables().get(1).getArgMapping());
    assertEquals("pa1f", map4.getMappedVariables().get(1).getFctMapping());
    assertEquals("pa2c", map4.getMappedVariables().get(2).getArgMapping());
    assertEquals("pa2c", map4.getMappedVariables().get(2).getFctMapping());
    assertEquals("pa2f", map4.getMappedVariables().get(3).getArgMapping());
    assertEquals("pa2f", map4.getMappedVariables().get(3).getFctMapping());
    assertEquals("pa3c", map4.getMappedVariables().get(4).getArgMapping());
    assertEquals("pa3c", map4.getMappedVariables().get(4).getFctMapping());
    assertEquals("pa3f", map4.getMappedVariables().get(5).getArgMapping());
    assertEquals("pa3f", map4.getMappedVariables().get(5).getFctMapping());
    assertEquals("j1", map4.getMappingVariables().get(0).getArgMapping());
    assertEquals("j1", map4.getMappingVariables().get(0).getFctMapping());
    assertTrue(l.hasFusionClause());
    assertTrue(l.hasGroupClause());
    assertEquals("coeth-j1", l.getGroupValue());
    assertTrue(l.hasAcceleratorClause());
    assertEquals("loop gang vector", l.getAcceleratorClauses());
    analyzeValidClawLoopExtract("claw loop-extract range(i=istart,iend) map(i:j) target(gpu) fusion " + "group(j1)", "i", "istart", "iend", null, Collections.singletonList(Target.GPU));
    analyzeValidClawLoopExtract("claw loop-extract range(i=istart,iend) map(i:j) fusion group(j1) " + "target(gpu)", "i", "istart", "iend", null, Collections.singletonList(Target.GPU));
    // Unvalid directives
    analyzeUnvalidClawLanguage("claw loop-extract");
    analyzeUnvalidClawLanguage("claw loop   -   extract ");
}
Also used : ClawMapping(cx2x.translator.language.common.ClawMapping) ClawLanguage(cx2x.translator.language.base.ClawLanguage) Test(org.junit.Test)

Aggregations

ClawMapping (cx2x.translator.language.common.ClawMapping)2 ClawLanguage (cx2x.translator.language.base.ClawLanguage)1 ClawMappingVar (cx2x.translator.language.common.ClawMappingVar)1 IllegalTransformationException (cx2x.xcodeml.exception.IllegalTransformationException)1 Test (org.junit.Test)1