use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class TemplateCell method rConstructCplan.
protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) {
// memoization for common subexpression elimination and to avoid redundant work
if (tmp.containsKey(hop.getHopID()))
return;
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CELL);
// recursively process required childs
if (me != null && me.type.isIn(TemplateType.ROW, TemplateType.OUTER)) {
CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals);
tmp.put(hop.getHopID(), cdata);
inHops.add(hop);
return;
}
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
if (me != null && me.isPlanRef(i) && !(c instanceof DataOp) && (me.type != TemplateType.MAGG || memo.contains(c.getHopID(), TemplateType.CELL)))
rConstructCplan(c, memo, tmp, inHops, compileLiterals);
else if (me != null && (me.type == TemplateType.MAGG || me.type == TemplateType.CELL) && HopRewriteUtils.isMatrixMultiply(hop) && // skip transpose
i == 0)
rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
// construct cnode for current hop
CNode out = null;
if (hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
String primitiveOpName = ((UnaryOp) hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
} else if (hop instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
String primitiveOpName = bop.getOp().name();
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
// construct binary cnode
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
} else if (hop instanceof TernaryOp) {
TernaryOp top = (TernaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
// construct ternary cnode, primitive operation derived from OpOp3
out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().name()));
} else if (hop instanceof ParameterizedBuiltinOp) {
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
} else if (hop instanceof IndexingOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), TernaryType.LOOKUP_RC1);
} else if (HopRewriteUtils.isTransposeOperation(hop)) {
out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
// correct indexing types of existing lookups
if (!HopRewriteUtils.containsOp(hop.getParent(), AggBinaryOp.class))
TemplateUtils.rFlipVectorLookups(out);
// maintain input hops
if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
inHops.add(hop.getInput().get(0));
} else if (hop instanceof AggUnaryOp) {
// aggregation handled in template implementation (note: we do not compile
// ^2 of SUM_SQ into the operator to simplify the detection of single operators)
out = tmp.get(hop.getInput().get(0).getHopID());
} else if (hop instanceof AggBinaryOp) {
// (1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y)
if (HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1))) {
CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID());
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
out = new CNodeUnary(cdata1, UnaryType.POW2);
} else {
CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals);
if (cdata1 instanceof CNodeData && !inHops.contains(hop.getInput().get(0).getInput().get(0)))
inHops.add(hop.getInput().get(0).getInput().get(0));
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
if (TemplateUtils.isColVector(cdata2))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.MULT);
}
}
tmp.put(hop.getHopID(), out);
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class TemplateOuterProduct method constructCplan.
@Override
public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
// recursively process required cplan output
HashSet<Hop> inHops = new HashSet<>();
HashMap<String, Hop> inHops2 = new HashMap<>();
HashMap<Long, CNode> tmp = new HashMap<>();
hop.resetVisitStatus();
rConstructCplan(hop, memo, tmp, inHops, inHops2, compileLiterals);
hop.resetVisitStatus();
// reorder inputs (ensure matrix is first input)
Hop X = inHops2.get("_X");
Hop U = inHops2.get("_U");
Hop V = inHops2.get("_V");
LinkedList<Hop> sinHops = new LinkedList<>(inHops);
sinHops.remove(V);
sinHops.addFirst(V);
sinHops.remove(U);
sinHops.addFirst(U);
sinHops.remove(X);
sinHops.addFirst(X);
// construct template node
ArrayList<CNode> inputs = new ArrayList<>();
for (Hop in : sinHops) if (in != null)
inputs.add(tmp.get(in.getHopID()));
CNode output = tmp.get(hop.getHopID());
CNodeOuterProduct tpl = new CNodeOuterProduct(inputs, output);
tpl.setOutProdType(TemplateUtils.getOuterProductType(X, U, V, hop));
tpl.setTransposeOutput(!HopRewriteUtils.isTransposeOperation(hop) && tpl.getOutProdType() == OutProdType.LEFT_OUTER_PRODUCT);
tpl.setBeginLine(hop.getBeginLine());
return new Pair<>(sinHops.toArray(new Hop[0]), tpl);
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class TemplateOuterProduct method rConstructCplan.
private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
// memoization for common subexpression elimination and to avoid redundant work
if (tmp.containsKey(hop.getHopID()))
return;
// recursively process required childs
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.OUTER, TemplateType.CELL);
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
if (me.isPlanRef(i))
rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
// construct cnode for current hop
CNode out = null;
if (hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
String primitiveOpName = ((UnaryOp) hop).getOp().toString();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
} else if (hop instanceof BinaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
String primitiveOpName = ((BinaryOp) hop).getOp().toString();
if (HopRewriteUtils.isBinarySparseSafe(hop)) {
if (TemplateUtils.isMatrix(hop.getInput().get(0)) && cdata1 instanceof CNodeData)
inHops2.put("_X", hop.getInput().get(0));
if (TemplateUtils.isMatrix(hop.getInput().get(1)) && cdata2 instanceof CNodeData)
inHops2.put("_X", hop.getInput().get(1));
}
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
} else if (hop instanceof AggBinaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
// handle transpose in outer or final product
cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
cdata2 = TemplateUtils.skipTranspose(cdata2, hop.getInput().get(1), tmp, compileLiterals);
// outer product U%*%t(V), see open
if (HopRewriteUtils.isOuterProductLikeMM(hop)) {
// keep U and V for later reference
inHops2.put("_U", hop.getInput().get(0));
if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(1)))
inHops2.put("_V", hop.getInput().get(1).getInput().get(0));
else
inHops2.put("_V", hop.getInput().get(1));
out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
} else // final left/right matrix mult, see close
{
if (cdata1.getDataType().isScalar())
out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD);
else
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
}
} else if (HopRewriteUtils.isTransposeOperation(hop)) {
out = tmp.get(hop.getInput().get(0).getHopID());
} else if (hop instanceof AggUnaryOp && ((AggUnaryOp) hop).getOp() == AggOp.SUM && ((AggUnaryOp) hop).getDirection() == Direction.RowCol) {
out = tmp.get(hop.getInput().get(0).getHopID());
}
tmp.put(hop.getHopID(), out);
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class TemplateUtils method getRowType.
public static RowType getRowType(Hop output, Hop... inputs) {
Hop X = inputs[0];
Hop B1 = (inputs.length > 1) ? inputs[1] : null;
if ((X != null && HopRewriteUtils.isEqualSize(output, X)) || X == null || !X.dimsKnown())
return RowType.NO_AGG;
else if (((B1 != null && output.getDim1() == X.getDim1() && output.getDim2() == B1.getDim2()) || (output instanceof IndexingOp && HopRewriteUtils.isColumnRangeIndexing((IndexingOp) output))) && !(output instanceof AggBinaryOp && HopRewriteUtils.isTransposeOfItself(output.getInput().get(0), X)))
return RowType.NO_AGG_B1;
else if (output.getDim1() == X.getDim1() && (output.getDim2() == 1) && !(output instanceof AggBinaryOp && HopRewriteUtils.isTransposeOfItself(output.getInput().get(0), X)))
return RowType.ROW_AGG;
else if (output instanceof AggUnaryOp && ((AggUnaryOp) output).getDirection() == Direction.RowCol)
return RowType.FULL_AGG;
else if (output.getDim1() == X.getDim2() && output.getDim2() == 1)
return RowType.COL_AGG_T;
else if (output.getDim1() == 1 && output.getDim2() == X.getDim2())
return RowType.COL_AGG;
else if (B1 != null && output.getDim1() == X.getDim2() && output.getDim2() == B1.getDim2())
return RowType.COL_AGG_B1_T;
else if (B1 != null && output.getDim1() == B1.getDim2() && output.getDim2() == X.getDim2())
return RowType.COL_AGG_B1;
else if (B1 != null && output.getDim1() == 1 && B1.getDim2() == output.getDim2())
return RowType.COL_AGG_B1R;
else if (X.getDim1() == output.getDim1() && X.getDim2() != output.getDim2())
return RowType.NO_AGG_CONST;
else if (output.getDim1() == 1 && X.getDim2() != output.getDim2())
return RowType.COL_AGG_CONST;
else
throw new RuntimeException("Unknown row type for hop " + output.getHopID() + ".");
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class TemplateUtils method isBinaryMatrixColVector.
public static boolean isBinaryMatrixColVector(Hop hop) {
if (!(hop instanceof BinaryOp))
return false;
Hop left = hop.getInput().get(0);
Hop right = hop.getInput().get(1);
return left.dimsKnown() && right.dimsKnown() && left.getDataType().isMatrix() && right.getDataType().isMatrix() && left.getDim2() > right.getDim2();
}
Aggregations