Search in sources :

Example 11 with CNodeData

use of org.apache.sysml.hops.codegen.cplan.CNodeData in project systemml by apache.

the class SpoofCompiler method cleanupCPlans.

/**
 * Cleanup generated cplans in order to remove unnecessary inputs created
 * during incremental construction. This is important as it avoids unnecessary
 * redundant computation.
 *
 * @param memo memoization table
 * @param cplans set of cplans
 */
private static HashMap<Long, Pair<Hop[], CNodeTpl>> cleanupCPlans(CPlanMemoTable memo, HashMap<Long, Pair<Hop[], CNodeTpl>> cplans) {
    HashMap<Long, Pair<Hop[], CNodeTpl>> cplans2 = new HashMap<>();
    CPlanOpRewriter rewriter = new CPlanOpRewriter();
    CPlanCSERewriter cse = new CPlanCSERewriter();
    for (Entry<Long, Pair<Hop[], CNodeTpl>> e : cplans.entrySet()) {
        CNodeTpl tpl = e.getValue().getValue();
        Hop[] inHops = e.getValue().getKey();
        // remove invalid plans with null inputs
        if (Arrays.stream(inHops).anyMatch(h -> (h == null)))
            continue;
        // perform simplifications and cse rewrites
        tpl = rewriter.simplifyCPlan(tpl);
        tpl = cse.eliminateCommonSubexpressions(tpl);
        // update input hops (order-preserving)
        HashSet<Long> inputHopIDs = tpl.getInputHopIDs(false);
        inHops = Arrays.stream(inHops).filter(p -> p != null && inputHopIDs.contains(p.getHopID())).toArray(Hop[]::new);
        cplans2.put(e.getKey(), new Pair<>(inHops, tpl));
        // remove invalid plans with column indexing on main input
        if (tpl instanceof CNodeCell || tpl instanceof CNodeRow) {
            CNodeData in1 = (CNodeData) tpl.getInput().get(0);
            boolean inclRC1 = !(tpl instanceof CNodeRow);
            if (rHasLookupRC1(tpl.getOutput(), in1, inclRC1) || isLookupRC1(tpl.getOutput(), in1, inclRC1)) {
                cplans2.remove(e.getKey());
                if (LOG.isTraceEnabled())
                    LOG.trace("Removed cplan due to invalid rc1 indexing on main input.");
            }
        } else if (tpl instanceof CNodeMultiAgg) {
            CNodeData in1 = (CNodeData) tpl.getInput().get(0);
            for (CNode output : ((CNodeMultiAgg) tpl).getOutputs()) if (rHasLookupRC1(output, in1, true) || isLookupRC1(output, in1, true)) {
                cplans2.remove(e.getKey());
                if (LOG.isTraceEnabled())
                    LOG.trace("Removed cplan due to invalid rc1 indexing on main input.");
            }
        }
        // remove invalid lookups on main input (all templates)
        CNodeData in1 = (CNodeData) tpl.getInput().get(0);
        if (tpl instanceof CNodeMultiAgg)
            rFindAndRemoveLookupMultiAgg((CNodeMultiAgg) tpl, in1);
        else
            rFindAndRemoveLookup(tpl.getOutput(), in1, !(tpl instanceof CNodeRow));
        // remove invalid row templates (e.g., unsatisfied blocksize constraint)
        if (tpl instanceof CNodeRow) {
            // check for invalid row cplan over column vector
            if (((CNodeRow) tpl).getRowType() == RowType.NO_AGG && tpl.getOutput().getDataType().isScalar()) {
                cplans2.remove(e.getKey());
                if (LOG.isTraceEnabled())
                    LOG.trace("Removed invalid row cplan w/o agg on column vector.");
            } else if (OptimizerUtils.isSparkExecutionMode()) {
                Hop hop = memo.getHopRefs().get(e.getKey());
                boolean isSpark = DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || OptimizerUtils.getTotalMemEstimate(inHops, hop, true) > OptimizerUtils.getLocalMemBudget();
                boolean invalidNcol = hop.getDataType().isMatrix() && (HopRewriteUtils.isTransposeOperation(hop) ? hop.getDim1() > hop.getRowsInBlock() : hop.getDim2() > hop.getColsInBlock());
                for (Hop in : inHops) invalidNcol |= (in.getDataType().isMatrix() && in.getDim2() > in.getColsInBlock());
                if (isSpark && invalidNcol) {
                    cplans2.remove(e.getKey());
                    if (LOG.isTraceEnabled())
                        LOG.trace("Removed invalid row cplan w/ ncol>ncolpb.");
                }
            }
        }
        // remove cplan w/ single op and w/o agg
        if ((tpl instanceof CNodeCell && ((CNodeCell) tpl).getCellType() == CellType.NO_AGG && TemplateUtils.hasSingleOperation(tpl)) || (tpl instanceof CNodeRow && (((CNodeRow) tpl).getRowType() == RowType.NO_AGG || ((CNodeRow) tpl).getRowType() == RowType.NO_AGG_B1 || ((CNodeRow) tpl).getRowType() == RowType.ROW_AGG) && TemplateUtils.hasSingleOperation(tpl)) || TemplateUtils.hasNoOperation(tpl)) {
            cplans2.remove(e.getKey());
            if (LOG.isTraceEnabled())
                LOG.trace("Removed cplan with single operation.");
        }
        // remove cplan if empty
        if (tpl.getOutput() instanceof CNodeData) {
            cplans2.remove(e.getKey());
            if (LOG.isTraceEnabled())
                LOG.trace("Removed empty cplan.");
        }
        // rename inputs (for codegen and plan caching)
        tpl.renameInputs();
    }
    return cplans2;
}
Also used : CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) CNodeTpl(org.apache.sysml.hops.codegen.cplan.CNodeTpl) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Hop(org.apache.sysml.hops.Hop) CPlanCSERewriter(org.apache.sysml.hops.codegen.template.CPlanCSERewriter) CNodeCell(org.apache.sysml.hops.codegen.cplan.CNodeCell) CNode(org.apache.sysml.hops.codegen.cplan.CNode) CNodeRow(org.apache.sysml.hops.codegen.cplan.CNodeRow) CNodeMultiAgg(org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg) CPlanOpRewriter(org.apache.sysml.hops.codegen.template.CPlanOpRewriter) Pair(org.apache.sysml.runtime.matrix.data.Pair)

