use of org.apache.spark.api.java.JavaPairRDD in project gatk by broadinstitute.
the class AddContextDataToReadSpark method addUsingOverlapsPartitioning.
/**
* Add context data ({@link ReadContextData}) to reads, using overlaps partitioning to avoid a shuffle.
* @param ctx the Spark context
* @param mappedReads the coordinate-sorted reads
* @param referenceSource the reference source
* @param variants the coordinate-sorted variants
* @param sequenceDictionary the sequence dictionary for the reads
* @param shardSize the maximum size of each shard, in bases
* @param shardPadding amount of extra context around each shard, in bases
* @return a RDD of read-context pairs, in coordinate-sorted order
*/
private static JavaPairRDD<GATKRead, ReadContextData> addUsingOverlapsPartitioning(final JavaSparkContext ctx, final JavaRDD<GATKRead> mappedReads, final ReferenceMultiSource referenceSource, final JavaRDD<GATKVariant> variants, final SAMSequenceDictionary sequenceDictionary, final int shardSize, final int shardPadding) {
final List<SimpleInterval> intervals = IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
// use unpadded shards (padding is only needed for reference bases)
final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, shardSize, 0, sequenceDictionary).stream()).collect(Collectors.toList());
final Broadcast<ReferenceMultiSource> bReferenceSource = ctx.broadcast(referenceSource);
final IntervalsSkipList<GATKVariant> variantSkipList = new IntervalsSkipList<>(variants.collect());
final Broadcast<IntervalsSkipList<GATKVariant>> variantsBroadcast = ctx.broadcast(variantSkipList);
int maxLocatableSize = Math.min(shardSize, shardPadding);
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, mappedReads, GATKRead.class, sequenceDictionary, intervalShards, maxLocatableSize);
return shardedReads.flatMapToPair(new PairFlatMapFunction<Shard<GATKRead>, GATKRead, ReadContextData>() {
private static final long serialVersionUID = 1L;
@Override
public Iterator<Tuple2<GATKRead, ReadContextData>> call(Shard<GATKRead> shard) throws Exception {
// get reference bases for this shard (padded)
SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(shardPadding, sequenceDictionary);
ReferenceBases referenceBases = bReferenceSource.getValue().getReferenceBases(null, paddedInterval);
final IntervalsSkipList<GATKVariant> intervalsSkipList = variantsBroadcast.getValue();
Iterator<Tuple2<GATKRead, ReadContextData>> transform = Iterators.transform(shard.iterator(), new Function<GATKRead, Tuple2<GATKRead, ReadContextData>>() {
@Nullable
@Override
public Tuple2<GATKRead, ReadContextData> apply(@Nullable GATKRead r) {
List<GATKVariant> overlappingVariants;
if (SimpleInterval.isValid(r.getContig(), r.getStart(), r.getEnd())) {
overlappingVariants = intervalsSkipList.getOverlapping(new SimpleInterval(r));
} else {
//Sometimes we have reads that do not form valid intervals (reads that do not consume any ref bases, eg CIGAR 61S90I
//In those cases, we'll just say that nothing overlaps the read
overlappingVariants = Collections.emptyList();
}
return new Tuple2<>(r, new ReadContextData(referenceBases, overlappingVariants));
}
});
// only include reads that start in the shard
return Iterators.filter(transform, r -> r._1().getStart() >= shard.getStart() && r._1().getStart() <= shard.getEnd());
}
});
}
use of org.apache.spark.api.java.JavaPairRDD in project gatk by broadinstitute.
the class CompareDuplicatesSpark method runTool.
@Override
protected void runTool(final JavaSparkContext ctx) {
JavaRDD<GATKRead> firstReads = filteredReads(getReads(), readArguments.getReadFilesNames().get(0));
ReadsSparkSource readsSource2 = new ReadsSparkSource(ctx, readArguments.getReadValidationStringency());
JavaRDD<GATKRead> secondReads = filteredReads(readsSource2.getParallelReads(input2, null, getIntervals(), bamPartitionSplitSize), input2);
// Start by verifying that we have same number of reads and duplicates in each BAM.
long firstBamSize = firstReads.count();
long secondBamSize = secondReads.count();
if (firstBamSize != secondBamSize) {
throw new UserException("input bams have different numbers of mapped reads: " + firstBamSize + "," + secondBamSize);
}
System.out.println("processing bams with " + firstBamSize + " mapped reads");
long firstDupesCount = firstReads.filter(GATKRead::isDuplicate).count();
long secondDupesCount = secondReads.filter(GATKRead::isDuplicate).count();
if (firstDupesCount != secondDupesCount) {
System.out.println("BAMs have different number of total duplicates: " + firstDupesCount + "," + secondDupesCount);
}
System.out.println("first and second: " + firstDupesCount + "," + secondDupesCount);
Broadcast<SAMFileHeader> bHeader = ctx.broadcast(getHeaderForReads());
// Group the reads of each BAM by MarkDuplicates key, then pair up the the reads for each BAM.
JavaPairRDD<String, GATKRead> firstKeyed = firstReads.mapToPair(read -> new Tuple2<>(ReadsKey.keyForFragment(bHeader.getValue(), read), read));
JavaPairRDD<String, GATKRead> secondKeyed = secondReads.mapToPair(read -> new Tuple2<>(ReadsKey.keyForFragment(bHeader.getValue(), read), read));
JavaPairRDD<String, Tuple2<Iterable<GATKRead>, Iterable<GATKRead>>> cogroup = firstKeyed.cogroup(secondKeyed, getRecommendedNumReducers());
// Produces an RDD of MatchTypes, e.g., EQUAL, DIFFERENT_REPRESENTATIVE_READ, etc. per MarkDuplicates key,
// which is approximately start position x strand.
JavaRDD<MatchType> tagged = cogroup.map(v1 -> {
SAMFileHeader header = bHeader.getValue();
Iterable<GATKRead> iFirstReads = v1._2()._1();
Iterable<GATKRead> iSecondReads = v1._2()._2();
return getDupes(iFirstReads, iSecondReads, header);
});
// TODO: We should also produce examples of reads that don't match to make debugging easier (#1263).
Map<MatchType, Integer> tagCountMap = tagged.mapToPair(v1 -> new Tuple2<>(v1, 1)).reduceByKey((v1, v2) -> v1 + v2).collectAsMap();
if (tagCountMap.get(MatchType.SIZE_UNEQUAL) != null) {
throw new UserException("The number of reads by the MarkDuplicates key were unequal, indicating that the BAMs are not the same");
}
if (tagCountMap.get(MatchType.READ_MISMATCH) != null) {
throw new UserException("The reads grouped by the MarkDuplicates key were not the same, indicating that the BAMs are not the same");
}
if (printSummary) {
MatchType[] values = MatchType.values();
Set<MatchType> matchTypes = Sets.newLinkedHashSet(Sets.newHashSet(values));
System.out.println("##############################");
matchTypes.forEach(s -> System.out.println(s + ": " + tagCountMap.getOrDefault(s, 0)));
}
if (throwOnDiff) {
for (MatchType s : MatchType.values()) {
if (s != MatchType.EQUAL) {
if (tagCountMap.get(s) != null)
throw new UserException("found difference between the two BAMs: " + s + " with count " + tagCountMap.get(s));
}
}
}
}
use of org.apache.spark.api.java.JavaPairRDD in project grakn by graknlabs.
the class GraknSparkComputer method submitWithExecutor.
@SuppressWarnings("PMD.UnusedFormalParameter")
private Future<ComputerResult> submitWithExecutor() {
jobGroupId = Integer.toString(ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE));
String jobDescription = this.vertexProgram == null ? this.mapReducers.toString() : this.vertexProgram + "+" + this.mapReducers;
// Use different output locations
this.sparkConfiguration.setProperty(Constants.GREMLIN_HADOOP_OUTPUT_LOCATION, this.sparkConfiguration.getString(Constants.GREMLIN_HADOOP_OUTPUT_LOCATION) + "/" + jobGroupId);
updateConfigKeys(sparkConfiguration);
final Future<ComputerResult> result = computerService.submit(() -> {
final long startTime = System.currentTimeMillis();
// apache and hadoop configurations that are used throughout the graph computer computation
final org.apache.commons.configuration.Configuration graphComputerConfiguration = new HadoopConfiguration(this.sparkConfiguration);
if (!graphComputerConfiguration.containsKey(Constants.SPARK_SERIALIZER)) {
graphComputerConfiguration.setProperty(Constants.SPARK_SERIALIZER, GryoSerializer.class.getCanonicalName());
}
graphComputerConfiguration.setProperty(Constants.GREMLIN_HADOOP_GRAPH_WRITER_HAS_EDGES, this.persist.equals(GraphComputer.Persist.EDGES));
final Configuration hadoopConfiguration = ConfUtil.makeHadoopConfiguration(graphComputerConfiguration);
final Storage fileSystemStorage = FileSystemStorage.open(hadoopConfiguration);
final boolean inputFromHDFS = FileInputFormat.class.isAssignableFrom(hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_READER, Object.class));
final boolean inputFromSpark = PersistedInputRDD.class.isAssignableFrom(hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_READER, Object.class));
final boolean outputToHDFS = FileOutputFormat.class.isAssignableFrom(hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_WRITER, Object.class));
final boolean outputToSpark = PersistedOutputRDD.class.isAssignableFrom(hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_WRITER, Object.class));
final boolean skipPartitioner = graphComputerConfiguration.getBoolean(Constants.GREMLIN_SPARK_SKIP_PARTITIONER, false);
final boolean skipPersist = graphComputerConfiguration.getBoolean(Constants.GREMLIN_SPARK_SKIP_GRAPH_CACHE, false);
if (inputFromHDFS) {
String inputLocation = Constants.getSearchGraphLocation(hadoopConfiguration.get(Constants.GREMLIN_HADOOP_INPUT_LOCATION), fileSystemStorage).orElse(null);
if (null != inputLocation) {
try {
graphComputerConfiguration.setProperty(Constants.MAPREDUCE_INPUT_FILEINPUTFORMAT_INPUTDIR, FileSystem.get(hadoopConfiguration).getFileStatus(new Path(inputLocation)).getPath().toString());
hadoopConfiguration.set(Constants.MAPREDUCE_INPUT_FILEINPUTFORMAT_INPUTDIR, FileSystem.get(hadoopConfiguration).getFileStatus(new Path(inputLocation)).getPath().toString());
} catch (final IOException e) {
throw new IllegalStateException(e.getMessage(), e);
}
}
}
final InputRDD inputRDD;
final OutputRDD outputRDD;
final boolean filtered;
try {
inputRDD = InputRDD.class.isAssignableFrom(hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_READER, Object.class)) ? hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_READER, InputRDD.class, InputRDD.class).newInstance() : InputFormatRDD.class.newInstance();
outputRDD = OutputRDD.class.isAssignableFrom(hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_WRITER, Object.class)) ? hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_WRITER, OutputRDD.class, OutputRDD.class).newInstance() : OutputFormatRDD.class.newInstance();
// if the input class can filter on load, then set the filters
if (inputRDD instanceof InputFormatRDD && GraphFilterAware.class.isAssignableFrom(hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_READER, InputFormat.class, InputFormat.class))) {
GraphFilterAware.storeGraphFilter(graphComputerConfiguration, hadoopConfiguration, this.graphFilter);
filtered = false;
} else if (inputRDD instanceof GraphFilterAware) {
((GraphFilterAware) inputRDD).setGraphFilter(this.graphFilter);
filtered = false;
} else
filtered = this.graphFilter.hasFilter();
} catch (final InstantiationException | IllegalAccessException e) {
throw new IllegalStateException(e.getMessage(), e);
}
// create the spark context from the graph computer configuration
final JavaSparkContext sparkContext = new JavaSparkContext(Spark.create(hadoopConfiguration));
final Storage sparkContextStorage = SparkContextStorage.open();
sparkContext.setJobGroup(jobGroupId, jobDescription);
GraknSparkMemory memory = null;
// delete output location
final String outputLocation = hadoopConfiguration.get(Constants.GREMLIN_HADOOP_OUTPUT_LOCATION, null);
if (null != outputLocation) {
if (outputToHDFS && fileSystemStorage.exists(outputLocation)) {
fileSystemStorage.rm(outputLocation);
}
if (outputToSpark && sparkContextStorage.exists(outputLocation)) {
sparkContextStorage.rm(outputLocation);
}
}
// the Spark application name will always be set by SparkContextStorage,
// thus, INFO the name to make it easier to debug
logger.debug(Constants.GREMLIN_HADOOP_SPARK_JOB_PREFIX + (null == this.vertexProgram ? "No VertexProgram" : this.vertexProgram) + "[" + this.mapReducers + "]");
// add the project jars to the cluster
this.loadJars(hadoopConfiguration, sparkContext);
updateLocalConfiguration(sparkContext, hadoopConfiguration);
// create a message-passing friendly rdd from the input rdd
boolean partitioned = false;
JavaPairRDD<Object, VertexWritable> loadedGraphRDD = inputRDD.readGraphRDD(graphComputerConfiguration, sparkContext);
// if there are vertex or edge filters, filter the loaded graph rdd prior to partitioning and persisting
if (filtered) {
this.logger.debug("Filtering the loaded graphRDD: " + this.graphFilter);
loadedGraphRDD = GraknSparkExecutor.applyGraphFilter(loadedGraphRDD, this.graphFilter);
}
// else partition it with HashPartitioner
if (loadedGraphRDD.partitioner().isPresent()) {
this.logger.debug("Using the existing partitioner associated with the loaded graphRDD: " + loadedGraphRDD.partitioner().get());
} else {
if (!skipPartitioner) {
final Partitioner partitioner = new HashPartitioner(this.workersSet ? this.workers : loadedGraphRDD.partitions().size());
this.logger.debug("Partitioning the loaded graphRDD: " + partitioner);
loadedGraphRDD = loadedGraphRDD.partitionBy(partitioner);
partitioned = true;
assert loadedGraphRDD.partitioner().isPresent();
} else {
// no easy way to test this with a test case
assert skipPartitioner == !loadedGraphRDD.partitioner().isPresent();
this.logger.debug("Partitioning has been skipped for the loaded graphRDD via " + Constants.GREMLIN_SPARK_SKIP_PARTITIONER);
}
}
// then this coalesce/repartition will not take place
if (this.workersSet) {
// ensures that the loaded graphRDD does not have more partitions than workers
if (loadedGraphRDD.partitions().size() > this.workers) {
loadedGraphRDD = loadedGraphRDD.coalesce(this.workers);
} else {
// ensures that the loaded graphRDD does not have less partitions than workers
if (loadedGraphRDD.partitions().size() < this.workers) {
loadedGraphRDD = loadedGraphRDD.repartition(this.workers);
}
}
}
// or else use default cache() which is MEMORY_ONLY
if (!skipPersist && (!inputFromSpark || partitioned || filtered)) {
loadedGraphRDD = loadedGraphRDD.persist(StorageLevel.fromString(hadoopConfiguration.get(Constants.GREMLIN_SPARK_GRAPH_STORAGE_LEVEL, "MEMORY_ONLY")));
}
// final graph with view
// (for persisting and/or mapReducing -- may be null and thus, possible to save space/time)
JavaPairRDD<Object, VertexWritable> computedGraphRDD = null;
try {
// //////////////////////////////
if (null != this.vertexProgram) {
memory = new GraknSparkMemory(this.vertexProgram, this.mapReducers, sparkContext);
// if there is a registered VertexProgramInterceptor, use it to bypass the GraphComputer semantics
if (graphComputerConfiguration.containsKey(Constants.GREMLIN_HADOOP_VERTEX_PROGRAM_INTERCEPTOR)) {
try {
final GraknSparkVertexProgramInterceptor<VertexProgram> interceptor = (GraknSparkVertexProgramInterceptor) Class.forName(graphComputerConfiguration.getString(Constants.GREMLIN_HADOOP_VERTEX_PROGRAM_INTERCEPTOR)).newInstance();
computedGraphRDD = interceptor.apply(this.vertexProgram, loadedGraphRDD, memory);
} catch (final ClassNotFoundException | IllegalAccessException | InstantiationException e) {
throw new IllegalStateException(e.getMessage());
}
} else {
// standard GraphComputer semantics
// get a configuration that will be propagated to all workers
final HadoopConfiguration vertexProgramConfiguration = new HadoopConfiguration();
this.vertexProgram.storeState(vertexProgramConfiguration);
// set up the vertex program and wire up configurations
this.vertexProgram.setup(memory);
JavaPairRDD<Object, ViewIncomingPayload<Object>> viewIncomingRDD = null;
memory.broadcastMemory(sparkContext);
// execute the vertex program
while (true) {
if (Thread.interrupted()) {
sparkContext.cancelAllJobs();
throw new TraversalInterruptedException();
}
memory.setInExecute(true);
viewIncomingRDD = GraknSparkExecutor.executeVertexProgramIteration(loadedGraphRDD, viewIncomingRDD, memory, graphComputerConfiguration, vertexProgramConfiguration);
memory.setInExecute(false);
if (this.vertexProgram.terminate(memory)) {
break;
} else {
memory.incrIteration();
memory.broadcastMemory(sparkContext);
}
}
// then generate a view+graph
if ((null != outputRDD && !this.persist.equals(Persist.NOTHING)) || !this.mapReducers.isEmpty()) {
computedGraphRDD = GraknSparkExecutor.prepareFinalGraphRDD(loadedGraphRDD, viewIncomingRDD, this.vertexProgram.getVertexComputeKeys());
assert null != computedGraphRDD && computedGraphRDD != loadedGraphRDD;
} else {
// ensure that the computedGraphRDD was not created
assert null == computedGraphRDD;
}
}
// ///////////////
// drop all transient memory keys
memory.complete();
// write the computed graph to the respective output (rdd or output format)
if (null != outputRDD && !this.persist.equals(Persist.NOTHING)) {
// the logic holds that a computeGraphRDD must be created at this point
assert null != computedGraphRDD;
outputRDD.writeGraphRDD(graphComputerConfiguration, computedGraphRDD);
}
}
final boolean computedGraphCreated = computedGraphRDD != null && computedGraphRDD != loadedGraphRDD;
if (!computedGraphCreated) {
computedGraphRDD = loadedGraphRDD;
}
final Memory.Admin finalMemory = null == memory ? new MapMemory() : new MapMemory(memory);
// ////////////////////////////
if (!this.mapReducers.isEmpty()) {
// create a mapReduceRDD for executing the map reduce jobs on
JavaPairRDD<Object, VertexWritable> mapReduceRDD = computedGraphRDD;
if (computedGraphCreated && !outputToSpark) {
// drop all the edges of the graph as they are not used in mapReduce processing
mapReduceRDD = computedGraphRDD.mapValues(vertexWritable -> {
vertexWritable.get().dropEdges(Direction.BOTH);
return vertexWritable;
});
// if there is only one MapReduce to execute, don't bother wasting the clock cycles.
if (this.mapReducers.size() > 1) {
mapReduceRDD = mapReduceRDD.persist(StorageLevel.fromString(hadoopConfiguration.get(Constants.GREMLIN_SPARK_GRAPH_STORAGE_LEVEL, "MEMORY_ONLY")));
}
}
for (final MapReduce mapReduce : this.mapReducers) {
// execute the map reduce job
final HadoopConfiguration newApacheConfiguration = new HadoopConfiguration(graphComputerConfiguration);
mapReduce.storeState(newApacheConfiguration);
// map
final JavaPairRDD mapRDD = GraknSparkExecutor.executeMap(mapReduceRDD, mapReduce, newApacheConfiguration);
// combine
final JavaPairRDD combineRDD = mapReduce.doStage(MapReduce.Stage.COMBINE) ? GraknSparkExecutor.executeCombine(mapRDD, newApacheConfiguration) : mapRDD;
// reduce
final JavaPairRDD reduceRDD = mapReduce.doStage(MapReduce.Stage.REDUCE) ? GraknSparkExecutor.executeReduce(combineRDD, mapReduce, newApacheConfiguration) : combineRDD;
// write the map reduce output back to disk and computer result memory
if (null != outputRDD) {
mapReduce.addResultToMemory(finalMemory, outputRDD.writeMemoryRDD(graphComputerConfiguration, mapReduce.getMemoryKey(), reduceRDD));
}
}
// if the mapReduceRDD is not simply the computed graph, unpersist the mapReduceRDD
if (computedGraphCreated && !outputToSpark) {
assert loadedGraphRDD != computedGraphRDD;
assert mapReduceRDD != computedGraphRDD;
mapReduceRDD.unpersist();
} else {
assert mapReduceRDD == computedGraphRDD;
}
}
// if the graphRDD was loaded from Spark, but then partitioned or filtered, its a different RDD
if (!inputFromSpark || partitioned || filtered) {
loadedGraphRDD.unpersist();
}
// then don't unpersist the computedGraphRDD/loadedGraphRDD
if ((!outputToSpark || this.persist.equals(GraphComputer.Persist.NOTHING)) && computedGraphCreated) {
computedGraphRDD.unpersist();
}
// delete any file system or rdd data if persist nothing
if (null != outputLocation && this.persist.equals(GraphComputer.Persist.NOTHING)) {
if (outputToHDFS) {
fileSystemStorage.rm(outputLocation);
}
if (outputToSpark) {
sparkContextStorage.rm(outputLocation);
}
}
// update runtime and return the newly computed graph
finalMemory.setRuntime(System.currentTimeMillis() - startTime);
// clear properties that should not be propagated in an OLAP chain
graphComputerConfiguration.clearProperty(Constants.GREMLIN_HADOOP_GRAPH_FILTER);
graphComputerConfiguration.clearProperty(Constants.GREMLIN_HADOOP_VERTEX_PROGRAM_INTERCEPTOR);
graphComputerConfiguration.clearProperty(Constants.GREMLIN_SPARK_SKIP_GRAPH_CACHE);
graphComputerConfiguration.clearProperty(Constants.GREMLIN_SPARK_SKIP_PARTITIONER);
return new DefaultComputerResult(InputOutputHelper.getOutputGraph(graphComputerConfiguration, this.resultGraph, this.persist), finalMemory.asImmutable());
} catch (Exception e) {
// So it throws the same exception as tinker does
throw new RuntimeException(e);
}
});
computerService.shutdown();
return result;
}
use of org.apache.spark.api.java.JavaPairRDD in project incubator-systemml by apache.
the class MLContextConversionUtil method frameObjectToDataFrame.
/**
* Convert a {@code FrameObject} to a {@code DataFrame}.
*
* @param frameObject
* the {@code FrameObject}
* @param sparkExecutionContext
* the Spark execution context
* @return the {@code FrameObject} converted to a {@code DataFrame}
*/
public static Dataset<Row> frameObjectToDataFrame(FrameObject frameObject, SparkExecutionContext sparkExecutionContext) {
try {
@SuppressWarnings("unchecked") JavaPairRDD<Long, FrameBlock> binaryBlockFrame = (JavaPairRDD<Long, FrameBlock>) sparkExecutionContext.getRDDHandleForFrameObject(frameObject, InputInfo.BinaryBlockInputInfo);
MatrixCharacteristics mc = frameObject.getMatrixCharacteristics();
return FrameRDDConverterUtils.binaryBlockToDataFrame(spark(), binaryBlockFrame, mc, frameObject.getSchema());
} catch (DMLRuntimeException e) {
throw new MLContextException("DMLRuntimeException while converting frame object to DataFrame", e);
}
}
use of org.apache.spark.api.java.JavaPairRDD in project incubator-systemml by apache.
the class MLContextUtil method convertInputType.
/**
* Convert input types to internal SystemML representations
*
* @param parameterName
* The name of the input parameter
* @param parameterValue
* The value of the input parameter
* @param metadata
* matrix/frame metadata
* @return input in SystemML data representation
*/
public static Data convertInputType(String parameterName, Object parameterValue, Metadata metadata) {
String name = parameterName;
Object value = parameterValue;
boolean hasMetadata = (metadata != null) ? true : false;
boolean hasMatrixMetadata = hasMetadata && (metadata instanceof MatrixMetadata) ? true : false;
boolean hasFrameMetadata = hasMetadata && (metadata instanceof FrameMetadata) ? true : false;
if (name == null) {
throw new MLContextException("Input parameter name is null");
} else if (value == null) {
throw new MLContextException("Input parameter value is null for: " + parameterName);
} else if (value instanceof JavaRDD<?>) {
@SuppressWarnings("unchecked") JavaRDD<String> javaRDD = (JavaRDD<String>) value;
if (hasMatrixMetadata) {
MatrixMetadata matrixMetadata = (MatrixMetadata) metadata;
if (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV) {
return MLContextConversionUtil.javaRDDStringIJVToMatrixObject(javaRDD, matrixMetadata);
} else {
return MLContextConversionUtil.javaRDDStringCSVToMatrixObject(javaRDD, matrixMetadata);
}
} else if (hasFrameMetadata) {
FrameMetadata frameMetadata = (FrameMetadata) metadata;
if (frameMetadata.getFrameFormat() == FrameFormat.IJV) {
return MLContextConversionUtil.javaRDDStringIJVToFrameObject(javaRDD, frameMetadata);
} else {
return MLContextConversionUtil.javaRDDStringCSVToFrameObject(javaRDD, frameMetadata);
}
} else if (!hasMetadata) {
String firstLine = javaRDD.first();
boolean isAllNumbers = isCSVLineAllNumbers(firstLine);
if (isAllNumbers) {
return MLContextConversionUtil.javaRDDStringCSVToMatrixObject(javaRDD);
} else {
return MLContextConversionUtil.javaRDDStringCSVToFrameObject(javaRDD);
}
}
} else if (value instanceof RDD<?>) {
@SuppressWarnings("unchecked") RDD<String> rdd = (RDD<String>) value;
if (hasMatrixMetadata) {
MatrixMetadata matrixMetadata = (MatrixMetadata) metadata;
if (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV) {
return MLContextConversionUtil.rddStringIJVToMatrixObject(rdd, matrixMetadata);
} else {
return MLContextConversionUtil.rddStringCSVToMatrixObject(rdd, matrixMetadata);
}
} else if (hasFrameMetadata) {
FrameMetadata frameMetadata = (FrameMetadata) metadata;
if (frameMetadata.getFrameFormat() == FrameFormat.IJV) {
return MLContextConversionUtil.rddStringIJVToFrameObject(rdd, frameMetadata);
} else {
return MLContextConversionUtil.rddStringCSVToFrameObject(rdd, frameMetadata);
}
} else if (!hasMetadata) {
String firstLine = rdd.first();
boolean isAllNumbers = isCSVLineAllNumbers(firstLine);
if (isAllNumbers) {
return MLContextConversionUtil.rddStringCSVToMatrixObject(rdd);
} else {
return MLContextConversionUtil.rddStringCSVToFrameObject(rdd);
}
}
} else if (value instanceof MatrixBlock) {
MatrixBlock matrixBlock = (MatrixBlock) value;
return MLContextConversionUtil.matrixBlockToMatrixObject(name, matrixBlock, (MatrixMetadata) metadata);
} else if (value instanceof FrameBlock) {
FrameBlock frameBlock = (FrameBlock) value;
return MLContextConversionUtil.frameBlockToFrameObject(name, frameBlock, (FrameMetadata) metadata);
} else if (value instanceof Dataset<?>) {
@SuppressWarnings("unchecked") Dataset<Row> dataFrame = (Dataset<Row>) value;
dataFrame = MLUtils.convertVectorColumnsToML(dataFrame);
if (hasMatrixMetadata) {
return MLContextConversionUtil.dataFrameToMatrixObject(dataFrame, (MatrixMetadata) metadata);
} else if (hasFrameMetadata) {
return MLContextConversionUtil.dataFrameToFrameObject(dataFrame, (FrameMetadata) metadata);
} else if (!hasMetadata) {
boolean looksLikeMatrix = doesDataFrameLookLikeMatrix(dataFrame);
if (looksLikeMatrix) {
return MLContextConversionUtil.dataFrameToMatrixObject(dataFrame);
} else {
return MLContextConversionUtil.dataFrameToFrameObject(dataFrame);
}
}
} else if (value instanceof Matrix) {
Matrix matrix = (Matrix) value;
if ((matrix.hasBinaryBlocks()) && (!matrix.hasMatrixObject())) {
if (metadata == null) {
metadata = matrix.getMatrixMetadata();
}
JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks = matrix.toBinaryBlocks();
return MLContextConversionUtil.binaryBlocksToMatrixObject(binaryBlocks, (MatrixMetadata) metadata);
} else {
return matrix.toMatrixObject();
}
} else if (value instanceof Frame) {
Frame frame = (Frame) value;
if ((frame.hasBinaryBlocks()) && (!frame.hasFrameObject())) {
if (metadata == null) {
metadata = frame.getFrameMetadata();
}
JavaPairRDD<Long, FrameBlock> binaryBlocks = frame.toBinaryBlocks();
return MLContextConversionUtil.binaryBlocksToFrameObject(binaryBlocks, (FrameMetadata) metadata);
} else {
return frame.toFrameObject();
}
} else if (value instanceof double[][]) {
double[][] doubleMatrix = (double[][]) value;
return MLContextConversionUtil.doubleMatrixToMatrixObject(name, doubleMatrix, (MatrixMetadata) metadata);
} else if (value instanceof URL) {
URL url = (URL) value;
return MLContextConversionUtil.urlToMatrixObject(url, (MatrixMetadata) metadata);
} else if (value instanceof Integer) {
return new IntObject((Integer) value);
} else if (value instanceof Double) {
return new DoubleObject((Double) value);
} else if (value instanceof String) {
return new StringObject((String) value);
} else if (value instanceof Boolean) {
return new BooleanObject((Boolean) value);
}
return null;
}
Aggregations