Search in sources :

Example 1 with ColumnarStructureX

use of edu.sdsc.mmtf.spark.utils.ColumnarStructureX in project mm-dev by sbl-sdsc.

the class StructureAligner method getAllVsAllAlignments.

/**
 * Calculates all vs. all structural alignments of protein chains using the
 * specified alignment algorithm. The input structures must contain single
 * protein chains.
 *
 * @param targets structures containing single protein chains
 * @param alignmentAlgorithm name of the algorithm
 * @return dataset with alignment metrics
 */
public static Dataset<Row> getAllVsAllAlignments(JavaPairRDD<String, StructureDataInterface> targets, String alignmentAlgorithm) {
    SparkSession session = SparkSession.builder().getOrCreate();
    JavaSparkContext sc = new JavaSparkContext(session.sparkContext());
    // create a list of chainName/ C Alpha coordinates
    List<Tuple2<String, Point3d[]>> chains = targets.mapValues(s -> new ColumnarStructureX(s, true).getcAlphaCoordinates()).collect();
    // create an RDD of all pair indices (0,1), (0,2), ..., (1,2), (1,3), ...
    JavaRDD<Tuple2<Integer, Integer>> pairs = getPairs(sc, chains.size());
    // calculate structural alignments for all pairs.
    // broadcast (copy) chains to all worker nodes for efficient processing.
    // for each pair there can be zero or more solutions, therefore we flatmap the pairs.
    JavaRDD<Row> rows = pairs.flatMap(new StructuralAlignmentMapper(sc.broadcast(chains), alignmentAlgorithm));
    // convert rows to a dataset
    return session.createDataFrame(rows, getSchema());
}
Also used : IntStream(java.util.stream.IntStream) DataTypes(org.apache.spark.sql.types.DataTypes) StructField(org.apache.spark.sql.types.StructField) StructType(org.apache.spark.sql.types.StructType) Iterator(java.util.Iterator) Dataset(org.apache.spark.sql.Dataset) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ColumnarStructureX(edu.sdsc.mmtf.spark.utils.ColumnarStructureX) Row(org.apache.spark.sql.Row) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) Serializable(java.io.Serializable) ArrayList(java.util.ArrayList) List(java.util.List) StructureDataInterface(org.rcsb.mmtf.api.StructureDataInterface) Point3d(javax.vecmath.Point3d) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) SparkSession(org.apache.spark.sql.SparkSession) SparkSession(org.apache.spark.sql.SparkSession) Tuple2(scala.Tuple2) Point3d(javax.vecmath.Point3d) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Row(org.apache.spark.sql.Row) ColumnarStructureX(edu.sdsc.mmtf.spark.utils.ColumnarStructureX)

Example 2 with ColumnarStructureX

use of edu.sdsc.mmtf.spark.utils.ColumnarStructureX in project mm-dev by sbl-sdsc.

the class StructureAligner method getQueryVsAllAlignments.

/**
 * Calculates structural alignments between a query and a target set of protein chains
 * using the specified alignment algorithm. An input structures must contain single
 * protein chains.
 *
 * @param targets structures containing single protein chains
 * @param alignmentAlgorithm name of the algorithm
 * @return dataset with alignment metrics
 */
public static Dataset<Row> getQueryVsAllAlignments(JavaPairRDD<String, StructureDataInterface> queries, JavaPairRDD<String, StructureDataInterface> targets, String alignmentAlgorithm) {
    SparkSession session = SparkSession.builder().getOrCreate();
    // spark context should not be closed here
    @SuppressWarnings("resource") JavaSparkContext sc = new JavaSparkContext(session.sparkContext());
    List<Tuple2<String, Point3d[]>> chains = new ArrayList<>();
    // create a list of chainName/ C Alpha coordinates for query chains
    chains.addAll(queries.mapValues(s -> new ColumnarStructureX(s, true).getcAlphaCoordinates()).collect());
    int querySize = chains.size();
    // create a list of chainName/ C Alpha coordinates for target chains
    chains.addAll(targets.mapValues(s -> new ColumnarStructureX(s, true).getcAlphaCoordinates()).collect());
    // create an RDD with indices for all query - target pairs (q, t)
    List<Tuple2<Integer, Integer>> pairList = new ArrayList<>(chains.size());
    for (int q = 0; q < querySize; q++) {
        for (int t = querySize; t < chains.size(); t++) {
            pairList.add(new Tuple2<Integer, Integer>(q, t));
        }
    }
    JavaRDD<Tuple2<Integer, Integer>> pairs = sc.parallelize(pairList, NUM_TASKS * sc.defaultParallelism());
    // calculate structural alignments for all pairs.
    // the chains are broadcast (copied) to all worker nodes for efficient processing
    JavaRDD<Row> rows = pairs.flatMap(new StructuralAlignmentMapper(sc.broadcast(chains), alignmentAlgorithm));
    // convert rows to a dataset
    return session.createDataFrame(rows, getSchema());
}
Also used : SparkSession(org.apache.spark.sql.SparkSession) ArrayList(java.util.ArrayList) ColumnarStructureX(edu.sdsc.mmtf.spark.utils.ColumnarStructureX) Tuple2(scala.Tuple2) Point3d(javax.vecmath.Point3d) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Row(org.apache.spark.sql.Row)

Aggregations

ColumnarStructureX (edu.sdsc.mmtf.spark.utils.ColumnarStructureX)2 ArrayList (java.util.ArrayList)2 Point3d (javax.vecmath.Point3d)2 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)2 Row (org.apache.spark.sql.Row)2 SparkSession (org.apache.spark.sql.SparkSession)2 Tuple2 (scala.Tuple2)2 Serializable (java.io.Serializable)1 Iterator (java.util.Iterator)1 List (java.util.List)1 Collectors (java.util.stream.Collectors)1 IntStream (java.util.stream.IntStream)1 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)1 JavaRDD (org.apache.spark.api.java.JavaRDD)1 FlatMapFunction (org.apache.spark.api.java.function.FlatMapFunction)1 Dataset (org.apache.spark.sql.Dataset)1 DataTypes (org.apache.spark.sql.types.DataTypes)1 StructField (org.apache.spark.sql.types.StructField)1 StructType (org.apache.spark.sql.types.StructType)1 StructureDataInterface (org.rcsb.mmtf.api.StructureDataInterface)1