Search in sources :

Example 16 with VectorUDT

use of org.apache.spark.ml.linalg.VectorUDT in project incubator-systemml by apache.

the class RDDConverterUtilsExt method stringDataFrameToVectorDataFrame.

/**
 * Convert a dataframe of comma-separated string rows to a dataframe of
 * ml.linalg.Vector rows.
 *
 * <p>
 * Example input rows:<br>
 *
 * <code>
 * ((1.2, 4.3, 3.4))<br>
 * (1.2, 3.4, 2.2)<br>
 * [[1.2, 34.3, 1.2, 1.25]]<br>
 * [1.2, 3.4]<br>
 * </code>
 *
 * @param sparkSession
 *            Spark Session
 * @param inputDF
 *            dataframe of comma-separated row strings to convert to
 *            dataframe of ml.linalg.Vector rows
 * @return dataframe of ml.linalg.Vector rows
 */
public static Dataset<Row> stringDataFrameToVectorDataFrame(SparkSession sparkSession, Dataset<Row> inputDF) {
    StructField[] oldSchema = inputDF.schema().fields();
    StructField[] newSchema = new StructField[oldSchema.length];
    for (int i = 0; i < oldSchema.length; i++) {
        String colName = oldSchema[i].name();
        newSchema[i] = DataTypes.createStructField(colName, new VectorUDT(), true);
    }
    // converter
    class StringToVector implements Function<Tuple2<Row, Long>, Row> {

        private static final long serialVersionUID = -4733816995375745659L;

        @Override
        public Row call(Tuple2<Row, Long> arg0) throws Exception {
            Row oldRow = arg0._1;
            int oldNumCols = oldRow.length();
            if (oldNumCols > 1) {
                throw new DMLRuntimeException("The row must have at most one column");
            }
            // parse the various strings. i.e
            // ((1.2, 4.3, 3.4)) or (1.2, 3.4, 2.2)
            // [[1.2, 34.3, 1.2, 1.2]] or [1.2, 3.4]
            Object[] fields = new Object[oldNumCols];
            ArrayList<Object> fieldsArr = new ArrayList<Object>();
            for (int i = 0; i < oldRow.length(); i++) {
                Object ci = oldRow.get(i);
                if (ci == null) {
                    fieldsArr.add(null);
                } else if (ci instanceof String) {
                    String cis = (String) ci;
                    StringBuffer sb = new StringBuffer(cis.trim());
                    for (int nid = 0; i < 2; i++) {
                        // nesting
                        if ((sb.charAt(0) == '(' && sb.charAt(sb.length() - 1) == ')') || (sb.charAt(0) == '[' && sb.charAt(sb.length() - 1) == ']')) {
                            sb.deleteCharAt(0);
                            sb.setLength(sb.length() - 1);
                        }
                    }
                    // have the replace code
                    String ncis = "[" + sb.toString().replaceAll(" *, *", ",") + "]";
                    try {
                        // ncis [ ] will always result in double array return type
                        double[] doubles = (double[]) NumericParser.parse(ncis);
                        Vector dense = Vectors.dense(doubles);
                        fieldsArr.add(dense);
                    } catch (Exception e) {
                        // can't catch SparkException here in Java apparently
                        throw new DMLRuntimeException("Error converting to double array. " + e.getMessage(), e);
                    }
                } else {
                    throw new DMLRuntimeException("Only String is supported");
                }
            }
            Row row = RowFactory.create(fieldsArr.toArray());
            return row;
        }
    }
    // output DF
    JavaRDD<Row> newRows = inputDF.rdd().toJavaRDD().zipWithIndex().map(new StringToVector());
    Dataset<Row> outDF = sparkSession.createDataFrame(newRows.rdd(), DataTypes.createStructType(newSchema));
    return outDF;
}
Also used : VectorUDT(org.apache.spark.ml.linalg.VectorUDT) ArrayList(java.util.ArrayList) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) IOException(java.io.IOException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) Function(org.apache.spark.api.java.function.Function) StructField(org.apache.spark.sql.types.StructField) Tuple2(scala.Tuple2) Row(org.apache.spark.sql.Row) Vector(org.apache.spark.ml.linalg.Vector)

Aggregations

VectorUDT (org.apache.spark.ml.linalg.VectorUDT)16 StructField (org.apache.spark.sql.types.StructField)16 Row (org.apache.spark.sql.Row)14 StructType (org.apache.spark.sql.types.StructType)13 ArrayList (java.util.ArrayList)12 DenseVector (org.apache.spark.ml.linalg.DenseVector)11 Vector (org.apache.spark.ml.linalg.Vector)11 Script (org.apache.sysml.api.mlcontext.Script)9 Test (org.junit.Test)9 Tuple2 (scala.Tuple2)5 MatrixMetadata (org.apache.sysml.api.mlcontext.MatrixMetadata)4 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)3 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)2 DataType (org.apache.spark.sql.types.DataType)2 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)2 IOException (java.io.IOException)1 VectorBuilder (net.jgp.labs.spark.x.udf.VectorBuilder)1 Function (org.apache.spark.api.java.function.Function)1 PairFlatMapFunction (org.apache.spark.api.java.function.PairFlatMapFunction)1 LabeledPoint (org.apache.spark.ml.feature.LabeledPoint)1