Example 12 with CNodeData

use of org.apache.sysml.hops.codegen.cplan.CNodeData in project systemml by apache.

the class CPlanComparisonTest method testEqualLiteral.

@Test
public void testEqualLiteral() {
    CNodeData c1 = new CNodeData(new LiteralOp(7), 0, 0, DataType.SCALAR);
    CNodeData c2 = new CNodeData(new LiteralOp(7), 0, 0, DataType.SCALAR);
    Assert.assertEquals(c1.hashCode(), c2.hashCode());
    Assert.assertEquals(c1, c2);
    c1.setLiteral(true);
    c2.setLiteral(true);
    Assert.assertEquals(c1.hashCode(), c2.hashCode());
    Assert.assertEquals(c1, c2);
    c1.setStrictEquals(true);
    c2.setStrictEquals(true);
    Assert.assertEquals(c1.hashCode(), c2.hashCode());
    Assert.assertEquals(c1, c2);
}
Also used : CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) LiteralOp(org.apache.sysml.hops.LiteralOp) Test(org.junit.Test)

Example 13 with CNodeData

use of org.apache.sysml.hops.codegen.cplan.CNodeData in project incubator-systemml by apache.

the class CPlanOpRewriter method rFindAndRemoveBinaryMS.

private static void rFindAndRemoveBinaryMS(CNode node, CNodeData mainInput, BinType type, String lit, String replace) {
    for (int i = 0; i < node.getInput().size(); i++) {
        CNode tmp = node.getInput().get(i);
        if (TemplateUtils.isBinary(tmp, type) && tmp.getInput().get(1).isLiteral() && tmp.getInput().get(1).getVarname().equals(lit) && tmp.getInput().get(0) instanceof CNodeData && ((CNodeData) tmp.getInput().get(0)).getHopID() == mainInput.getHopID()) {
            CNodeData cnode = new CNodeData(new LiteralOp(replace));
            cnode.setLiteral(true);
            node.getInput().set(i, cnode);
        } else
            rFindAndRemoveBinaryMS(tmp, mainInput, type, lit, replace);
    }
}
Also used : CNode(org.apache.sysml.hops.codegen.cplan.CNode) CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 14 with CNodeData

