Search in sources :

Example 11 with Vector

use of org.apache.spark.ml.linalg.Vector in project incubator-systemml by apache.

the class MLContextTest method testDataFrameSumDMLVectorWithNoIDColumn.

@Test
public void testDataFrameSumDMLVectorWithNoIDColumn() {
    System.out.println("MLContextTest - DataFrame sum DML, vector with no ID column");
    List<Vector> list = new ArrayList<Vector>();
    list.add(Vectors.dense(1.0, 2.0, 3.0));
    list.add(Vectors.dense(4.0, 5.0, 6.0));
    list.add(Vectors.dense(7.0, 8.0, 9.0));
    JavaRDD<Vector> javaRddVector = sc.parallelize(list);
    JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow());
    List<StructField> fields = new ArrayList<StructField>();
    fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
    StructType schema = DataTypes.createStructType(fields);
    Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);
    MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR);
    Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
    setExpectedStdOut("sum: 45.0");
    ml.execute(script);
}
Also used : Script(org.apache.sysml.api.mlcontext.Script) VectorUDT(org.apache.spark.ml.linalg.VectorUDT) StructType(org.apache.spark.sql.types.StructType) ArrayList(java.util.ArrayList) StructField(org.apache.spark.sql.types.StructField) Row(org.apache.spark.sql.Row) MatrixMetadata(org.apache.sysml.api.mlcontext.MatrixMetadata) Vector(org.apache.spark.ml.linalg.Vector) DenseVector(org.apache.spark.ml.linalg.DenseVector) Test(org.junit.Test)

Example 12 with Vector

use of org.apache.spark.ml.linalg.Vector in project incubator-systemml by apache.

the class MLContextTest method testDataFrameSumDMLVectorWithIDColumn.

@Test
public void testDataFrameSumDMLVectorWithIDColumn() {
    System.out.println("MLContextTest - DataFrame sum DML, vector with ID column");
    List<Tuple2<Double, Vector>> list = new ArrayList<Tuple2<Double, Vector>>();
    list.add(new Tuple2<Double, Vector>(1.0, Vectors.dense(1.0, 2.0, 3.0)));
    list.add(new Tuple2<Double, Vector>(2.0, Vectors.dense(4.0, 5.0, 6.0)));
    list.add(new Tuple2<Double, Vector>(3.0, Vectors.dense(7.0, 8.0, 9.0)));
    JavaRDD<Tuple2<Double, Vector>> javaRddTuple = sc.parallelize(list);
    JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
    List<StructField> fields = new ArrayList<StructField>();
    fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true));
    fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
    StructType schema = DataTypes.createStructType(fields);
    Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);
    MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX);
    Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
    setExpectedStdOut("sum: 45.0");
    ml.execute(script);
}
Also used : Script(org.apache.sysml.api.mlcontext.Script) VectorUDT(org.apache.spark.ml.linalg.VectorUDT) StructType(org.apache.spark.sql.types.StructType) ArrayList(java.util.ArrayList) StructField(org.apache.spark.sql.types.StructField) Tuple2(scala.Tuple2) Row(org.apache.spark.sql.Row) MatrixMetadata(org.apache.sysml.api.mlcontext.MatrixMetadata) Vector(org.apache.spark.ml.linalg.Vector) DenseVector(org.apache.spark.ml.linalg.DenseVector) Test(org.junit.Test)

Example 13 with Vector

use of org.apache.spark.ml.linalg.Vector in project incubator-systemml by apache.

the class MLContextTest method testDataFrameSumPYDMLVectorWithIDColumnNoFormatSpecified.

