use of org.apache.spark.ml.linalg.VectorUDT in project incubator-systemml by apache.
the class RDDConverterUtilsExt method stringDataFrameToVectorDataFrame.
/**
* Convert a dataframe of comma-separated string rows to a dataframe of
* ml.linalg.Vector rows.
*
* <p>
* Example input rows:<br>
*
* <code>
* ((1.2, 4.3, 3.4))<br>
* (1.2, 3.4, 2.2)<br>
* [[1.2, 34.3, 1.2, 1.25]]<br>
* [1.2, 3.4]<br>
* </code>
*
* @param sparkSession
* Spark Session
* @param inputDF
* dataframe of comma-separated row strings to convert to
* dataframe of ml.linalg.Vector rows
* @return dataframe of ml.linalg.Vector rows
*/
public static Dataset<Row> stringDataFrameToVectorDataFrame(SparkSession sparkSession, Dataset<Row> inputDF) {
StructField[] oldSchema = inputDF.schema().fields();
StructField[] newSchema = new StructField[oldSchema.length];
for (int i = 0; i < oldSchema.length; i++) {
String colName = oldSchema[i].name();
newSchema[i] = DataTypes.createStructField(colName, new VectorUDT(), true);
}
// converter
class StringToVector implements Function<Tuple2<Row, Long>, Row> {
private static final long serialVersionUID = -4733816995375745659L;
@Override
public Row call(Tuple2<Row, Long> arg0) throws Exception {
Row oldRow = arg0._1;
int oldNumCols = oldRow.length();
if (oldNumCols > 1) {
throw new DMLRuntimeException("The row must have at most one column");
}
// parse the various strings. i.e
// ((1.2, 4.3, 3.4)) or (1.2, 3.4, 2.2)
// [[1.2, 34.3, 1.2, 1.2]] or [1.2, 3.4]
Object[] fields = new Object[oldNumCols];
ArrayList<Object> fieldsArr = new ArrayList<Object>();
for (int i = 0; i < oldRow.length(); i++) {
Object ci = oldRow.get(i);
if (ci == null) {
fieldsArr.add(null);
} else if (ci instanceof String) {
String cis = (String) ci;
StringBuffer sb = new StringBuffer(cis.trim());
for (int nid = 0; i < 2; i++) {
// nesting
if ((sb.charAt(0) == '(' && sb.charAt(sb.length() - 1) == ')') || (sb.charAt(0) == '[' && sb.charAt(sb.length() - 1) == ']')) {
sb.deleteCharAt(0);
sb.setLength(sb.length() - 1);
}
}
// have the replace code
String ncis = "[" + sb.toString().replaceAll(" *, *", ",") + "]";
try {
// ncis [ ] will always result in double array return type
double[] doubles = (double[]) NumericParser.parse(ncis);
Vector dense = Vectors.dense(doubles);
fieldsArr.add(dense);
} catch (Exception e) {
// can't catch SparkException here in Java apparently
throw new DMLRuntimeException("Error converting to double array. " + e.getMessage(), e);
}
} else {
throw new DMLRuntimeException("Only String is supported");
}
}
Row row = RowFactory.create(fieldsArr.toArray());
return row;
}
}
// output DF
JavaRDD<Row> newRows = inputDF.rdd().toJavaRDD().zipWithIndex().map(new StringToVector());
Dataset<Row> outDF = sparkSession.createDataFrame(newRows.rdd(), DataTypes.createStructType(newSchema));
return outDF;
}
Aggregations