use of org.apache.sysml.hops.codegen.cplan.CNodeData in project incubator-systemml by apache.

the class TemplateMultiAgg method constructCplan.

@Override
public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
    // get all root nodes for multi aggregation
    MemoTableEntry multiAgg = memo.getBest(hop.getHopID(), TemplateType.MAGG);
    ArrayList<Hop> roots = new ArrayList<>();
    for (int i = 0; i < 3; i++) if (multiAgg.isPlanRef(i))
        roots.add(memo._hopRefs.get(multiAgg.input(i)));
    Hop.resetVisitStatus(roots);
    // recursively process required cplan outputs
    HashSet<Hop> inHops = new HashSet<>();
    HashMap<Long, CNode> tmp = new HashMap<>();
    for (// use celltpl cplan construction
    Hop root : // use celltpl cplan construction
    roots) super.rConstructCplan(root, memo, tmp, inHops, compileLiterals);
    Hop.resetVisitStatus(roots);
    // reorder inputs (ensure matrices/vectors come first) and prune literals
    // note: we order by number of cells and subsequently sparsity to ensure
    // that sparse inputs are used as the main input w/o unnecessary conversion
    Hop shared = getSparseSafeSharedInput(roots, inHops);
    Hop[] sinHops = inHops.stream().filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())).sorted(new HopInputComparator(shared)).toArray(Hop[]::new);
    // construct template node
    ArrayList<CNode> inputs = new ArrayList<>();
    for (Hop in : sinHops) inputs.add(tmp.get(in.getHopID()));
    ArrayList<CNode> outputs = new ArrayList<>();
    ArrayList<AggOp> aggOps = new ArrayList<>();
    for (Hop root : roots) {
        CNode node = tmp.get(root.getHopID());
        if (// add indexing ops for sideways data inputs
        node instanceof CNodeData && ((CNodeData) inputs.get(0)).getHopID() != ((CNodeData) node).getHopID())
            node = new CNodeUnary(node, (roots.get(0).getDim2() == 1) ? UnaryType.LOOKUP_R : UnaryType.LOOKUP_RC);
        outputs.add(node);
        aggOps.add(TemplateUtils.getAggOp(root));
    }
    CNodeMultiAgg tpl = new CNodeMultiAgg(inputs, outputs);
    tpl.setAggOps(aggOps);
    tpl.setSparseSafe(isSparseSafe(roots, sinHops[0], tpl.getOutputs(), tpl.getAggOps(), true));
    tpl.setRootNodes(roots);
    tpl.setBeginLine(hop.getBeginLine());
    // return cplan instance
    return new Pair<>(sinHops, tpl);
}
Also used : CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) HashMap(java.util.HashMap) AggOp(org.apache.sysml.hops.Hop.AggOp) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) CNode(org.apache.sysml.hops.codegen.cplan.CNode) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) CNodeMultiAgg(org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg) HashSet(java.util.HashSet) Pair(org.apache.sysml.runtime.matrix.data.Pair)

Example 15 with CNodeData

use of org.apache.sysml.hops.codegen.cplan.CNodeData in project incubator-systemml by apache.

the class TemplateRow method rConstructCplan.