@Test
public void testDataFrameSumPYDMLVectorWithIDColumnNoFormatSpecified() {
    System.out.println("MLContextTest - DataFrame sum PYDML, vector with ID column, no format specified");
    List<Tuple2<Double, Vector>> list = new ArrayList<Tuple2<Double, Vector>>();
    list.add(new Tuple2<Double, Vector>(1.0, Vectors.dense(1.0, 2.0, 3.0)));
    list.add(new Tuple2<Double, Vector>(2.0, Vectors.dense(4.0, 5.0, 6.0)));
    list.add(new Tuple2<Double, Vector>(3.0, Vectors.dense(7.0, 8.0, 9.0)));
    JavaRDD<Tuple2<Double, Vector>> javaRddTuple = sc.parallelize(list);
    JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
    List<StructField> fields = new ArrayList<StructField>();
    fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true));
    fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
    StructType schema = DataTypes.createStructType(fields);
    Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);
    Script script = dml("print('sum: ' + sum(M))").in("M", dataFrame);
    setExpectedStdOut("sum: 45.0");
    ml.execute(script);
}
Also used : Script(org.apache.sysml.api.mlcontext.Script) VectorUDT(org.apache.spark.ml.linalg.VectorUDT) StructType(org.apache.spark.sql.types.StructType) ArrayList(java.util.ArrayList) StructField(org.apache.spark.sql.types.StructField) Tuple2(scala.Tuple2) Row(org.apache.spark.sql.Row) Vector(org.apache.spark.ml.linalg.Vector) DenseVector(org.apache.spark.ml.linalg.DenseVector) Test(org.junit.Test)

Example 14 with Vector

use of org.apache.spark.ml.linalg.Vector in project incubator-systemml by apache.

the class MLContextTest method testDataFrameSumPYDMLMllibVectorWithNoIDColumn.

@Test
public void testDataFrameSumPYDMLMllibVectorWithNoIDColumn() {
    System.out.println("MLContextTest - DataFrame sum PYDML, mllib vector with no ID column");
    List<org.apache.spark.mllib.linalg.Vector> list = new ArrayList<org.apache.spark.mllib.linalg.Vector>();
    list.add(org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0));
    list.add(org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0));
    list.add(org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0));
    JavaRDD<org.apache.spark.mllib.linalg.Vector> javaRddVector = sc.parallelize(list);
    JavaRDD<Row> javaRddRow = javaRddVector.map(new MllibVectorRow());
    List<StructField> fields = new ArrayList<StructField>();
    fields.add(DataTypes.createStructField("C1", new org.apache.spark.mllib.linalg.VectorUDT(), true));
    StructType schema = DataTypes.createStructType(fields);
    Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);
    MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR);
    Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame, mm);
    setExpectedStdOut("sum: 45.0");
    ml.execute(script);
}
Also used : Script(org.apache.sysml.api.mlcontext.Script) VectorUDT(org.apache.spark.ml.linalg.VectorUDT) StructType(org.apache.spark.sql.types.StructType) ArrayList(java.util.ArrayList) StructField(org.apache.spark.sql.types.StructField) Row(org.apache.spark.sql.Row) MatrixMetadata(org.apache.sysml.api.mlcontext.MatrixMetadata) Vector(org.apache.spark.ml.linalg.Vector) DenseVector(org.apache.spark.ml.linalg.DenseVector) Test(org.junit.Test)

Example 15 with Vector

use of org.apache.spark.ml.linalg.Vector in project incubator-systemml by apache.

the class RDDConverterUtilsExt method stringDataFrameToVectorDataFrame.

/**
 * Convert a dataframe of comma-separated string rows to a dataframe of
 * ml.linalg.Vector rows.
 *
 * <p>
 * Example input rows:<br>
 *
 * <code>
 * ((1.2, 4.3, 3.4))<br>
 * (1.2, 3.4, 2.2)<br>
 * [[1.2, 34.3, 1.2, 1.25]]<br>
 * [1.2, 3.4]<br>
 * </code>
 *
 * @param sparkSession
 *            Spark Session
 * @param inputDF
 *            dataframe of comma-separated row strings to convert to
 *            dataframe of ml.linalg.Vector rows
 * @return dataframe of ml.linalg.Vector rows
 */
