use of org.apache.sysml.hops.DataGenOp in project systemml by apache.
the class HopRewriteUtils method copyDataGenOp.
/**
* Assumes that min and max are literal ops, needs to be checked from outside.
*
* @param inputGen input data gen op
* @param scale the scale
* @param shift the shift
* @return data gen op
*/
public static DataGenOp copyDataGenOp(DataGenOp inputGen, double scale, double shift) {
HashMap<String, Integer> params = inputGen.getParamIndexMap();
Hop rows = inputGen.getInput().get(params.get(DataExpression.RAND_ROWS));
Hop cols = inputGen.getInput().get(params.get(DataExpression.RAND_COLS));
Hop min = inputGen.getInput().get(params.get(DataExpression.RAND_MIN));
Hop max = inputGen.getInput().get(params.get(DataExpression.RAND_MAX));
Hop pdf = inputGen.getInput().get(params.get(DataExpression.RAND_PDF));
Hop mean = inputGen.getInput().get(params.get(DataExpression.RAND_LAMBDA));
Hop sparsity = inputGen.getInput().get(params.get(DataExpression.RAND_SPARSITY));
Hop seed = inputGen.getInput().get(params.get(DataExpression.RAND_SEED));
// check for literal ops
if (!(min instanceof LiteralOp) || !(max instanceof LiteralOp))
return null;
// scale and shift
double smin = getDoubleValue((LiteralOp) min);
double smax = getDoubleValue((LiteralOp) max);
smin = smin * scale + shift;
smax = smax * scale + shift;
Hop sminHop = new LiteralOp(smin);
Hop smaxHop = new LiteralOp(smax);
HashMap<String, Hop> params2 = new HashMap<>();
params2.put(DataExpression.RAND_ROWS, rows);
params2.put(DataExpression.RAND_COLS, cols);
params2.put(DataExpression.RAND_MIN, sminHop);
params2.put(DataExpression.RAND_MAX, smaxHop);
params2.put(DataExpression.RAND_PDF, pdf);
params2.put(DataExpression.RAND_LAMBDA, mean);
params2.put(DataExpression.RAND_SPARSITY, sparsity);
params2.put(DataExpression.RAND_SEED, seed);
// note internal refresh size information
DataGenOp datagen = new DataGenOp(DataGenMethod.RAND, new DataIdentifier("tmp"), params2);
datagen.setOutputBlocksizes(inputGen.getRowsInBlock(), inputGen.getColsInBlock());
copyLineNumbers(inputGen, datagen);
if (smin == 0 && smax == 0)
datagen.setNnz(0);
return datagen;
}
Aggregations