use of org.apache.spark.api.java.function.Function in project gatk by broadinstitute.
the class CoverageModelEMWorkspace method updateSampleUnexplainedVariance.
/**
* E-step update of the sample-specific unexplained variance
*
* @return a {@link SubroutineSignal} containing the update size (key: "error_norm") and the average
* number of function evaluations per sample (key: "iterations")
*/
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
public SubroutineSignal updateSampleUnexplainedVariance() {
mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.E_STEP_GAMMA));
cacheWorkers("after E-step for sample unexplained variance initialization");
/* create a compound objective function for simultaneous multi-sample queries */
final java.util.function.Function<Map<Integer, Double>, Map<Integer, Double>> objFunc = arg -> {
if (arg.isEmpty()) {
return Collections.emptyMap();
}
final int[] sampleIndices = arg.keySet().stream().mapToInt(i -> i).toArray();
final INDArray gammaValues = Nd4j.create(Arrays.stream(sampleIndices).mapToDouble(arg::get).toArray(), new int[] { sampleIndices.length, 1 });
final INDArray eval = mapWorkersAndReduce(cb -> cb.calculateSampleSpecificVarianceObjectiveFunctionMultiSample(sampleIndices, gammaValues), INDArray::add);
final Map<Integer, Double> output = new HashMap<>();
IntStream.range(0, sampleIndices.length).forEach(evalIdx -> output.put(sampleIndices[evalIdx], eval.getDouble(evalIdx)));
return output;
};
final java.util.function.Function<UnivariateSolverSpecifications, AbstractUnivariateSolver> solverFactory = spec -> new RobustBrentSolver(spec.getRelativeAccuracy(), spec.getAbsoluteAccuracy(), spec.getFunctionValueAccuracy(), null, config.getSampleSpecificVarianceSolverNumBisections(), config.getSampleSpecificVarianceSolverRefinementDepth());
/* instantiate a synchronized multi-sample root finder and add jobs */
final SynchronizedUnivariateSolver syncSolver = new SynchronizedUnivariateSolver(objFunc, solverFactory, numSamples);
IntStream.range(0, numSamples).forEach(si -> {
final double x0 = 0.5 * config.getSampleSpecificVarianceUpperLimit();
syncSolver.add(si, 0, config.getSampleSpecificVarianceUpperLimit(), x0, config.getSampleSpecificVarianceAbsoluteTolerance(), config.getSampleSpecificVarianceRelativeTolerance(), config.getSampleSpecificVarianceMaximumIterations());
});
/* solve and collect statistics */
final INDArray newSampleUnexplainedVariance = Nd4j.create(numSamples, 1);
final List<Integer> numberOfEvaluations = new ArrayList<>(numSamples);
try {
final Map<Integer, SynchronizedUnivariateSolver.UnivariateSolverSummary> newSampleSpecificVarianceMap = syncSolver.solve();
newSampleSpecificVarianceMap.entrySet().forEach(entry -> {
final int sampleIndex = entry.getKey();
final SynchronizedUnivariateSolver.UnivariateSolverSummary summary = entry.getValue();
double val = 0;
switch(summary.status) {
case SUCCESS:
val = summary.x;
break;
case TOO_MANY_EVALUATIONS:
logger.warn("Could not locate the root of gamma -- increase the maximum number of" + "function evaluations");
break;
}
newSampleUnexplainedVariance.put(sampleIndex, 0, val);
numberOfEvaluations.add(summary.evaluations);
});
} catch (final InterruptedException ex) {
throw new RuntimeException("The update of sample unexplained variance was interrupted -- can not continue");
}
/* admix */
final INDArray newSampleUnexplainedVarianceAdmixed = newSampleUnexplainedVariance.mul(config.getMeanFieldAdmixingRatio()).addi(sampleUnexplainedVariance.mul(1 - config.getMeanFieldAdmixingRatio()));
/* calculate the error */
final double errorNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newSampleUnexplainedVarianceAdmixed.sub(sampleUnexplainedVariance));
/* update local copy */
sampleUnexplainedVariance.assign(newSampleUnexplainedVarianceAdmixed);
/* push to workers */
pushToWorkers(newSampleUnexplainedVarianceAdmixed, (arr, cb) -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.gamma_s, newSampleUnexplainedVarianceAdmixed));
final int iterations = (int) (numberOfEvaluations.stream().mapToDouble(d -> d).sum() / numSamples);
return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).put(StandardSubroutineSignals.ITERATIONS, iterations).build();
}
use of org.apache.spark.api.java.function.Function in project gatk by broadinstitute.
the class ExampleReadWalkerWithReferenceSpark method readFunction.
private Function<ReadWalkerContext, String> readFunction() {
return (Function<ReadWalkerContext, String>) context -> {
GATKRead read = context.getRead();
ReferenceContext referenceContext = context.getReferenceContext();
StringBuilder sb = new StringBuilder();
sb.append(String.format("Read at %s:%d-%d:\n%s\n", read.getContig(), read.getStart(), read.getEnd(), read.getBasesString()));
if (referenceContext.hasBackingDataSource())
sb.append("Reference Context:\n" + new String(referenceContext.getBases()) + "\n");
sb.append("\n");
return sb.toString();
};
}
use of org.apache.spark.api.java.function.Function in project beijingThirdPeriod by weidongcao.
the class SparkOperateBcp method run.
public static void run(TaskBean task) {
logger.info("开始处理 {} 的BCP数据", task.getContentType());
SparkConf conf = new SparkConf().setAppName(task.getContentType());
JavaSparkContext sc = new JavaSparkContext(conf);
JavaRDD<String> originalRDD = sc.textFile(task.getBcpPath());
// 对BCP文件数据进行基本的处理,并生成ID(HBase的RowKey,Solr的Sid)
JavaRDD<String[]> valueArrrayRDD = originalRDD.mapPartitions((FlatMapFunction<Iterator<String>, String[]>) iter -> {
List<String[]> list = new ArrayList<>();
while (iter.hasNext()) {
String str = iter.next();
String[] fields = str.split("\t");
list.add(fields);
}
return list.iterator();
});
/*
* 对数据进行过滤
* 字段名数组里没有id字段(HBase的RowKey,Solr的Side)
* BCP文件可能升级,添加了新的字段
* FTP、IM_CHAT表新加了三个字段:"service_code_out", "terminal_longitude", "terminal_latitude"
* HTTP表新了了7个字段其中三个字段与上面相同:"service_code_out", "terminal_longitude", "terminal_latitude"
* 另外4个字段是:"manufacturer_code", "zipname", "bcpname", "rownumber", "
* 故过滤的时候要把以上情况考虑进去
*/
JavaRDD<String[]> filterValuesRDD;
filterValuesRDD = valueArrrayRDD.filter((Function<String[], Boolean>) (String[] strings) -> // BCP文件 没有新加字段,
(task.getColumns().length + 1 == strings.length) || // BCP文件添加了新的字段,且只添加了三个字段
((task.getColumns().length + 1) == (strings.length + 3)) || // HTTP的BCP文件添加了新的字段,且添加了7个字段
(BigDataConstants.CONTENT_TYPE_HTTP.equalsIgnoreCase(task.getContentType()) && ((task.getColumns().length + 1) == (strings.length + 3 + 4))));
// BCP文件数据写入HBase
bcpWriteIntoHBase(filterValuesRDD, task);
sc.close();
}
use of org.apache.spark.api.java.function.Function in project calcite by apache.
the class SparkRules method main.
// Play area
public static void main(String[] args) {
final JavaSparkContext sc = new JavaSparkContext("local[1]", "calcite");
final JavaRDD<String> file = sc.textFile("/usr/share/dict/words");
System.out.println(file.map(new Function<String, Object>() {
@Override
public Object call(String s) throws Exception {
return s.substring(0, Math.min(s.length(), 1));
}
}).distinct().count());
file.cache();
String s = file.groupBy(new Function<String, String>() {
@Override
public String call(String s) throws Exception {
return s.substring(0, Math.min(s.length(), 1));
}
}).map(new Function<Tuple2<String, Iterable<String>>, Object>() {
@Override
public Object call(Tuple2<String, Iterable<String>> pair) {
return pair._1() + ":" + Iterables.size(pair._2());
}
}).collect().toString();
System.out.print(s);
final JavaRDD<Integer> rdd = sc.parallelize(new AbstractList<Integer>() {
final Random random = new Random();
@Override
public Integer get(int index) {
System.out.println("get(" + index + ")");
return random.nextInt(100);
}
@Override
public int size() {
System.out.println("size");
return 10;
}
});
System.out.println(rdd.groupBy(new Function<Integer, Integer>() {
public Integer call(Integer integer) {
return integer % 2;
}
}).collect().toString());
System.out.println(file.flatMap(new FlatMapFunction<String, Pair<String, Integer>>() {
public Iterator<Pair<String, Integer>> call(String x) {
if (!x.startsWith("a")) {
return Collections.emptyIterator();
}
return Collections.singletonList(Pair.of(x.toUpperCase(Locale.ROOT), x.length())).iterator();
}
}).take(5).toString());
}
use of org.apache.spark.api.java.function.Function in project cdap by caskdata.
the class SparkPageRankProgram method run.
@Override
public void run(JavaSparkExecutionContext sec) throws Exception {
JavaSparkContext jsc = new JavaSparkContext();
LOG.info("Processing backlinkURLs data");
JavaPairRDD<Long, String> backlinkURLs = sec.fromStream("backlinkURLStream", String.class);
int iterationCount = getIterationCount(sec);
LOG.info("Grouping data by key");
// Grouping backlinks by unique URL in key
JavaPairRDD<String, Iterable<String>> links = backlinkURLs.values().mapToPair(new PairFunction<String, String, String>() {
@Override
public Tuple2<String, String> call(String s) {
String[] parts = SPACES.split(s);
return new Tuple2<>(parts[0], parts[1]);
}
}).distinct().groupByKey().cache();
// Initialize default rank for each key URL
JavaPairRDD<String, Double> ranks = links.mapValues(new Function<Iterable<String>, Double>() {
@Override
public Double call(Iterable<String> rs) {
return 1.0;
}
});
// Calculates and updates URL ranks continuously using PageRank algorithm.
for (int current = 0; current < iterationCount; current++) {
LOG.debug("Processing data with PageRank algorithm. Iteration {}/{}", current + 1, (iterationCount));
// Calculates URL contributions to the rank of other URLs.
JavaPairRDD<String, Double> contribs = links.join(ranks).values().flatMapToPair(new PairFlatMapFunction<Tuple2<Iterable<String>, Double>, String, Double>() {
@Override
public Iterable<Tuple2<String, Double>> call(Tuple2<Iterable<String>, Double> s) {
LOG.debug("Processing {} with rank {}", s._1(), s._2());
int urlCount = Iterables.size(s._1());
List<Tuple2<String, Double>> results = new ArrayList<>();
for (String n : s._1()) {
results.add(new Tuple2<>(n, s._2() / urlCount));
}
return results;
}
});
// Re-calculates URL ranks based on backlink contributions.
ranks = contribs.reduceByKey(new Sum()).mapValues(new Function<Double, Double>() {
@Override
public Double call(Double sum) {
return 0.15 + sum * 0.85;
}
});
}
LOG.info("Writing ranks data");
final ServiceDiscoverer discoveryServiceContext = sec.getServiceDiscoverer();
final Metrics sparkMetrics = sec.getMetrics();
JavaPairRDD<byte[], Integer> ranksRaw = ranks.mapToPair(new PairFunction<Tuple2<String, Double>, byte[], Integer>() {
@Override
public Tuple2<byte[], Integer> call(Tuple2<String, Double> tuple) throws Exception {
LOG.debug("URL {} has rank {}", Arrays.toString(tuple._1().getBytes(Charsets.UTF_8)), tuple._2());
URL serviceURL = discoveryServiceContext.getServiceURL(SparkPageRankApp.SERVICE_HANDLERS);
if (serviceURL == null) {
throw new RuntimeException("Failed to discover service: " + SparkPageRankApp.SERVICE_HANDLERS);
}
try {
URLConnection connection = new URL(serviceURL, String.format("%s/%s", SparkPageRankApp.SparkPageRankServiceHandler.TRANSFORM_PATH, tuple._2().toString())).openConnection();
try (BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream(), Charsets.UTF_8))) {
String pr = reader.readLine();
if ((Integer.parseInt(pr)) == POPULAR_PAGE_THRESHOLD) {
sparkMetrics.count(POPULAR_PAGES, 1);
} else if (Integer.parseInt(pr) <= UNPOPULAR_PAGE_THRESHOLD) {
sparkMetrics.count(UNPOPULAR_PAGES, 1);
} else {
sparkMetrics.count(REGULAR_PAGES, 1);
}
return new Tuple2<>(tuple._1().getBytes(Charsets.UTF_8), Integer.parseInt(pr));
}
} catch (Exception e) {
LOG.warn("Failed to read the Stream for service {}", SparkPageRankApp.SERVICE_HANDLERS, e);
throw Throwables.propagate(e);
}
}
});
// Store calculated results in output Dataset.
// All calculated results are stored in one row.
// Each result, the calculated URL rank based on backlink contributions, is an entry of the row.
// The value of the entry is the URL rank.
sec.saveAsDataset(ranksRaw, "ranks");
LOG.info("PageRanks successfuly computed and written to \"ranks\" dataset");
}
Aggregations