Search in sources :

Example 1 with MVImputeAgent

use of org.apache.sysml.runtime.transform.MVImputeAgent in project incubator-systemml by apache.

the class MultiReturnParameterizedBuiltinSPInstruction method processInstruction.

@Override
@SuppressWarnings("unchecked")
public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    try {
        //get input RDD and meta data
        FrameObject fo = sec.getFrameObject(input1.getName());
        FrameObject fometa = sec.getFrameObject(_outputs.get(1).getName());
        JavaPairRDD<Long, FrameBlock> in = (JavaPairRDD<Long, FrameBlock>) sec.getRDDHandleForFrameObject(fo, InputInfo.BinaryBlockInputInfo);
        String spec = ec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getStringValue();
        MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(input1.getName());
        MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
        String[] colnames = !TfMetaUtils.isIDSpecification(spec) ? in.lookup(1L).get(0).getColumnNames() : null;
        //step 1: build transform meta data
        Encoder encoderBuild = EncoderFactory.createEncoder(spec, colnames, fo.getSchema(), (int) fo.getNumColumns(), null);
        MaxLongAccumulator accMax = registerMaxLongAccumulator(sec.getSparkContext());
        JavaRDD<String> rcMaps = in.mapPartitionsToPair(new TransformEncodeBuildFunction(encoderBuild)).distinct().groupByKey().flatMap(new TransformEncodeGroupFunction(accMax));
        if (containsMVImputeEncoder(encoderBuild)) {
            MVImputeAgent mva = getMVImputeEncoder(encoderBuild);
            rcMaps = rcMaps.union(in.mapPartitionsToPair(new TransformEncodeBuild2Function(mva)).groupByKey().flatMap(new TransformEncodeGroup2Function(mva)));
        }
        //trigger eval
        rcMaps.saveAsTextFile(fometa.getFileName());
        //consolidate meta data frame (reuse multi-threaded reader, special handling missing values) 
        FrameReader reader = FrameReaderFactory.createFrameReader(InputInfo.TextCellInputInfo);
        FrameBlock meta = reader.readFrameFromHDFS(fometa.getFileName(), accMax.value(), fo.getNumColumns());
        //recompute num distinct items per column
        meta.recomputeColumnCardinality();
        meta.setColumnNames((colnames != null) ? colnames : meta.getColumnNames());
        //step 2: transform apply (similar to spark transformapply)
        //compute omit offset map for block shifts
        TfOffsetMap omap = null;
        if (TfMetaUtils.containsOmitSpec(spec, colnames)) {
            omap = new TfOffsetMap(SparkUtils.toIndexedLong(in.mapToPair(new RDDTransformApplyOffsetFunction(spec, colnames)).collect()));
        }
        //create encoder broadcast (avoiding replication per task) 
        Encoder encoder = EncoderFactory.createEncoder(spec, colnames, fo.getSchema(), (int) fo.getNumColumns(), meta);
        mcOut.setDimension(mcIn.getRows() - ((omap != null) ? omap.getNumRmRows() : 0), encoder.getNumCols());
        Broadcast<Encoder> bmeta = sec.getSparkContext().broadcast(encoder);
        Broadcast<TfOffsetMap> bomap = (omap != null) ? sec.getSparkContext().broadcast(omap) : null;
        //execute transform apply
        JavaPairRDD<Long, FrameBlock> tmp = in.mapToPair(new RDDTransformApplyFunction(bmeta, bomap));
        JavaPairRDD<MatrixIndexes, MatrixBlock> out = FrameRDDConverterUtils.binaryBlockToMatrixBlock(tmp, mcOut, mcOut);
        //set output and maintain lineage/output characteristics
        sec.setRDDHandleForVariable(_outputs.get(0).getName(), out);
        sec.addLineageRDD(_outputs.get(0).getName(), input1.getName());
        sec.setFrameOutput(_outputs.get(1).getName(), meta);
    } catch (IOException ex) {
        throw new RuntimeException(ex);
    }
}
Also used : MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) FrameBlock(org.apache.sysml.runtime.matrix.data.FrameBlock) Encoder(org.apache.sysml.runtime.transform.encode.Encoder) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext) RDDTransformApplyOffsetFunction(org.apache.sysml.runtime.instructions.spark.ParameterizedBuiltinSPInstruction.RDDTransformApplyOffsetFunction) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) FrameObject(org.apache.sysml.runtime.controlprogram.caching.FrameObject) IOException(java.io.IOException) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) RDDTransformApplyFunction(org.apache.sysml.runtime.instructions.spark.ParameterizedBuiltinSPInstruction.RDDTransformApplyFunction) TfOffsetMap(org.apache.sysml.runtime.transform.meta.TfOffsetMap) FrameReader(org.apache.sysml.runtime.io.FrameReader) MVImputeAgent(org.apache.sysml.runtime.transform.MVImputeAgent)

Example 2 with MVImputeAgent