public static Dataset<Row> stringDataFrameToVectorDataFrame(SparkSession sparkSession, Dataset<Row> inputDF) {
    StructField[] oldSchema = inputDF.schema().fields();
    StructField[] newSchema = new StructField[oldSchema.length];
    for (int i = 0; i < oldSchema.length; i++) {
        String colName = oldSchema[i].name();
        newSchema[i] = DataTypes.createStructField(colName, new VectorUDT(), true);
    }
    // converter
    class StringToVector implements Function<Tuple2<Row, Long>, Row> {

        private static final long serialVersionUID = -4733816995375745659L;

        @Override
        public Row call(Tuple2<Row, Long> arg0) throws Exception {
            Row oldRow = arg0._1;
            int oldNumCols = oldRow.length();
            if (oldNumCols > 1) {
                throw new DMLRuntimeException("The row must have at most one column");
            }
            // parse the various strings. i.e
            // ((1.2, 4.3, 3.4)) or (1.2, 3.4, 2.2)
            // [[1.2, 34.3, 1.2, 1.2]] or [1.2, 3.4]
            Object[] fields = new Object[oldNumCols];
            ArrayList<Object> fieldsArr = new ArrayList<Object>();
            for (int i = 0; i < oldRow.length(); i++) {
                Object ci = oldRow.get(i);
                if (ci == null) {
                    fieldsArr.add(null);
                } else if (ci instanceof String) {
                    String cis = (String) ci;
                    StringBuffer sb = new StringBuffer(cis.trim());
                    for (int nid = 0; i < 2; i++) {
                        // nesting
                        if ((sb.charAt(0) == '(' && sb.charAt(sb.length() - 1) == ')') || (sb.charAt(0) == '[' && sb.charAt(sb.length() - 1) == ']')) {
                            sb.deleteCharAt(0);
                            sb.setLength(sb.length() - 1);
                        }
                    }
                    // have the replace code
                    String ncis = "[" + sb.toString().replaceAll(" *, *", ",") + "]";
                    try {
                        // ncis [ ] will always result in double array return type
                        double[] doubles = (double[]) NumericParser.parse(ncis);
                        Vector dense = Vectors.dense(doubles);
                        fieldsArr.add(dense);
                    } catch (Exception e) {
                        // can't catch SparkException here in Java apparently
                        throw new DMLRuntimeException("Error converting to double array. " + e.getMessage(), e);
                    }
                } else {
                    throw new DMLRuntimeException("Only String is supported");
                }
            }
            Row row = RowFactory.create(fieldsArr.toArray());
            return row;
        }
    }
    // output DF
    JavaRDD<Row> newRows = inputDF.rdd().toJavaRDD().zipWithIndex().map(new StringToVector());
    Dataset<Row> outDF = sparkSession.createDataFrame(newRows.rdd(), DataTypes.createStructType(newSchema));
    return outDF;
}
Also used : VectorUDT(org.apache.spark.ml.linalg.VectorUDT) ArrayList(java.util.ArrayList) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) IOException(java.io.IOException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) Function(org.apache.spark.api.java.function.Function) StructField(org.apache.spark.sql.types.StructField) Tuple2(scala.Tuple2) Row(org.apache.spark.sql.Row) Vector(org.apache.spark.ml.linalg.Vector)

Aggregations

Vector (org.apache.spark.ml.linalg.Vector)15 VectorUDT (org.apache.spark.ml.linalg.VectorUDT)15 Row (org.apache.spark.sql.Row)15 StructField (org.apache.spark.sql.types.StructField)15 StructType (org.apache.spark.sql.types.StructType)14 ArrayList (java.util.ArrayList)13 DenseVector (org.apache.spark.ml.linalg.DenseVector)13 Script (org.apache.sysml.api.mlcontext.Script)13 Test (org.junit.Test)13 MatrixMetadata (org.apache.sysml.api.mlcontext.MatrixMetadata)8 Tuple2 (scala.Tuple2)7 IOException (java.io.IOException)1 VectorBuilder (net.jgp.labs.spark.x.udf.VectorBuilder)1 Function (org.apache.spark.api.java.function.Function)1 PairFlatMapFunction (org.apache.spark.api.java.function.PairFlatMapFunction)1 LinearRegression (org.apache.spark.ml.regression.LinearRegression)1 LinearRegressionModel (org.apache.spark.ml.regression.LinearRegressionModel)1 LinearRegressionTrainingSummary (org.apache.spark.ml.regression.LinearRegressionTrainingSummary)1 SparkSession (org.apache.spark.sql.SparkSession)1 DoubleType (org.apache.spark.sql.types.DoubleType)1