use of edu.cmu.tetrad.util.ForkJoinPoolInstance in project tetrad by cmu-phil.
the class GraphUtils method edgeMisclassificationCounts.
public static int[][] edgeMisclassificationCounts(Graph leftGraph, Graph topGraph, boolean print) {
class CountTask extends RecursiveTask<Counts> {
private int chunk;
private int from;
private int to;
private final List<Edge> edges;
private final Graph leftGraph;
private final Graph topGraph;
private final Counts counts;
private final int[] count;
public CountTask(int chunk, int from, int to, List<Edge> edges, Graph leftGraph, Graph topGraph, int[] count) {
this.chunk = chunk;
this.from = from;
this.to = to;
this.edges = edges;
this.leftGraph = leftGraph;
this.topGraph = topGraph;
this.counts = new Counts();
this.count = count;
}
@Override
protected Counts compute() {
int range = to - from;
if (range <= chunk) {
for (int i = from; i < to; i++) {
int j = ++count[0];
if (j % 1000 == 0) {
System.out.println("Counted " + (count[0]));
}
Edge edge = edges.get(i);
Node x = edge.getNode1();
Node y = edge.getNode2();
Edge left = leftGraph.getEdge(x, y);
Edge top = topGraph.getEdge(x, y);
int m = getTypeLeft(left, top);
int n = getTypeTop(top);
counts.increment(m, n);
}
return counts;
} else {
int mid = (to + from) / 2;
CountTask left = new CountTask(chunk, from, mid, edges, leftGraph, topGraph, count);
CountTask right = new CountTask(chunk, mid, to, edges, leftGraph, topGraph, count);
left.fork();
Counts rightAnswer = right.compute();
Counts leftAnswer = left.join();
leftAnswer.addAll(rightAnswer);
return leftAnswer;
}
}
public Counts getCounts() {
return counts;
}
}
// System.out.println("Forming edge union");
// topGraph = GraphUtils.replaceNodes(topGraph, leftGraph.getNodes());
// int[][] counts = new int[8][6];
Set<Edge> edgeSet = new HashSet<>();
edgeSet.addAll(topGraph.getEdges());
edgeSet.addAll(leftGraph.getEdges());
// System.out.println("Union formed");
if (print) {
System.out.println("Top graph " + topGraph.getEdges().size());
System.out.println("Left graph " + leftGraph.getEdges().size());
System.out.println("All edges " + edgeSet.size());
}
List<Edge> edges = new ArrayList<>(edgeSet);
// System.out.println("Finding pool");
ForkJoinPoolInstance pool = ForkJoinPoolInstance.getInstance();
// System.out.println("Starting count task");
CountTask task = new CountTask(500, 0, edges.size(), edges, leftGraph, topGraph, new int[1]);
Counts counts = pool.getPool().invoke(task);
// System.out.println("Finishing count task");
return counts.countArray();
}
Aggregations