use of org.apache.sysml.runtime.transform.MVImputeAgent in project incubator-systemml by apache.

the class EncoderFactory method createEncoder.

@SuppressWarnings("unchecked")
public static Encoder createEncoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta) throws DMLRuntimeException {
    Encoder encoder = null;
    int clen = schema.length;
    try {
        //parse transform specification
        JSONObject jSpec = new JSONObject(spec);
        List<Encoder> lencoders = new ArrayList<Encoder>();
        //prepare basic id lists (recode, dummycode, pass-through)
        //note: any dummycode column requires recode as preparation
        List<Integer> rcIDs = Arrays.asList(ArrayUtils.toObject(TfMetaUtils.parseJsonIDList(jSpec, colnames, TfUtils.TXMETHOD_RECODE)));
        List<Integer> dcIDs = Arrays.asList(ArrayUtils.toObject(TfMetaUtils.parseJsonIDList(jSpec, colnames, TfUtils.TXMETHOD_DUMMYCODE)));
        rcIDs = new ArrayList<Integer>(CollectionUtils.union(rcIDs, dcIDs));
        List<Integer> binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames);
        List<Integer> ptIDs = new ArrayList<Integer>(CollectionUtils.subtract(CollectionUtils.subtract(UtilFunctions.getSequenceList(1, clen, 1), rcIDs), binIDs));
        List<Integer> oIDs = Arrays.asList(ArrayUtils.toObject(TfMetaUtils.parseJsonIDList(jSpec, colnames, TfUtils.TXMETHOD_OMIT)));
        List<Integer> mvIDs = Arrays.asList(ArrayUtils.toObject(TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, TfUtils.TXMETHOD_IMPUTE)));
        //create individual encoders
        if (!rcIDs.isEmpty()) {
            RecodeAgent ra = new RecodeAgent(jSpec, colnames, clen);
            ra.setColList(ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0])));
            lencoders.add(ra);
        }
        if (!ptIDs.isEmpty())
            lencoders.add(new EncoderPassThrough(ArrayUtils.toPrimitive(ptIDs.toArray(new Integer[0])), clen));
        if (!dcIDs.isEmpty())
            lencoders.add(new DummycodeAgent(jSpec, colnames, schema.length));
        if (!binIDs.isEmpty())
            lencoders.add(new BinAgent(jSpec, colnames, schema.length, true));
        if (!oIDs.isEmpty())
            lencoders.add(new OmitAgent(jSpec, colnames, schema.length));
        if (!mvIDs.isEmpty()) {
            MVImputeAgent ma = new MVImputeAgent(jSpec, colnames, schema.length);
            ma.initRecodeIDList(rcIDs);
            lencoders.add(ma);
        }
        //create composite decoder of all created encoders
        //and initialize meta data (recode, dummy, bin, mv)
        encoder = new EncoderComposite(lencoders);
        if (meta != null)
            encoder.initMetaData(meta);
    } catch (Exception ex) {
        throw new DMLRuntimeException(ex);
    }
    return encoder;
}
Also used : RecodeAgent(org.apache.sysml.runtime.transform.RecodeAgent) ArrayList(java.util.ArrayList) BinAgent(org.apache.sysml.runtime.transform.BinAgent) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) JSONObject(org.apache.wink.json4j.JSONObject) DummycodeAgent(org.apache.sysml.runtime.transform.DummycodeAgent) OmitAgent(org.apache.sysml.runtime.transform.OmitAgent) MVImputeAgent(org.apache.sysml.runtime.transform.MVImputeAgent)

Aggregations

DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)2 MVImputeAgent (org.apache.sysml.runtime.transform.MVImputeAgent)2 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)1 FrameObject (org.apache.sysml.runtime.controlprogram.caching.FrameObject)1 SparkExecutionContext (org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext)1 RDDTransformApplyFunction (org.apache.sysml.runtime.instructions.spark.ParameterizedBuiltinSPInstruction.RDDTransformApplyFunction)1 RDDTransformApplyOffsetFunction (org.apache.sysml.runtime.instructions.spark.ParameterizedBuiltinSPInstruction.RDDTransformApplyOffsetFunction)1 FrameReader (org.apache.sysml.runtime.io.FrameReader)1 MatrixCharacteristics (org.apache.sysml.runtime.matrix.MatrixCharacteristics)1 FrameBlock (org.apache.sysml.runtime.matrix.data.FrameBlock)1 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)1 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)1 BinAgent (org.apache.sysml.runtime.transform.BinAgent)1 DummycodeAgent (org.apache.sysml.runtime.transform.DummycodeAgent)1 OmitAgent (org.apache.sysml.runtime.transform.OmitAgent)1 RecodeAgent (org.apache.sysml.runtime.transform.RecodeAgent)1 Encoder (org.apache.sysml.runtime.transform.encode.Encoder)1 TfOffsetMap (org.apache.sysml.runtime.transform.meta.TfOffsetMap)1