use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.
the class PlanSelectionFuseCostBasedV2 method rCollectDependentRowOps.
private static void rCollectDependentRowOps(Hop hop, CPlanMemoTable memo, PlanPartition part, HashSet<Long> blacklist, HashSet<Pair<Long, Integer>> visited, TemplateType type, boolean foundRowOp) {
// avoid redundant evaluation of processed and non-partition nodes
Pair<Long, Integer> key = Pair.of(hop.getHopID(), (foundRowOp ? Short.MAX_VALUE : 0) + ((type != null) ? type.ordinal() + 1 : 0));
if (visited.contains(key) || !part.getPartition().contains(hop.getHopID())) {
return;
}
// process node itself (top-down)
MemoTableEntry me = (type == null) ? memo.getBest(hop.getHopID()) : memo.getBest(hop.getHopID(), type);
boolean inRow = (me != null && me.type == TemplateType.ROW && type == TemplateType.ROW);
boolean diffPlans = // guard against plan differences
part.getMatPointsExt().length > 0 && memo.contains(hop.getHopID(), TemplateType.ROW) && !memo.hasOnlyExactMatches(hop.getHopID(), TemplateType.ROW, TemplateType.CELL);
if (inRow && foundRowOp)
blacklist.add(hop.getHopID());
if (isRowAggOp(hop, inRow) || diffPlans) {
blacklist.add(hop.getHopID());
foundRowOp = true;
}
// process children recursively
for (int i = 0; i < hop.getInput().size(); i++) {
boolean lfoundRowOp = foundRowOp && me != null && (me.isPlanRef(i) || isImplicitlyFused(hop, i, me.type));
rCollectDependentRowOps(hop.getInput().get(i), memo, part, blacklist, visited, me != null ? me.type : null, lfoundRowOp);
}
// process node itself (bottom-up)
if (!blacklist.contains(hop.getHopID())) {
for (int i = 0; i < hop.getInput().size(); i++) if (me != null && me.type == TemplateType.ROW && (me.isPlanRef(i) || isImplicitlyFused(hop, i, me.type)) && blacklist.contains(hop.getInput().get(i).getHopID())) {
blacklist.add(hop.getHopID());
}
}
visited.add(key);
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.
the class PlanSelectionFuseNoRedundancy method rSelectPlans.
private void rSelectPlans(CPlanMemoTable memo, Hop current, TemplateType currentType) {
if (isVisited(current.getHopID(), currentType))
return;
// step 0: remove plans that refer to a common partial plan
if (memo.contains(current.getHopID())) {
HashSet<MemoTableEntry> rmSet = new HashSet<>();
List<MemoTableEntry> hopP = memo.get(current.getHopID());
for (MemoTableEntry e1 : hopP) for (int i = 0; i < 3; i++) if (e1.isPlanRef(i) && current.getInput().get(i).getParent().size() > 1)
// remove references to hops w/ multiple consumers
rmSet.add(e1);
memo.remove(current, rmSet);
}
// step 1: prune subsumed plans of same type
if (memo.contains(current.getHopID())) {
HashSet<MemoTableEntry> rmSet = new HashSet<>();
List<MemoTableEntry> hopP = memo.get(current.getHopID());
for (MemoTableEntry e1 : hopP) for (MemoTableEntry e2 : hopP) if (e1 != e2 && e1.subsumes(e2))
rmSet.add(e2);
memo.remove(current, rmSet);
}
// step 2: select plan for current path
MemoTableEntry best = null;
if (memo.contains(current.getHopID())) {
if (currentType == null) {
best = memo.get(current.getHopID()).stream().filter(p -> p.isValid()).min(new BasicPlanComparator()).orElse(null);
} else {
best = memo.get(current.getHopID()).stream().filter(p -> p.type == currentType || p.type == TemplateType.CELL).min(Comparator.comparing(p -> 7 - ((p.type == currentType) ? 4 : 0) - p.countPlanRefs())).orElse(null);
}
addBestPlan(current.getHopID(), best);
}
// step 3: recursively process children
for (int i = 0; i < current.getInput().size(); i++) {
TemplateType pref = (best != null && best.isPlanRef(i)) ? best.type : null;
rSelectPlans(memo, current.getInput().get(i), pref);
}
setVisited(current.getHopID(), currentType);
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.
the class TemplateCell method rConstructCplan.
protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) {
// memoization for common subexpression elimination and to avoid redundant work
if (tmp.containsKey(hop.getHopID()))
return;
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CELL);
// recursively process required childs
if (me != null && me.type.isIn(TemplateType.ROW, TemplateType.OUTER)) {
CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals);
tmp.put(hop.getHopID(), cdata);
inHops.add(hop);
return;
}
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
if (me != null && me.isPlanRef(i) && !(c instanceof DataOp) && (me.type != TemplateType.MAGG || memo.contains(c.getHopID(), TemplateType.CELL)))
rConstructCplan(c, memo, tmp, inHops, compileLiterals);
else if (me != null && (me.type == TemplateType.MAGG || me.type == TemplateType.CELL) && HopRewriteUtils.isMatrixMultiply(hop) && // skip transpose
i == 0)
rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
// construct cnode for current hop
CNode out = null;
if (hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
String primitiveOpName = ((UnaryOp) hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
} else if (hop instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
String primitiveOpName = bop.getOp().name();
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
// construct binary cnode
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
} else if (hop instanceof TernaryOp) {
TernaryOp top = (TernaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
// construct ternary cnode, primitive operation derived from OpOp3
out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().name()));
} else if (hop instanceof ParameterizedBuiltinOp) {
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
} else if (hop instanceof IndexingOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), TernaryType.LOOKUP_RC1);
} else if (HopRewriteUtils.isTransposeOperation(hop)) {
out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
// correct indexing types of existing lookups
if (!HopRewriteUtils.containsOp(hop.getParent(), AggBinaryOp.class))
TemplateUtils.rFlipVectorLookups(out);
// maintain input hops
if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
inHops.add(hop.getInput().get(0));
} else if (hop instanceof AggUnaryOp) {
// aggregation handled in template implementation (note: we do not compile
// ^2 of SUM_SQ into the operator to simplify the detection of single operators)
out = tmp.get(hop.getInput().get(0).getHopID());
} else if (hop instanceof AggBinaryOp) {
// (1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y)
if (HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1))) {
CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID());
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
out = new CNodeUnary(cdata1, UnaryType.POW2);
} else {
CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals);
if (cdata1 instanceof CNodeData && !inHops.contains(hop.getInput().get(0).getInput().get(0)))
inHops.add(hop.getInput().get(0).getInput().get(0));
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
if (TemplateUtils.isColVector(cdata2))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.MULT);
}
}
tmp.put(hop.getHopID(), out);
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.
the class TemplateOuterProduct method rConstructCplan.
private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
// memoization for common subexpression elimination and to avoid redundant work
if (tmp.containsKey(hop.getHopID()))
return;
// recursively process required childs
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.OUTER, TemplateType.CELL);
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
if (me.isPlanRef(i))
rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
// construct cnode for current hop
CNode out = null;
if (hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
String primitiveOpName = ((UnaryOp) hop).getOp().toString();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
} else if (hop instanceof BinaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
String primitiveOpName = ((BinaryOp) hop).getOp().toString();
if (HopRewriteUtils.isBinarySparseSafe(hop)) {
if (TemplateUtils.isMatrix(hop.getInput().get(0)) && cdata1 instanceof CNodeData)
inHops2.put("_X", hop.getInput().get(0));
if (TemplateUtils.isMatrix(hop.getInput().get(1)) && cdata2 instanceof CNodeData)
inHops2.put("_X", hop.getInput().get(1));
}
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
} else if (hop instanceof AggBinaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
// handle transpose in outer or final product
cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
cdata2 = TemplateUtils.skipTranspose(cdata2, hop.getInput().get(1), tmp, compileLiterals);
// outer product U%*%t(V), see open
if (HopRewriteUtils.isOuterProductLikeMM(hop)) {
// keep U and V for later reference
inHops2.put("_U", hop.getInput().get(0));
if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(1)))
inHops2.put("_V", hop.getInput().get(1).getInput().get(0));
else
inHops2.put("_V", hop.getInput().get(1));
out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
} else // final left/right matrix mult, see close
{
if (cdata1.getDataType().isScalar())
out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD);
else
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
}
} else if (HopRewriteUtils.isTransposeOperation(hop)) {
out = tmp.get(hop.getInput().get(0).getHopID());
} else if (hop instanceof AggUnaryOp && ((AggUnaryOp) hop).getOp() == AggOp.SUM && ((AggUnaryOp) hop).getDirection() == Direction.RowCol) {
out = tmp.get(hop.getInput().get(0).getHopID());
}
tmp.put(hop.getHopID(), out);
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.
the class TemplateRow method rConstructCplan.
private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
// memoization for common subexpression elimination and to avoid redundant work
if (tmp.containsKey(hop.getHopID()))
return;
// recursively process required childs
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.ROW, TemplateType.CELL);
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
if (me != null && me.isPlanRef(i))
rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
// construct cnode for current hop
CNode out = null;
if (hop instanceof AggUnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
if (((AggUnaryOp) hop).getDirection() == Direction.Row && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG)) {
if (hop.getInput().get(0).getDim2() == 1)
out = (cdata1.getDataType() == DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
else {
String opcode = "ROW_" + ((AggUnaryOp) hop).getOp().name().toUpperCase() + "S";
out = new CNodeUnary(cdata1, UnaryType.valueOf(opcode));
if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
inHops2.put("X", hop.getInput().get(0));
}
} else if (((AggUnaryOp) hop).getDirection() == Direction.Col && ((AggUnaryOp) hop).getOp() == AggOp.SUM) {
// vector add without temporary copy
if (cdata1 instanceof CNodeBinary && ((CNodeBinary) cdata1).getType().isVectorScalarPrimitive())
out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), ((CNodeBinary) cdata1).getType().getVectorAddPrimitive());
else
out = cdata1;
} else if (((AggUnaryOp) hop).getDirection() == Direction.RowCol && ((AggUnaryOp) hop).getOp() == AggOp.SUM) {
out = (cdata1.getDataType().isMatrix()) ? new CNodeUnary(cdata1, UnaryType.ROW_SUMS) : cdata1;
}
} else if (hop instanceof AggBinaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) {
// correct input under transpose
cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
inHops.remove(hop.getInput().get(0));
if (cdata1 instanceof CNodeData)
inHops.add(hop.getInput().get(0).getInput().get(0));
// note: vectorMultAdd applicable to vector-scalar, and vector-vector
if (hop.getInput().get(1).getDim2() == 1)
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
else {
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_OUTERMULT_ADD);
if (!inHops2.containsKey("B1")) {
// incl modification of X for consistency
if (cdata1 instanceof CNodeData)
inHops2.put("X", hop.getInput().get(0).getInput().get(0));
inHops2.put("B1", hop.getInput().get(1));
}
}
if (!inHops2.containsKey("X"))
inHops2.put("X", hop.getInput().get(0).getInput().get(0));
} else {
if (hop.getInput().get(0).getDim2() == 1 && hop.getInput().get(1).getDim2() == 1)
out = new CNodeBinary((cdata1.getDataType() == DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0), (cdata2.getDataType() == DataType.SCALAR) ? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT);
else if (hop.getInput().get(1).getDim2() == 1) {
out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
inHops2.put("X", hop.getInput().get(0));
} else {
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MATRIXMULT);
inHops2.put("X", hop.getInput().get(0));
inHops2.put("B1", hop.getInput().get(1));
}
}
} else if (HopRewriteUtils.isTransposeOperation(hop)) {
out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
inHops.add(hop.getInput().get(0));
} else if (hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
// if one input is a matrix then we need to do vector by scalar operations
if (hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1 || (!hop.dimsKnown() && cdata1.getDataType() == DataType.MATRIX)) {
if (HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY)) {
String opname = "VECT_" + ((UnaryOp) hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(opname));
if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
inHops2.put("X", hop.getInput().get(0));
} else
throw new RuntimeException("Unsupported unary matrix " + "operation: " + ((UnaryOp) hop).getOp().name());
} else // general scalar case
{
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
String primitiveOpName = ((UnaryOp) hop).getOp().toString();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
}
} else if (HopRewriteUtils.isBinary(hop, OpOp2.CBIND)) {
// special case for cbind with zeros
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = null;
if (HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1))) {
cdata2 = TemplateUtils.createCNodeData(HopRewriteUtils.getDataGenOpConstantValue(hop.getInput().get(1)), true);
// rm 0-matrix
inHops.remove(hop.getInput().get(1));
} else {
cdata2 = tmp.get(hop.getInput().get(1).getHopID());
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
}
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_CBIND);
if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
inHops2.put("X", hop.getInput().get(0));
} else if (hop instanceof BinaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
// if one input is a matrix then we need to do vector by scalar operations
if ((hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1) || (hop.getInput().get(1).getDim1() >= 1 && hop.getInput().get(1).getDim2() > 1) || (!(hop.dimsKnown() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()) && // not a known vector output
(hop.getDim2() != 1) && (cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix()))) {
if (HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY)) {
if (TemplateUtils.isMatrix(cdata1) && (TemplateUtils.isMatrix(cdata2) || TemplateUtils.isRowVector(cdata2))) {
String opname = "VECT_" + ((BinaryOp) hop).getOp().name();
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
} else {
String opname = "VECT_" + ((BinaryOp) hop).getOp().name() + "_SCALAR";
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
if (TemplateUtils.isColVector(cdata2))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
}
if (cdata1 instanceof CNodeData && !inHops2.containsKey("X") && !(cdata1.getDataType() == DataType.SCALAR)) {
inHops2.put("X", hop.getInput().get(0));
}
} else
throw new RuntimeException("Unsupported binary matrix " + "operation: " + ((BinaryOp) hop).getOp().name());
} else // one input is a vector/scalar other is a scalar
{
String primitiveOpName = ((BinaryOp) hop).getOp().toString();
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
if (// vector or vector can be inferred from lhs
TemplateUtils.isColVector(cdata2) || (TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData && hop.getInput().get(1).getDataType().isMatrix()))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
}
} else if (hop instanceof TernaryOp) {
TernaryOp top = (TernaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
// construct ternary cnode, primitive operation derived from OpOp3
out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().toString()));
} else if (HopRewriteUtils.isNary(hop, OpOpN.CBIND)) {
CNode[] inputs = new CNode[hop.getInput().size()];
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
CNode cdata = tmp.get(c.getHopID());
if (TemplateUtils.isColVector(cdata) || TemplateUtils.isRowVector(cdata))
cdata = TemplateUtils.wrapLookupIfNecessary(cdata, c);
inputs[i] = cdata;
if (i == 0 && cdata instanceof CNodeData && !inHops2.containsKey("X"))
inHops2.put("X", c);
}
out = new CNodeNary(inputs, NaryType.VECT_CBIND);
} else if (hop instanceof ParameterizedBuiltinOp) {
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
} else if (hop instanceof IndexingOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), (hop.getDim2() != 1) ? TernaryType.LOOKUP_RVECT1 : TernaryType.LOOKUP_RC1);
}
if (out == null) {
throw new RuntimeException(hop.getHopID() + " " + hop.getOpString());
}
if (out.getDataType().isMatrix()) {
out.setNumRows(hop.getDim1());
out.setNumCols(hop.getDim2());
}
tmp.put(hop.getHopID(), out);
}
Aggregations