use of org.apache.spark.sql.Row in project incubator-systemml by apache.
the class RDDConverterUtils method dataFrameToBinaryBlock.
public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlock(JavaSparkContext sc, Dataset<Row> df, MatrixCharacteristics mc, boolean containsID, boolean isVector) {
// determine unknown dimensions and sparsity if required
if (!mc.dimsKnown(true)) {
LongAccumulator aNnz = sc.sc().longAccumulator("nnz");
JavaRDD<Row> tmp = df.javaRDD().map(new DataFrameAnalysisFunction(aNnz, containsID, isVector));
long rlen = tmp.count();
long clen = !isVector ? df.columns().length - (containsID ? 1 : 0) : ((Vector) tmp.first().get(containsID ? 1 : 0)).size();
long nnz = UtilFunctions.toLong(aNnz.value());
mc.set(rlen, clen, mc.getRowsPerBlock(), mc.getColsPerBlock(), nnz);
}
// ensure valid blocksizes
if (mc.getRowsPerBlock() <= 1 || mc.getColsPerBlock() <= 1) {
mc.setBlockSize(ConfigurationManager.getBlocksize());
}
// construct or reuse row ids
JavaPairRDD<Row, Long> prepinput = containsID ? df.javaRDD().mapToPair(new DataFrameExtractIDFunction(df.schema().fieldIndex(DF_ID_COLUMN))) : // zip row index
df.javaRDD().zipWithIndex();
// convert csv rdd to binary block rdd (w/ partial blocks)
boolean sparse = requiresSparseAllocation(prepinput, mc);
JavaPairRDD<MatrixIndexes, MatrixBlock> out = prepinput.mapPartitionsToPair(new DataFrameToBinaryBlockFunction(mc, sparse, containsID, isVector));
// aggregate partial matrix blocks (w/ preferred number of output
// partitions as the data is likely smaller in binary block format,
// but also to bound the size of partitions for compressed inputs)
int parts = SparkUtils.getNumPreferredPartitions(mc, out);
return RDDAggregateUtils.mergeByKey(out, parts, false);
}
use of org.apache.spark.sql.Row in project incubator-systemml by apache.
the class RDDConverterUtilsExt method addIDToDataFrame.
/**
* Add element indices as new column to DataFrame
*
* @param df input data frame
* @param sparkSession the Spark Session
* @param nameOfCol name of index column
* @return new data frame
*/
public static Dataset<Row> addIDToDataFrame(Dataset<Row> df, SparkSession sparkSession, String nameOfCol) {
StructField[] oldSchema = df.schema().fields();
StructField[] newSchema = new StructField[oldSchema.length + 1];
for (int i = 0; i < oldSchema.length; i++) {
newSchema[i] = oldSchema[i];
}
newSchema[oldSchema.length] = DataTypes.createStructField(nameOfCol, DataTypes.DoubleType, false);
// JavaRDD<Row> newRows = df.rdd().toJavaRDD().map(new AddRowID());
JavaRDD<Row> newRows = df.rdd().toJavaRDD().zipWithIndex().map(new AddRowID());
return sparkSession.createDataFrame(newRows, new StructType(newSchema));
}
use of org.apache.spark.sql.Row in project incubator-systemml by apache.
the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrameNonNumbers.
@Test(expected = SparkException.class)
public void testStringDataFrameToVectorDataFrameNonNumbers() {
List<String> list = new ArrayList<String>();
list.add("[cheeseburger,fries]");
JavaRDD<String> javaRddString = sc.parallelize(list);
JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
// trigger evaluation to throw exception
outDF.collectAsList();
}
use of org.apache.spark.sql.Row in project incubator-systemml by apache.
the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrameNull.
@Test
public void testStringDataFrameToVectorDataFrameNull() {
List<String> list = new ArrayList<String>();
list.add("[1.2, 3.4]");
list.add(null);
JavaRDD<String> javaRddString = sc.parallelize(list);
JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
List<String> expectedResults = new ArrayList<String>();
expectedResults.add("[[1.2,3.4]]");
expectedResults.add("[null]");
List<Row> outputList = outDF.collectAsList();
for (Row row : outputList) {
assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString()));
}
}
use of org.apache.spark.sql.Row in project incubator-systemml by apache.
the class RDDConverterUtilsExtTest method testStringDataFrameToVectorDataFrame.
@Test
public void testStringDataFrameToVectorDataFrame() {
List<String> list = new ArrayList<String>();
list.add("((1.2, 4.3, 3.4))");
list.add("(1.2, 3.4, 2.2)");
list.add("[[1.2, 34.3, 1.2, 1.25]]");
list.add("[1.2, 3.4]");
JavaRDD<String> javaRddString = sc.parallelize(list);
JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
List<String> expectedResults = new ArrayList<String>();
expectedResults.add("[[1.2,4.3,3.4]]");
expectedResults.add("[[1.2,3.4,2.2]]");
expectedResults.add("[[1.2,34.3,1.2,1.25]]");
expectedResults.add("[[1.2,3.4]]");
List<Row> outputList = outDF.collectAsList();
for (Row row : outputList) {
assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString()));
}
}
Aggregations