use of edu.cmu.tetrad.util.ChoiceGenerator in project tetrad by cmu-phil.
the class SampleVcpcFast method orientUnshieldedTriples.
private void orientUnshieldedTriples(IKnowledge knowledge, IndependenceTest test, int depth) {
TetradLogger.getInstance().log("info", "Starting Collider Orientation:");
// System.out.println("orientUnshieldedTriples 1");
colliderTriples = new HashSet<>();
noncolliderTriples = new HashSet<>();
ambiguousTriples = new HashSet<>();
List<Node> nodes = graph.getNodes();
for (Node y : nodes) {
List<Node> adjacentNodes = graph.getAdjacentNodes(y);
if (adjacentNodes.size() < 2) {
continue;
}
ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
int[] combination;
while ((combination = cg.next()) != null) {
Node x = adjacentNodes.get(combination[0]);
Node z = adjacentNodes.get(combination[1]);
if (this.graph.isAdjacentTo(x, z)) {
continue;
}
getAllTriples().add(new Triple(x, y, z));
SearchGraphUtils.CpcTripleType type = SearchGraphUtils.getCpcTripleType(x, y, z, test, depth, graph, verbose);
if (type == SearchGraphUtils.CpcTripleType.COLLIDER) {
if (colliderAllowed(x, y, z, knowledge)) {
graph.setEndpoint(x, y, Endpoint.ARROW);
graph.setEndpoint(z, y, Endpoint.ARROW);
TetradLogger.getInstance().log("colliderOrientations", SearchLogUtils.colliderOrientedMsg(x, y, z));
}
colliderTriples.add(new Triple(x, y, z));
} else if (type == SearchGraphUtils.CpcTripleType.AMBIGUOUS) {
Triple triple = new Triple(x, y, z);
ambiguousTriples.add(triple);
graph.addAmbiguousTriple(triple.getX(), triple.getY(), triple.getZ());
Edge edge = Edges.undirectedEdge(x, z);
definitelyNonadjacencies.add(edge);
} else {
noncolliderTriples.add(new Triple(x, y, z));
}
}
}
TetradLogger.getInstance().log("info", "Finishing Collider Orientation.");
}
use of edu.cmu.tetrad.util.ChoiceGenerator in project tetrad by cmu-phil.
the class SampleVcpcFast method orientUnshieldedTriplesConcurrent.
// Sample Version of Step 3 of VCPC.
// private CpcTripleType getSampleTripleType(Node x, Node y, Node z, IndependenceTest test,
// int depth, Graph graph, boolean verbose) {
//
// if (verbose) {
// System.out.println("Checking " + x + " --- " + y + " --- " + z);
// }
//
// int numSepsetsContainingY = 0;
// int numSepsetsNotContainingY = 0;
//
// this.partialCorrs = getPartialCorrs();
//
//
// List<Node> _nodes = graph.getAdjacentNodes(x);
// _nodes.remove(z);
// int _depth = depth;
// if (_depth == -1) {
// _depth = 1000;
// }
// _depth = Math.min(_depth, _nodes.size());
//
//
// while (true) {
// for (int d = 0; d <= _depth; d++) {
// ChoiceGenerator cg1 = new ChoiceGenerator(_nodes.size(), d);
// int[] choice;
// while ((choice = cg1.next()) != null) {
// List<Node> cond = DataGraphUtils.asList(choice, _nodes);
// TetradMatrix submatrix = DataUtils.subMatrix(covMatrix, indexMap, x, z, cond);
// double r = StatUtils.partialCorrelation(submatrix);
// partialCorrs.put(cond, r);
//
// if (test.isIndependent(x, z, cond)) {
// if (verbose) {
// System.out.println("Indep: " + x + " _||_ " + z + " | " + cond);
// }
// if (cond.contains(y)) {
// numSepsetsContainingY++;
// } else {
// numSepsetsNotContainingY++;
// }
// }
//
// if (numSepsetsContainingY > 0 && numSepsetsNotContainingY > 0) {
// return CpcTripleType.AMBIGUOUS;
// }
// }
// }
//
// _nodes = graph.getAdjacentNodes(z);
// _nodes.remove(x);
// TetradLogger.getInstance().log("adjacencies", "Adjacents for " + x + "--" + y + "--" + z + " = " + _nodes);
//
// _depth = depth;
// if (_depth == -1) {
// _depth = 1000;
// }
// _depth = Math.min(_depth, _nodes.size());
//
// for (int d = 0; d <= _depth; d++) {
// ChoiceGenerator cg1 = new ChoiceGenerator(_nodes.size(), d);
// int[] choice;
// while ((choice = cg1.next()) != null) {
// List<Node> cond = DataGraphUtils.asList(choice, _nodes);
// TetradMatrix submatrix = DataUtils.subMatrix(covMatrix, indexMap, x, z, cond);
// double r = StatUtils.partialCorrelation(submatrix);
// partialCorrs.put(cond, r);
//
// if (test.isIndependent(x, z, cond)) {
//
// if (verbose) {
// System.out.println("Indep: " + x + " _||_ " + z + " | " + cond);
// }
//
// if (cond.contains(y)) {
// numSepsetsContainingY++;
// } else {
// numSepsetsNotContainingY++;
// }
// }
// if (numSepsetsContainingY > 0 && numSepsetsNotContainingY > 0) {
// return CpcTripleType.AMBIGUOUS;
// }
// }
// }
// break;
// }
//
// double L = 0.01;
// // System.out.println("L = " + L);
//
// if (numSepsetsContainingY > 0 && numSepsetsNotContainingY == 0) {
// for (List<Node> sepset1 : partialCorrs.keySet()) {
// if (sepset1.contains(y)) {
// double r1 = partialCorrs.get(sepset1);
// for (List<Node> sepset2 : partialCorrs.keySet()) {
// if (!sepset2.contains(y)) {
// double r2 = partialCorrs.get(sepset2);
// double M = Math.abs(r1 - r2);
//
// if (!(M >= L)) {
// return CpcTripleType.AMBIGUOUS;
// }
// }
// }
// return CpcTripleType.NONCOLLIDER;
// }
// }
// }
//
// if (numSepsetsNotContainingY > 0 && numSepsetsContainingY == 0) {
// for (List<Node> sepset1 : partialCorrs.keySet()) {
// if (!sepset1.contains(y)) {
// double r1 = partialCorrs.get(sepset1);
// for (List<Node> sepset2 : partialCorrs.keySet()) {
// if (sepset2.contains(y)) {
// double r2 = partialCorrs.get(sepset2);
// double M = Math.abs(r1 - r2);
// if (!(M >= L)) {
// return CpcTripleType.AMBIGUOUS;
// }
// }
// }
// return CpcTripleType.COLLIDER;
// }
// }
// }
// return null;
// }
// public enum CpcTripleType {
// COLLIDER, NONCOLLIDER, AMBIGUOUS
// }
private void orientUnshieldedTriplesConcurrent(final IKnowledge knowledge, final IndependenceTest test, final int depth) {
ExecutorService executor = Executors.newFixedThreadPool(NTHREDS);
TetradLogger.getInstance().log("info", "Starting Collider Orientation:");
Graph graph = new EdgeListGraph(getGraph());
// System.out.println("orientUnshieldedTriples 1");
colliderTriples = new HashSet<>();
noncolliderTriples = new HashSet<>();
ambiguousTriples = new HashSet<>();
List<Node> nodes = graph.getNodes();
for (Node _y : nodes) {
final Node y = _y;
List<Node> adjacentNodes = graph.getAdjacentNodes(y);
if (adjacentNodes.size() < 2) {
continue;
}
ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
int[] combination;
while ((combination = cg.next()) != null) {
final Node x = adjacentNodes.get(combination[0]);
final Node z = adjacentNodes.get(combination[1]);
if (graph.isAdjacentTo(x, z)) {
continue;
}
Runnable worker = new Runnable() {
@Override
public void run() {
getAllTriples().add(new Triple(x, y, z));
SearchGraphUtils.CpcTripleType type = SearchGraphUtils.getCpcTripleType(x, y, z, test, depth, getGraph(), verbose);
//
if (type == SearchGraphUtils.CpcTripleType.COLLIDER) {
if (colliderAllowed(x, y, z, knowledge)) {
getGraph().setEndpoint(x, y, Endpoint.ARROW);
getGraph().setEndpoint(z, y, Endpoint.ARROW);
TetradLogger.getInstance().log("colliderOrientations", SearchLogUtils.colliderOrientedMsg(x, y, z));
}
colliderTriples.add(new Triple(x, y, z));
} else if (type == SearchGraphUtils.CpcTripleType.AMBIGUOUS) {
Triple triple = new Triple(x, y, z);
ambiguousTriples.add(triple);
getGraph().addAmbiguousTriple(triple.getX(), triple.getY(), triple.getZ());
} else {
noncolliderTriples.add(new Triple(x, y, z));
}
}
};
executor.execute(worker);
}
}
// This will make the executor accept no new threads
// and finish all existing threads in the queue
executor.shutdown();
try {
// Wait until all threads are finish
executor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
System.out.println("Finished all threads");
} catch (InterruptedException e) {
e.printStackTrace();
}
TetradLogger.getInstance().log("info", "Finishing Collider Orientation.");
}
use of edu.cmu.tetrad.util.ChoiceGenerator in project tetrad by cmu-phil.
the class MbUtils method trimEdgesAmongParents.
/**
* Removes edges among the parents of the target.
*/
public static void trimEdgesAmongParents(Graph graph, Node target) {
List parents = graph.getParents(target);
if (parents.size() >= 2) {
ChoiceGenerator cg = new ChoiceGenerator(parents.size(), 2);
int[] choice;
while ((choice = cg.next()) != null) {
Node v = (Node) parents.get(choice[0]);
Node w = (Node) parents.get(choice[1]);
Edge edge = graph.getEdge(v, w);
if (edge != null) {
// LogUtils.getInstance().finest("Removing edge among parents: " + edge);
graph.removeEdges(v, w);
}
}
}
}
use of edu.cmu.tetrad.util.ChoiceGenerator in project tetrad by cmu-phil.
the class MbUtils method trimEdgesAmongParentsOfChildren.
/**
* Removes edges among the parents of children of the target.
*/
public static void trimEdgesAmongParentsOfChildren(Graph graph, Node target) {
List<Node> children = graph.getNodesOutTo(target, Endpoint.ARROW);
Set<Node> parents = new HashSet<>();
for (Node aChildren : children) {
parents.addAll(graph.getParents(aChildren));
}
parents.remove(target);
parents.removeAll(graph.getAdjacentNodes(target));
List<Node> parentsOfChildren = new ArrayList<>(parents);
if (parentsOfChildren.size() >= 2) {
ChoiceGenerator cg = new ChoiceGenerator(parentsOfChildren.size(), 2);
int[] choice;
while ((choice = cg.next()) != null) {
Node v = parentsOfChildren.get(choice[0]);
Node w = parentsOfChildren.get(choice[1]);
Edge edge = graph.getEdge(v, w);
if (edge != null) {
// LogUtils.getInstance().finest("Removing edge among parents: " + edge);
graph.removeEdge(v, w);
}
}
}
}
use of edu.cmu.tetrad.util.ChoiceGenerator in project tetrad by cmu-phil.
the class MeekRules method meekR3.
/**
* Meek's rule R3. If a--b, a--c, a--d, c-->b, d-->b, then orient a-->b.
*/
private void meekR3(Node a, Graph graph, IKnowledge knowledge) {
List<Node> adjacentNodes = graph.getAdjacentNodes(a);
if (adjacentNodes.size() < 3) {
return;
}
for (Node d : adjacentNodes) {
if (Edges.isUndirectedEdge(graph.getEdge(a, d))) {
List<Node> otherAdjacents = new ArrayList<>(adjacentNodes);
otherAdjacents.remove(d);
ChoiceGenerator cg = new ChoiceGenerator(otherAdjacents.size(), 2);
int[] choice;
while ((choice = cg.next()) != null) {
List<Node> nodes = GraphUtils.asList(choice, otherAdjacents);
Node b = nodes.get(0);
Node c = nodes.get(1);
boolean isKite = isKite(a, d, b, c, graph);
if (isKite) {
if (isArrowpointAllowed(d, a, knowledge)) {
if (!isUnshieldedNoncollider(c, d, b, graph)) {
continue;
}
direct(d, a, graph);
log(SearchLogUtils.edgeOrientedMsg("Meek R3", graph.getEdge(d, a)));
}
}
}
}
}
}
Aggregations