private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
    // memoization for common subexpression elimination and to avoid redundant work
    if (tmp.containsKey(hop.getHopID()))
        return;
    // recursively process required childs
    MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.ROW, TemplateType.CELL);
    for (int i = 0; i < hop.getInput().size(); i++) {
        Hop c = hop.getInput().get(i);
        if (me != null && me.isPlanRef(i))
            rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
        else {
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
    }
    // construct cnode for current hop
    CNode out = null;
    if (hop instanceof AggUnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        if (((AggUnaryOp) hop).getDirection() == Direction.Row && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG)) {
            if (hop.getInput().get(0).getDim2() == 1)
                out = (cdata1.getDataType() == DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
            else {
                String opcode = "ROW_" + ((AggUnaryOp) hop).getOp().name().toUpperCase() + "S";
                out = new CNodeUnary(cdata1, UnaryType.valueOf(opcode));
                if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
                    inHops2.put("X", hop.getInput().get(0));
            }
        } else if (((AggUnaryOp) hop).getDirection() == Direction.Col && ((AggUnaryOp) hop).getOp() == AggOp.SUM) {
            // vector add without temporary copy
            if (cdata1 instanceof CNodeBinary && ((CNodeBinary) cdata1).getType().isVectorScalarPrimitive())
                out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), ((CNodeBinary) cdata1).getType().getVectorAddPrimitive());
            else
                out = cdata1;
        } else if (((AggUnaryOp) hop).getDirection() == Direction.RowCol && ((AggUnaryOp) hop).getOp() == AggOp.SUM) {
            out = (cdata1.getDataType().isMatrix()) ? new CNodeUnary(cdata1, UnaryType.ROW_SUMS) : cdata1;
        }
    } else if (hop instanceof AggBinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) {
            // correct input under transpose
            cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
            inHops.remove(hop.getInput().get(0));
            if (cdata1 instanceof CNodeData)
                inHops.add(hop.getInput().get(0).getInput().get(0));
            // note: vectorMultAdd applicable to vector-scalar, and vector-vector
            if (hop.getInput().get(1).getDim2() == 1)
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
            else {
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_OUTERMULT_ADD);
                if (!inHops2.containsKey("B1")) {
                    // incl modification of X for consistency
                    if (cdata1 instanceof CNodeData)
                        inHops2.put("X", hop.getInput().get(0).getInput().get(0));
                    inHops2.put("B1", hop.getInput().get(1));
                }
            }
            if (!inHops2.containsKey("X"))
                inHops2.put("X", hop.getInput().get(0).getInput().get(0));
        } else {
            if (hop.getInput().get(0).getDim2() == 1 && hop.getInput().get(1).getDim2() == 1)
                out = new CNodeBinary((cdata1.getDataType() == DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0), (cdata2.getDataType() == DataType.SCALAR) ? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT);
            else if (hop.getInput().get(1).getDim2() == 1) {
                out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
                inHops2.put("X", hop.getInput().get(0));
            } else {
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MATRIXMULT);
                inHops2.put("X", hop.getInput().get(0));
                inHops2.put("B1", hop.getInput().get(1));
            }
        }
    } else if (HopRewriteUtils.isTransposeOperation(hop)) {
        out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
        if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
            inHops.add(hop.getInput().get(0));
    } else if (hop instanceof UnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        // if one input is a matrix then we need to do vector by scalar operations
        if (hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1 || (!hop.dimsKnown() && cdata1.getDataType() == DataType.MATRIX)) {
            if (HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY)) {
                String opname = "VECT_" + ((UnaryOp) hop).getOp().name();
                out = new CNodeUnary(cdata1, UnaryType.valueOf(opname));
                if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
                    inHops2.put("X", hop.getInput().get(0));
            } else
                throw new RuntimeException("Unsupported unary matrix " + "operation: " + ((UnaryOp) hop).getOp().name());
        } else // general scalar case
        {
            cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
            String primitiveOpName = ((UnaryOp) hop).getOp().toString();
            out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
        }
    } else if (HopRewriteUtils.isBinary(hop, OpOp2.CBIND)) {
        // special case for cbind with zeros
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = null;
        if (HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1))) {
            cdata2 = TemplateUtils.createCNodeData(HopRewriteUtils.getDataGenOpConstantValue(hop.getInput().get(1)), true);
            // rm 0-matrix
            inHops.remove(hop.getInput().get(1));
        } else {
            cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
        }
        out = new CNodeBinary(cdata1, cdata2, BinType.VECT_CBIND);
        if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
            inHops2.put("X", hop.getInput().get(0));
    } else if (hop instanceof BinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        // if one input is a matrix then we need to do vector by scalar operations
        if ((hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1) || (hop.getInput().get(1).getDim1() >= 1 && hop.getInput().get(1).getDim2() > 1) || (!(hop.dimsKnown() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()) && // not a known vector output
        (hop.getDim2() != 1) && (cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix()))) {
            if (HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY)) {
                if (TemplateUtils.isMatrix(cdata1) && (TemplateUtils.isMatrix(cdata2) || TemplateUtils.isRowVector(cdata2))) {
                    String opname = "VECT_" + ((BinaryOp) hop).getOp().name();
                    out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
                } else {
                    String opname = "VECT_" + ((BinaryOp) hop).getOp().name() + "_SCALAR";
                    if (TemplateUtils.isColVector(cdata1))
                        cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
                    if (TemplateUtils.isColVector(cdata2))
                        cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
                    out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
                }
                if (cdata1 instanceof CNodeData && !inHops2.containsKey("X") && !(cdata1.getDataType() == DataType.SCALAR)) {
                    inHops2.put("X", hop.getInput().get(0));
                }
            } else
                throw new RuntimeException("Unsupported binary matrix " + "operation: " + ((BinaryOp) hop).getOp().name());
        } else // one input is a vector/scalar other is a scalar
        {
            String primitiveOpName = ((BinaryOp) hop).getOp().toString();
            if (TemplateUtils.isColVector(cdata1))
                cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
            if (// vector or vector can be inferred from lhs
            TemplateUtils.isColVector(cdata2) || (TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData && hop.getInput().get(1).getDataType().isMatrix()))
                cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
            out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
        }
    } else if (hop instanceof TernaryOp) {
        TernaryOp top = (TernaryOp) hop;
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
        // add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
        // construct ternary cnode, primitive operation derived from OpOp3
        out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().toString()));
    } else if (HopRewriteUtils.isNary(hop, OpOpN.CBIND)) {
        CNode[] inputs = new CNode[hop.getInput().size()];
        for (int i = 0; i < hop.getInput().size(); i++) {
            Hop c = hop.getInput().get(i);
            CNode cdata = tmp.get(c.getHopID());
            if (TemplateUtils.isColVector(cdata) || TemplateUtils.isRowVector(cdata))
                cdata = TemplateUtils.wrapLookupIfNecessary(cdata, c);
            inputs[i] = cdata;
            if (i == 0 && cdata instanceof CNodeData && !inHops2.containsKey("X"))
                inHops2.put("X", c);
        }
        out = new CNodeNary(inputs, NaryType.VECT_CBIND);
    } else if (hop instanceof ParameterizedBuiltinOp) {
        CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
        CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
        TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
        out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
    } else if (hop instanceof IndexingOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), (hop.getDim2() != 1) ? TernaryType.LOOKUP_RVECT1 : TernaryType.LOOKUP_RC1);
    }
    if (out == null) {
        throw new RuntimeException(hop.getHopID() + " " + hop.getOpString());
    }
    if (out.getDataType().isMatrix()) {
        out.setNumRows(hop.getDim1());
        out.setNumCols(hop.getDim2());
    }
    tmp.put(hop.getHopID(), out);
}
Also used : TernaryType(org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType) CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) CNodeTernary(org.apache.sysml.hops.codegen.cplan.CNodeTernary) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) CNodeBinary(org.apache.sysml.hops.codegen.cplan.CNodeBinary) CNodeNary(org.apache.sysml.hops.codegen.cplan.CNodeNary) TernaryOp(org.apache.sysml.hops.TernaryOp) CNode(org.apache.sysml.hops.codegen.cplan.CNode) ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)23 CNode (org.apache.sysml.hops.codegen.cplan.CNode)17 Hop (org.apache.sysml.hops.Hop)13 LiteralOp (org.apache.sysml.hops.LiteralOp)12 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)8 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)8 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)6 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)6 BinaryOp (org.apache.sysml.hops.BinaryOp)6 UnaryOp (org.apache.sysml.hops.UnaryOp)6 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)6 Test (org.junit.Test)6 HashMap (java.util.HashMap)5 CNodeMultiAgg (org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg)5 Pair (org.apache.sysml.runtime.matrix.data.Pair)5 IndexingOp (org.apache.sysml.hops.IndexingOp)4 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)4 TernaryOp (org.apache.sysml.hops.TernaryOp)4 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)4 TernaryType (org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType)4