use of edu.cmu.tetrad.util.CombinationGenerator in project tetrad by cmu-phil.
the class SampleVcpcFast method search.
/**
* Runs PC starting with a fully connected graph over all of the variables in the domain of the independence test.
* See PC for caveats. The number of possible cycles and bidirected edges is far less with CPC than with PC.
*/
// public final Graph search() {
// return search(independenceTest.getVariable());
// }
// // public Graph search(List<Node> nodes) {
// //
// //// return search(new FasICov2(getIndependenceTest()), nodes);
// //// return search(new Fas(getIndependenceTest()), nodes);
// // return search(new Fas(getIndependenceTest()), nodes);
// }
// modified FAS into VCFAS; added in definitelyNonadjacencies set of edges.
public Graph search() {
this.logger.log("info", "Starting VCCPC algorithm");
this.logger.log("info", "Independence test = " + getIndependenceTest() + ".");
this.allTriples = new HashSet<>();
this.ambiguousTriples = new HashSet<>();
this.colliderTriples = new HashSet<>();
this.noncolliderTriples = new HashSet<>();
Vcfas fas = new Vcfas(getIndependenceTest());
definitelyNonadjacencies = new HashSet<>();
markovInAllPatterns = new HashSet<>();
// this.logger.log("info", "Variables " + independenceTest.getVariable());
long startTime = System.currentTimeMillis();
if (getIndependenceTest() == null) {
throw new NullPointerException();
}
List<Node> allNodes = getIndependenceTest().getVariables();
// if (!allNodes.containsAll(nodes)) {
// throw new IllegalArgumentException("All of the given nodes must " +
// "be in the domain of the independence test provided.");
// }
// Fas fas = new Fas(graph, getIndependenceTest());
// FasStableConcurrent fas = new FasStableConcurrent(graph, getIndependenceTest());
// Fas6 fas = new Fas6(graph, getIndependenceTest());
// fas = new FasICov(graph, (IndTestFisherZ) getIndependenceTest());
fas.setKnowledge(getKnowledge());
fas.setDepth(getDepth());
fas.setVerbose(verbose);
// Note that we are ignoring the sepset map returned by this method
// on purpose; it is not used in this search.
graph = fas.search();
apparentlyNonadjacencies = fas.getApparentlyNonadjacencies();
if (isDoOrientation()) {
if (verbose) {
System.out.println("CPC orientation...");
}
SearchGraphUtils.pcOrientbk(knowledge, graph, allNodes);
orientUnshieldedTriples(knowledge, getIndependenceTest(), getDepth());
// orientUnshieldedTriplesConcurrent(knowledge, getIndependenceTest(), getMaxIndegree());
MeekRules meekRules = new MeekRules();
meekRules.setAggressivelyPreventCycles(this.aggressivelyPreventCycles);
meekRules.setKnowledge(knowledge);
meekRules.orientImplied(graph);
}
List<Triple> ambiguousTriples = new ArrayList(graph.getAmbiguousTriples());
int[] dims = new int[ambiguousTriples.size()];
for (int i = 0; i < ambiguousTriples.size(); i++) {
dims[i] = 2;
}
List<Graph> patterns = new ArrayList<>();
Map<Graph, List<Triple>> newColliders = new IdentityHashMap<>();
Map<Graph, List<Triple>> newNonColliders = new IdentityHashMap<>();
// Using combination generator to generate a list of combinations of ambiguous triples dismabiguated into colliders
// and non-colliders. The combinations are added as graphs to the list patterns. The graphs are then subject to
// basic rules to ensure consistent patterns.
CombinationGenerator generator = new CombinationGenerator(dims);
int[] combination;
while ((combination = generator.next()) != null) {
Graph _graph = new EdgeListGraph(graph);
newColliders.put(_graph, new ArrayList<Triple>());
newNonColliders.put(_graph, new ArrayList<Triple>());
for (Graph graph : newColliders.keySet()) {
// System.out.println("$$$ " + newColliders.get(graph));
}
for (int k = 0; k < combination.length; k++) {
// System.out.println("k = " + combination[k]);
Triple triple = ambiguousTriples.get(k);
_graph.removeAmbiguousTriple(triple.getX(), triple.getY(), triple.getZ());
if (combination[k] == 0) {
newColliders.get(_graph).add(triple);
// System.out.println(newColliders.get(_graph));
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
_graph.setEndpoint(x, y, Endpoint.ARROW);
_graph.setEndpoint(z, y, Endpoint.ARROW);
}
if (combination[k] == 1) {
newNonColliders.get(_graph).add(triple);
}
}
patterns.add(_graph);
}
List<Graph> _patterns = new ArrayList<>(patterns);
GRAPH: for (Graph graph : new ArrayList<>(patterns)) {
// _graph = new EdgeListGraph(graph);
// System.out.println("graph = " + graph + " in keyset? " + newColliders.containsKey(graph));
//
List<Triple> colliders = newColliders.get(graph);
List<Triple> nonColliders = newNonColliders.get(graph);
for (Triple triple : colliders) {
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
if (graph.getEdge(x, y).pointsTowards(x) || (graph.getEdge(y, z).pointsTowards(z))) {
patterns.remove(graph);
continue GRAPH;
}
}
for (Triple triple : colliders) {
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
graph.setEndpoint(x, y, Endpoint.ARROW);
graph.setEndpoint(z, y, Endpoint.ARROW);
}
for (Triple triple : nonColliders) {
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
if (graph.getEdge(x, y).pointsTowards(y)) {
graph.removeEdge(y, z);
graph.addDirectedEdge(y, z);
}
if (graph.getEdge(y, z).pointsTowards(y)) {
graph.removeEdge(x, y);
graph.addDirectedEdge(y, x);
}
}
for (Edge edge : graph.getEdges()) {
Node x = edge.getNode1();
Node y = edge.getNode2();
if (Edges.isBidirectedEdge(edge)) {
graph.removeEdge(x, y);
graph.addUndirectedEdge(x, y);
}
}
// for (Edge edge : graph.getEdges()) {
// if (Edges.isBidirectedEdge(edge)) {
// patterns.remove(graph);
// continue Graph;
// }
// }
MeekRules rules = new MeekRules();
rules.orientImplied(graph);
if (graph.existsDirectedCycle()) {
patterns.remove(graph);
continue GRAPH;
}
}
MARKOV: for (Edge edge : apparentlyNonadjacencies.keySet()) {
Node x = edge.getNode1();
Node y = edge.getNode2();
for (Graph _graph : new ArrayList<>(patterns)) {
List<Node> boundaryX = new ArrayList<>(boundary(x, _graph));
List<Node> boundaryY = new ArrayList<>(boundary(y, _graph));
List<Node> futureX = new ArrayList<>(future(x, _graph));
List<Node> futureY = new ArrayList<>(future(y, _graph));
if (y == x) {
continue;
}
if (boundaryX.contains(y) || boundaryY.contains(x)) {
continue;
}
IndependenceTest test = independenceTest;
if (!futureX.contains(y)) {
if (!test.isIndependent(x, y, boundaryX)) {
continue MARKOV;
}
}
if (!futureY.contains(x)) {
if (!test.isIndependent(y, x, boundaryY)) {
continue MARKOV;
}
}
}
definitelyNonadjacencies.add(edge);
// apparentlyNonadjacencies.remove(edge);
}
for (Edge edge : definitelyNonadjacencies) {
if (apparentlyNonadjacencies.keySet().contains(edge)) {
apparentlyNonadjacencies.keySet().remove(edge);
}
}
setSemIm(semIm);
// semIm.getSemPm().getGraph();
// System.out.println(semIm.getEdgeCoef());
// graph = DataGraphUtils.replaceNodes(graph, semIm.getVariableNodes());
// System.out.println(semIm.getEdgeCoef());
// System.out.println(sampleRegress.entrySet());
List<Double> squaredDifference = new ArrayList<>();
int numNullEdges = 0;
// //Edge Estimation Alg I
Regression sampleRegression = new RegressionDataset(dataSet);
System.out.println(sampleRegression.getGraph());
graph = GraphUtils.replaceNodes(graph, dataSet.getVariables());
Map<Edge, double[]> sampleRegress = new HashMap<>();
Map<Edge, Double> edgeCoefs = new HashMap<>();
ESTIMATION: for (Node z : graph.getNodes()) {
Set<Edge> adj = getAdj(z, graph);
for (Edge edge : apparentlyNonadjacencies.keySet()) {
if (z == edge.getNode1() || z == edge.getNode2()) {
for (Edge adjacency : adj) {
// return Unknown and go to next Z
sampleRegress.put(adjacency, null);
Node a = adjacency.getNode1();
Node b = adjacency.getNode2();
if (semIm.existsEdgeCoef(a, b)) {
Double c = semIm.getEdgeCoef(a, b);
edgeCoefs.put(adjacency, c);
} else {
edgeCoefs.put(adjacency, 0.0);
}
}
continue ESTIMATION;
}
}
for (Edge nonadj : definitelyNonadjacencies) {
if (nonadj.getNode1() == z || nonadj.getNode2() == z) {
// return 0 for e
double[] d = { 0, 0 };
sampleRegress.put(nonadj, d);
Node a = nonadj.getNode1();
Node b = nonadj.getNode2();
if (semIm.existsEdgeCoef(a, b)) {
Double c = semIm.getEdgeCoef(a, b);
edgeCoefs.put(nonadj, c);
} else {
edgeCoefs.put(nonadj, 0.0);
}
}
}
Set<Edge> parentsOfZ = new HashSet<>();
Set<Edge> _adj = getAdj(z, graph);
for (Edge _adjacency : _adj) {
if (!_adjacency.isDirected()) {
for (Edge adjacency : adj) {
sampleRegress.put(adjacency, null);
Node a = adjacency.getNode1();
Node b = adjacency.getNode2();
if (semIm.existsEdgeCoef(a, b)) {
Double c = semIm.getEdgeCoef(a, b);
edgeCoefs.put(adjacency, c);
} else {
edgeCoefs.put(adjacency, 0.0);
}
}
}
if (_adjacency.pointsTowards(z)) {
parentsOfZ.add(_adjacency);
}
}
for (Edge edge : parentsOfZ) {
if (edge.pointsTowards(edge.getNode2())) {
RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
System.out.println(result);
double[] d = result.getCoef();
sampleRegress.put(edge, d);
Node a = edge.getNode1();
Node b = edge.getNode2();
if (semIm.existsEdgeCoef(a, b)) {
Double c = semIm.getEdgeCoef(a, b);
edgeCoefs.put(edge, c);
} else {
edgeCoefs.put(edge, 0.0);
}
}
// if (edge.pointsTowards(edge.getNode2())) {
// RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
// double[] d = result.getCoef();
// sampleRegress.put(edge, d);
//
// Node a = edge.getNode1();
// Node b = edge.getNode2();
// if (semIm.existsEdgeCoef(a, b)) {
// Double c = semIm.getEdgeCoef(a, b);
// edgeCoefs.put(edge, c);
// } else { edgeCoefs.put(edge, 0.0); }
// }
}
}
System.out.println("All IM: " + semIm + "Finish");
System.out.println("Just IM coefs: " + semIm.getEdgeCoef());
System.out.println("IM Coef Map: " + edgeCoefs);
System.out.println("Regress Coef Map: " + sampleRegress);
//
for (Edge edge : sampleRegress.keySet()) {
System.out.println(" Sample Regression: " + edge + Arrays.toString(sampleRegress.get(edge)));
}
for (Edge edge : graph.getEdges()) {
// if (edge.isDirected()) {
// System.out.println("IM edge: " + semIm.getEdgeCoef(edge));
// }
System.out.println("Sample edge: " + Arrays.toString(sampleRegress.get(edge)));
}
//
//
System.out.println("Sample VCPC:");
System.out.println("# of patterns: " + patterns.size());
long endTime = System.currentTimeMillis();
this.elapsedTime = endTime - startTime;
System.out.println("Search Time (seconds):" + (elapsedTime) / 1000 + " s");
System.out.println("Search Time (milli):" + elapsedTime + " ms");
System.out.println("# of Apparent Nonadj: " + apparentlyNonadjacencies.size());
System.out.println("# of Definite Nonadj: " + definitelyNonadjacencies.size());
// System.out.println("Definitely Nonadjacencies:");
//
// for (Edge edge : definitelyNonadjacencies) {
// System.out.println(edge);
// }
//
// System.out.println("markov in all patterns:" + markovInAllPatterns);
// System.out.println("patterns:" + patterns);
// System.out.println("Apparently Nonadjacencies:");
//
// for (Edge edge : apparentlyNonadjacencies.keySet()) {
// System.out.println(edge);
// }
// System.out.println("Definitely Nonadjacencies:");
//
//
// for (Edge edge : definitelyNonadjacencies) {
// System.out.println(edge);
// }
TetradLogger.getInstance().log("apparentlyNonadjacencies", "\n Apparent Non-adjacencies" + apparentlyNonadjacencies);
TetradLogger.getInstance().log("definitelyNonadjacencies", "\n Definite Non-adjacencies" + definitelyNonadjacencies);
TetradLogger.getInstance().log("patterns", "Disambiguated Patterns: " + patterns);
TetradLogger.getInstance().log("graph", "\nReturning this graph: " + graph);
TetradLogger.getInstance().log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
TetradLogger.getInstance().log("info", "Finishing CPC algorithm.");
logTriples();
TetradLogger.getInstance().flush();
// SearchGraphUtils.verifySepsetIntegrity(Map<Edge, List<Node>>, graph);
return graph;
}
use of edu.cmu.tetrad.util.CombinationGenerator in project tetrad by cmu-phil.
the class TimeoutComparison method compareFromSimulations.
/**
* Compares algorithms.
*
* @param resultsPath Path to the file where the output should be printed.
* @param simulations The list of simulationWrapper that is used to generate
* graphs and data for the comparison.
* @param algorithms The list of algorithms to be compared.
* @param statistics The list of statistics on which to compare the
* algorithm, and their utility weights.
*/
public void compareFromSimulations(String resultsPath, Simulations simulations, String outputFileName, Algorithms algorithms, Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
this.resultsPath = resultsPath;
// Create output file.
try {
File dir = new File(resultsPath);
dir.mkdirs();
File file = new File(dir, outputFileName);
this.out = new PrintStream(new FileOutputStream(file));
} catch (Exception e) {
throw new RuntimeException(e);
}
out.println(new Date());
// Set up simulations--create data and graphs, read in parameters. The parameters
// are set in the parameters object.
List<SimulationWrapper> simulationWrappers = new ArrayList<>();
int numRuns = parameters.getInt("numRuns");
for (Simulation simulation : simulations.getSimulations()) {
List<SimulationWrapper> wrappers = getSimulationWrappers(simulation, parameters);
for (SimulationWrapper wrapper : wrappers) {
wrapper.createData(wrapper.getSimulationSpecificParameters());
simulationWrappers.add(wrapper);
}
}
// Set up the algorithms.
List<AlgorithmWrapper> algorithmWrappers = new ArrayList<>();
for (Algorithm algorithm : algorithms.getAlgorithms()) {
List<Integer> _dims = new ArrayList<>();
List<String> varyingParameters = new ArrayList<>();
final List<String> parameters1 = algorithm.getParameters();
for (String name : parameters1) {
if (parameters.getNumValues(name) > 1) {
_dims.add(parameters.getNumValues(name));
varyingParameters.add(name);
}
}
if (varyingParameters.isEmpty()) {
algorithmWrappers.add(new AlgorithmWrapper(algorithm, parameters));
} else {
int[] dims = new int[_dims.size()];
for (int i = 0; i < _dims.size(); i++) {
dims[i] = _dims.get(i);
}
CombinationGenerator gen = new CombinationGenerator(dims);
int[] choice;
while ((choice = gen.next()) != null) {
AlgorithmWrapper wrapper = new AlgorithmWrapper(algorithm, parameters);
for (int h = 0; h < dims.length; h++) {
String parameter = varyingParameters.get(h);
Object[] values = parameters.getValues(parameter);
Object value = values[choice[h]];
wrapper.setValue(parameter, value);
}
algorithmWrappers.add(wrapper);
}
}
}
// Create the algorithm-simulation wrappers for every combination of algorithm and
// simulation.
List<AlgorithmSimulationWrapper> algorithmSimulationWrappers = new ArrayList<>();
for (SimulationWrapper simulationWrapper : simulationWrappers) {
for (AlgorithmWrapper algorithmWrapper : algorithmWrappers) {
DataType algDataType = algorithmWrapper.getDataType();
DataType simDataType = simulationWrapper.getDataType();
if (!(algDataType == DataType.Mixed || (algDataType == simDataType))) {
System.out.println("Type mismatch: " + algorithmWrapper.getDescription() + " / " + simulationWrapper.getDescription());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
// external.setSimulation(simulationWrapper.getSimulation());
// external.setPath(dirs.get(simulationWrappers.indexOf(simulationWrapper)));
// external.setPath(resultsPath);
external.setSimIndex(simulationWrappers.indexOf(external.getSimulation()));
}
algorithmSimulationWrappers.add(new AlgorithmSimulationWrapper(algorithmWrapper, simulationWrapper));
}
}
// Run all of the algorithms and compile statistics.
double[][][][] allStats = calcStats(algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, statistics, numRuns, timeout, unit);
// Print out the preliminary information for statistics types, etc.
if (allStats != null) {
out.println();
out.println("Statistics:");
out.println();
for (Statistic stat : statistics.getStatistics()) {
out.println(stat.getAbbreviation() + " = " + stat.getDescription());
}
}
out.println();
// out.println();
if (allStats != null) {
int numTables = allStats.length;
int numStats = allStats[0][0].length - 1;
double[][][] statTables = calcStatTables(allStats, Mode.Average, numTables, algorithmSimulationWrappers, numStats, statistics);
double[] utilities = calcUtilities(statistics, algorithmSimulationWrappers, statTables[0]);
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
int[] newOrder;
if (isSortByUtility()) {
newOrder = sort(algorithmSimulationWrappers, utilities);
} else {
newOrder = new int[algorithmSimulationWrappers.size()];
for (int q = 0; q < algorithmSimulationWrappers.size(); q++) {
newOrder[q] = q;
}
}
out.println("Simulations:");
out.println();
// if (simulationWrappers.size() == 1) {
// out.println(simulationWrappers.get(0).getDescription());
// } else {
int i = 0;
for (SimulationWrapper simulation : simulationWrappers) {
out.print("Simulation " + (++i) + ": ");
out.println(simulation.getDescription());
out.println();
printParameters(simulation.getParameters(), simulation.getSimulationSpecificParameters(), out);
// for (String param : simulation.getParameters()) {
// out.println(param + " = " + simulation.getValue(param));
// }
out.println();
}
// }
out.println("Algorithms:");
out.println();
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
AlgorithmSimulationWrapper wrapper = algorithmSimulationWrappers.get(t);
if (wrapper.getSimulationWrapper() == simulationWrappers.get(0)) {
out.println((t + 1) + ". " + wrapper.getAlgorithmWrapper().getDescription());
}
}
if (isSortByUtility()) {
out.println();
out.println("Sorting by utility, high to low.");
}
if (isShowUtilities()) {
out.println();
out.println("Weighting of statistics:");
out.println();
out.println("U = ");
for (Statistic stat : statistics.getStatistics()) {
String statName = stat.getAbbreviation();
double weight = statistics.getWeight(stat);
if (weight != 0.0) {
out.println(" " + weight + " * f(" + statName + ")");
}
}
out.println();
out.println("...normed to range between 0 and 1.");
out.println();
out.println("Note that f for each statistic is a function that maps the statistic to the ");
out.println("interval [0, 1], with higher being better.");
}
out.println();
out.println("Graphs are being compared to the " + comparisonGraph.toString().replace("_", " ") + ".");
out.println();
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
// Print all of the tables.
printStats(statTables, statistics, Mode.Average, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
statTables = calcStatTables(allStats, Mode.StandardDeviation, numTables, algorithmSimulationWrappers, numStats, statistics);
printStats(statTables, statistics, Mode.StandardDeviation, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
statTables = calcStatTables(allStats, Mode.WorstCase, numTables, algorithmSimulationWrappers, numStats, statistics);
// Add utilities to table as the last column.
for (int u = 0; u < numTables; u++) {
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
statTables[u][t][numStats] = utilities[t];
}
}
printStats(statTables, statistics, Mode.WorstCase, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
}
out.close();
}
use of edu.cmu.tetrad.util.CombinationGenerator in project tetrad by cmu-phil.
the class VcpcAlt method search.
/**
* Runs PC starting with a fully connected graph over all of the variables in the domain of the independence test.
* See PC for caveats. The number of possible cycles and bidirected edges is far less with CPC than with PC.
*/
// public final Graph search() {
// return search(independenceTest.getVariable());
// }
// // public Graph search(List<Node> nodes) {
// //
// //// return search(new FasICov2(getIndependenceTest()), nodes);
// //// return search(new Fas(getIndependenceTest()), nodes);
// // return search(new Fas(getIndependenceTest()), nodes);
// }
// modified FAS into VCFAS; added in definitelyNonadjacencies set of edges.
public Graph search() {
this.logger.log("info", "Starting VCCPC algorithm");
this.logger.log("info", "Independence test = " + getIndependenceTest() + ".");
this.allTriples = new HashSet<>();
this.ambiguousTriples = new HashSet<>();
this.colliderTriples = new HashSet<>();
this.noncolliderTriples = new HashSet<>();
Vcfas fas = new Vcfas(getIndependenceTest());
definitelyNonadjacencies = new HashSet<>();
markovInAllPatterns = new HashSet<>();
// this.logger.log("info", "Variables " + independenceTest.getVariable());
long startTime = System.currentTimeMillis();
if (getIndependenceTest() == null) {
throw new NullPointerException();
}
List<Node> allNodes = getIndependenceTest().getVariables();
// if (!allNodes.containsAll(nodes)) {
// throw new IllegalArgumentException("All of the given nodes must " +
// "be in the domain of the independence test provided.");
// }
// Fas fas = new Fas(graph, getIndependenceTest());
// FasStableConcurrent fas = new FasStableConcurrent(graph, getIndependenceTest());
// Fas6 fas = new Fas6(graph, getIndependenceTest());
// fas = new FasICov(graph, (IndTestFisherZ) getIndependenceTest());
fas.setKnowledge(getKnowledge());
fas.setDepth(getDepth());
fas.setVerbose(verbose);
// Note that we are ignoring the sepset map returned by this method
// on purpose; it is not used in this search.
graph = fas.search();
apparentlyNonadjacencies = fas.getApparentlyNonadjacencies();
if (isDoOrientation()) {
if (verbose) {
System.out.println("CPC orientation...");
}
SearchGraphUtils.pcOrientbk(knowledge, graph, allNodes);
orientUnshieldedTriples(knowledge, getIndependenceTest(), getDepth());
// orientUnshieldedTriplesConcurrent(knowledge, getIndependenceTest(), getMaxIndegree());
MeekRules meekRules = new MeekRules();
meekRules.setAggressivelyPreventCycles(this.aggressivelyPreventCycles);
meekRules.setKnowledge(knowledge);
meekRules.orientImplied(graph);
}
List<Triple> ambiguousTriples = new ArrayList(graph.getAmbiguousTriples());
int[] dims = new int[ambiguousTriples.size()];
for (int i = 0; i < ambiguousTriples.size(); i++) {
dims[i] = 2;
}
List<Graph> patterns = new ArrayList<>();
Map<Graph, List<Triple>> newColliders = new IdentityHashMap<>();
Map<Graph, List<Triple>> newNonColliders = new IdentityHashMap<>();
// Using combination generator to generate a list of combinations of ambiguous triples dismabiguated into colliders
// and non-colliders. The combinations are added as graphs to the list patterns. The graphs are then subject to
// basic rules to ensure consistent patterns.
CombinationGenerator generator = new CombinationGenerator(dims);
int[] combination;
while ((combination = generator.next()) != null) {
Graph _graph = new EdgeListGraph(graph);
newColliders.put(_graph, new ArrayList<Triple>());
newNonColliders.put(_graph, new ArrayList<Triple>());
for (Graph graph : newColliders.keySet()) {
// System.out.println("$$$ " + newColliders.get(graph));
}
for (int k = 0; k < combination.length; k++) {
// System.out.println("k = " + combination[k]);
Triple triple = ambiguousTriples.get(k);
_graph.removeAmbiguousTriple(triple.getX(), triple.getY(), triple.getZ());
if (combination[k] == 0) {
newColliders.get(_graph).add(triple);
// System.out.println(newColliders.get(_graph));
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
_graph.setEndpoint(x, y, Endpoint.ARROW);
_graph.setEndpoint(z, y, Endpoint.ARROW);
}
if (combination[k] == 1) {
newNonColliders.get(_graph).add(triple);
}
}
patterns.add(_graph);
}
List<Graph> _patterns = new ArrayList<>(patterns);
GRAPH: for (Graph graph : new ArrayList<>(patterns)) {
// _graph = new EdgeListGraph(graph);
// System.out.println("graph = " + graph + " in keyset? " + newColliders.containsKey(graph));
//
List<Triple> colliders = newColliders.get(graph);
List<Triple> nonColliders = newNonColliders.get(graph);
for (Triple triple : colliders) {
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
if (graph.getEdge(x, y).pointsTowards(x) || (graph.getEdge(y, z).pointsTowards(z))) {
patterns.remove(graph);
continue GRAPH;
}
}
for (Triple triple : colliders) {
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
graph.setEndpoint(x, y, Endpoint.ARROW);
graph.setEndpoint(z, y, Endpoint.ARROW);
}
for (Triple triple : nonColliders) {
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
if (graph.getEdge(x, y).pointsTowards(y)) {
graph.removeEdge(y, z);
graph.addDirectedEdge(y, z);
}
if (graph.getEdge(y, z).pointsTowards(y)) {
graph.removeEdge(x, y);
graph.addDirectedEdge(y, x);
}
}
for (Edge edge : graph.getEdges()) {
if (Edges.isBidirectedEdge(edge)) {
patterns.remove(graph);
continue GRAPH;
}
}
MeekRules rules = new MeekRules();
rules.orientImplied(graph);
if (graph.existsDirectedCycle()) {
patterns.remove(graph);
continue GRAPH;
}
}
MARKOV: for (Edge edge : apparentlyNonadjacencies.keySet()) {
Node x = edge.getNode1();
Node y = edge.getNode2();
for (Graph _graph : new ArrayList<>(patterns)) {
List<Node> boundaryX = new ArrayList<>(boundary(x, _graph));
List<Node> boundaryY = new ArrayList<>(boundary(y, _graph));
List<Node> futureX = new ArrayList<>(future(x, _graph));
List<Node> futureY = new ArrayList<>(future(y, _graph));
if (y == x) {
continue;
}
if (boundaryX.contains(y) || boundaryY.contains(x)) {
continue;
}
IndependenceTest test = independenceTest;
if (!futureX.contains(y)) {
if (test.isIndependent(x, y, boundaryX)) {
if (!futureY.contains(x)) {
if (test.isIndependent(y, x, boundaryY)) {
definitelyNonadjacencies.add(edge);
continue MARKOV;
}
}
}
}
}
}
for (Edge edge : definitelyNonadjacencies) {
if (apparentlyNonadjacencies.keySet().contains(edge)) {
apparentlyNonadjacencies.keySet().remove(edge);
}
}
// MARKOV:
//
// for (Edge edge : apparentlyNonadjacencies.keySet()) {
// Node x = edge.getNode1();
// Node y = edge.getNode2();
//
// for (Graph _graph : new ArrayList<Graph>(patterns)) {
//
// List<Node> boundaryX = new ArrayList<Node>(boundary(x, _graph));
// List<Node> boundaryY = new ArrayList<Node>(boundary(y, _graph));
// List<Node> futureX = new ArrayList<Node>(future(x, _graph));
// List<Node> futureY = new ArrayList<Node>(future(y, _graph));
// if (y == x) {
// continue;
// }
// if (futureX.contains(y) || futureY.contains(x)) {
// continue;
// }
// if (boundaryX.contains(y) || boundaryY.contains(x)) {
// continue;
// }
//
// System.out.println(_graph);
// IndependenceTest test = new IndTestDSep(_graph);
// if (!test.isIndependent(x, y, boundaryX)) {
// continue;
// }
// if (!test.isIndependent(y, x, boundaryY)) {
// continue;
// }
//
// definitelyNonadjacencies.add(edge);
// continue MARKOV;
// }
//
// // apparentlyNonadjacencies.remove(edge);
//
// }
//
// for (Edge edge : definitelyNonadjacencies) {
// if (apparentlyNonadjacencies.keySet().contains(edge)) {
// apparentlyNonadjacencies.keySet().remove(edge);
// }
// }
// Step V5. For each consistent disambiguation of the ambiguous triples
// we test whether the resulting pattern satisfies Markov. If
// every pattern does, then mark all the apparently non-adjacent
// pairs as definitely non-adjacent.
// NODES:
//
// for (Node node : graph.getNodes()) {
// for (Graph _graph : new ArrayList<Graph>(patterns)) {
// System.out.println("boundary of" + node + boundary(node, _graph));
// System.out.println("future of" + node + future(node, _graph));
// if (!isMarkov(node, _graph)) {
// continue NODES;
// }
// }
// markovInAllPatterns.add(node);
// continue NODES;
// }
//
// Graph g = new EdgeListGraph(graph.getNodes());
// for (Edge edge : apparentlyNonadjacencies.keySet()) {
// g.addEdge(edge);
// }
//
// List<Edge> _edges = g.getEdges();
//
// for (Edge edge : _edges) {
// Node x = edge.getNode1();
// Node y = edge.getNode2();
//
// if (markovInAllPatterns.contains(x) &&
// markovInAllPatterns.contains(y)) {
// definitelyNonadjacencies.add(edge);
// }
// }
System.out.println("Definitely Nonadjacencies:");
for (Edge edge : definitelyNonadjacencies) {
System.out.println(edge);
}
System.out.println("markov in all patterns:" + markovInAllPatterns);
System.out.println("patterns:" + patterns);
System.out.println("Apparently Nonadjacencies:");
for (Edge edge : apparentlyNonadjacencies.keySet()) {
System.out.println(edge);
}
System.out.println("Definitely Nonadjacencies:");
for (Edge edge : definitelyNonadjacencies) {
System.out.println(edge);
}
TetradLogger.getInstance().log("apparentlyNonadjacencies", "\n Apparent Non-adjacencies" + apparentlyNonadjacencies);
TetradLogger.getInstance().log("definitelyNonadjacencies", "\n Definite Non-adjacencies" + definitelyNonadjacencies);
TetradLogger.getInstance().log("patterns", "Disambiguated Patterns: " + patterns);
TetradLogger.getInstance().log("graph", "\nReturning this graph: " + graph);
long endTime = System.currentTimeMillis();
this.elapsedTime = endTime - startTime;
TetradLogger.getInstance().log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
TetradLogger.getInstance().log("info", "Finishing CPC algorithm.");
logTriples();
TetradLogger.getInstance().flush();
// SearchGraphUtils.verifySepsetIntegrity(Map<Edge, List<Node>>, graph);
return graph;
}
use of edu.cmu.tetrad.util.CombinationGenerator in project tetrad by cmu-phil.
the class SampleVcpc method search.
/**
* Runs PC starting with a fully connected graph over all of the variables in the domain of the independence test.
* See PC for caveats. The number of possible cycles and bidirected edges is far less with CPC than with PC.
*/
// public final Graph search() {
// return search(independenceTest.getVariable());
// }
// // public Graph search(List<Node> nodes) {
// //
// //// return search(new FasICov2(getIndependenceTest()), nodes);
// //// return search(new Fas(getIndependenceTest()), nodes);
// // return search(new Fas(getIndependenceTest()), nodes);
// }
// modified FAS into VCFAS; added in definitelyNonadjacencies set of edges.
public Graph search() {
this.logger.log("info", "Starting VCCPC algorithm");
this.logger.log("info", "Independence test = " + getIndependenceTest() + ".");
this.allTriples = new HashSet<>();
this.ambiguousTriples = new HashSet<>();
this.colliderTriples = new HashSet<>();
this.noncolliderTriples = new HashSet<>();
Vcfas fas = new Vcfas(getIndependenceTest());
definitelyNonadjacencies = new HashSet<>();
markovInAllPatterns = new HashSet<>();
// this.logger.log("info", "Variables " + independenceTest.getVariable());
long startTime = System.currentTimeMillis();
if (getIndependenceTest() == null) {
throw new NullPointerException();
}
List<Node> allNodes = getIndependenceTest().getVariables();
// if (!allNodes.containsAll(nodes)) {
// throw new IllegalArgumentException("All of the given nodes must " +
// "be in the domain of the independence test provided.");
// }
// Fas fas = new Fas(graph, getIndependenceTest());
// FasStableConcurrent fas = new FasStableConcurrent(graph, getIndependenceTest());
// Fas6 fas = new Fas6(graph, getIndependenceTest());
// fas = new FasICov(graph, (IndTestFisherZ) getIndependenceTest());
fas.setKnowledge(getKnowledge());
fas.setDepth(getDepth());
fas.setVerbose(verbose);
// Note that we are ignoring the sepset map returned by this method
// on purpose; it is not used in this search.
graph = fas.search();
apparentlyNonadjacencies = fas.getApparentlyNonadjacencies();
if (isDoOrientation()) {
if (verbose) {
System.out.println("CPC orientation...");
}
SearchGraphUtils.pcOrientbk(knowledge, graph, allNodes);
orientUnshieldedTriples(knowledge, getIndependenceTest(), getDepth());
// orientUnshieldedTriplesConcurrent(knowledge, getIndependenceTest(), getMaxIndegree());
MeekRules meekRules = new MeekRules();
meekRules.setAggressivelyPreventCycles(this.aggressivelyPreventCycles);
meekRules.setKnowledge(knowledge);
meekRules.orientImplied(graph);
}
List<Triple> ambiguousTriples = new ArrayList(graph.getAmbiguousTriples());
int[] dims = new int[ambiguousTriples.size()];
for (int i = 0; i < ambiguousTriples.size(); i++) {
dims[i] = 2;
}
List<Graph> patterns = new ArrayList<>();
Map<Graph, List<Triple>> newColliders = new IdentityHashMap<>();
Map<Graph, List<Triple>> newNonColliders = new IdentityHashMap<>();
// Using combination generator to generate a list of combinations of ambiguous triples dismabiguated into colliders
// and non-colliders. The combinations are added as graphs to the list patterns. The graphs are then subject to
// basic rules to ensure consistent patterns.
CombinationGenerator generator = new CombinationGenerator(dims);
int[] combination;
while ((combination = generator.next()) != null) {
Graph _graph = new EdgeListGraph(graph);
newColliders.put(_graph, new ArrayList<Triple>());
newNonColliders.put(_graph, new ArrayList<Triple>());
for (Graph graph : newColliders.keySet()) {
// System.out.println("$$$ " + newColliders.get(graph));
}
for (int k = 0; k < combination.length; k++) {
// System.out.println("k = " + combination[k]);
Triple triple = ambiguousTriples.get(k);
_graph.removeAmbiguousTriple(triple.getX(), triple.getY(), triple.getZ());
if (combination[k] == 0) {
newColliders.get(_graph).add(triple);
// System.out.println(newColliders.get(_graph));
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
_graph.setEndpoint(x, y, Endpoint.ARROW);
_graph.setEndpoint(z, y, Endpoint.ARROW);
}
if (combination[k] == 1) {
newNonColliders.get(_graph).add(triple);
}
}
patterns.add(_graph);
}
List<Graph> _patterns = new ArrayList<>(patterns);
GRAPH: for (Graph graph : new ArrayList<>(patterns)) {
// _graph = new EdgeListGraph(graph);
// System.out.println("graph = " + graph + " in keyset? " + newColliders.containsKey(graph));
//
List<Triple> colliders = newColliders.get(graph);
List<Triple> nonColliders = newNonColliders.get(graph);
for (Triple triple : colliders) {
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
if (graph.getEdge(x, y).pointsTowards(x) || (graph.getEdge(y, z).pointsTowards(z))) {
patterns.remove(graph);
continue GRAPH;
}
}
for (Triple triple : colliders) {
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
graph.setEndpoint(x, y, Endpoint.ARROW);
graph.setEndpoint(z, y, Endpoint.ARROW);
}
for (Triple triple : nonColliders) {
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
if (graph.getEdge(x, y).pointsTowards(y)) {
graph.removeEdge(y, z);
graph.addDirectedEdge(y, z);
}
if (graph.getEdge(y, z).pointsTowards(y)) {
graph.removeEdge(x, y);
graph.addDirectedEdge(y, x);
}
}
for (Edge edge : graph.getEdges()) {
Node x = edge.getNode1();
Node y = edge.getNode2();
if (Edges.isBidirectedEdge(edge)) {
graph.removeEdge(x, y);
graph.addUndirectedEdge(x, y);
}
}
// for (Edge edge : graph.getEdges()) {
// if (Edges.isBidirectedEdge(edge)) {
// patterns.remove(graph);
// continue Graph;
// }
// }
MeekRules rules = new MeekRules();
rules.orientImplied(graph);
if (graph.existsDirectedCycle()) {
patterns.remove(graph);
continue GRAPH;
}
}
MARKOV: for (Edge edge : apparentlyNonadjacencies.keySet()) {
Node x = edge.getNode1();
Node y = edge.getNode2();
for (Graph _graph : new ArrayList<>(patterns)) {
List<Node> boundaryX = new ArrayList<>(boundary(x, _graph));
List<Node> boundaryY = new ArrayList<>(boundary(y, _graph));
List<Node> futureX = new ArrayList<>(future(x, _graph));
List<Node> futureY = new ArrayList<>(future(y, _graph));
if (y == x) {
continue;
}
if (boundaryX.contains(y) || boundaryY.contains(x)) {
continue;
}
IndependenceTest test = independenceTest;
if (!futureX.contains(y)) {
if (!test.isIndependent(x, y, boundaryX)) {
continue MARKOV;
}
}
if (!futureY.contains(x)) {
if (!test.isIndependent(y, x, boundaryY)) {
continue MARKOV;
}
}
}
definitelyNonadjacencies.add(edge);
// apparentlyNonadjacencies.remove(edge);
}
for (Edge edge : definitelyNonadjacencies) {
if (apparentlyNonadjacencies.keySet().contains(edge)) {
apparentlyNonadjacencies.keySet().remove(edge);
}
}
setSemIm(semIm);
// semIm.getSemPm().getGraph();
System.out.println(semIm.getEdgeCoef());
// graph = DataGraphUtils.replaceNodes(graph, semIm.getVariableNodes());
// System.out.println(semIm.getEdgeCoef());
// System.out.println(sampleRegress.entrySet());
List<Double> squaredDifference = new ArrayList<>();
int numNullEdges = 0;
// //Edge Estimation Alg I
Regression sampleRegression = new RegressionDataset(dataSet);
System.out.println(sampleRegression.getGraph());
graph = GraphUtils.replaceNodes(graph, dataSet.getVariables());
Map<Edge, double[]> sampleRegress = new HashMap<>();
Map<Edge, Double> edgeCoefs = new HashMap<>();
ESTIMATION: for (Node z : graph.getNodes()) {
Set<Edge> adj = getAdj(z, graph);
for (Edge edge : apparentlyNonadjacencies.keySet()) {
if (z == edge.getNode1() || z == edge.getNode2()) {
for (Edge adjacency : adj) {
// return Unknown and go to next Z
sampleRegress.put(adjacency, null);
Node a = adjacency.getNode1();
Node b = adjacency.getNode2();
if (semIm.existsEdgeCoef(a, b)) {
Double c = semIm.getEdgeCoef(a, b);
edgeCoefs.put(adjacency, c);
} else {
edgeCoefs.put(adjacency, 0.0);
}
}
continue ESTIMATION;
}
}
for (Edge nonadj : definitelyNonadjacencies) {
if (nonadj.getNode1() == z || nonadj.getNode2() == z) {
// return 0 for e
double[] d = { 0, 0 };
sampleRegress.put(nonadj, d);
Node a = nonadj.getNode1();
Node b = nonadj.getNode2();
if (semIm.existsEdgeCoef(a, b)) {
Double c = semIm.getEdgeCoef(a, b);
edgeCoefs.put(nonadj, c);
} else {
edgeCoefs.put(nonadj, 0.0);
}
}
}
Set<Edge> parentsOfZ = new HashSet<>();
Set<Edge> _adj = getAdj(z, graph);
for (Edge _adjacency : _adj) {
if (!_adjacency.isDirected()) {
for (Edge adjacency : adj) {
sampleRegress.put(adjacency, null);
Node a = adjacency.getNode1();
Node b = adjacency.getNode2();
if (semIm.existsEdgeCoef(a, b)) {
Double c = semIm.getEdgeCoef(a, b);
edgeCoefs.put(adjacency, c);
} else {
edgeCoefs.put(adjacency, 0.0);
}
}
}
if (_adjacency.pointsTowards(z)) {
parentsOfZ.add(_adjacency);
}
}
for (Edge edge : parentsOfZ) {
if (edge.pointsTowards(edge.getNode2())) {
RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
System.out.println(result);
double[] d = result.getCoef();
sampleRegress.put(edge, d);
Node a = edge.getNode1();
Node b = edge.getNode2();
if (semIm.existsEdgeCoef(a, b)) {
Double c = semIm.getEdgeCoef(a, b);
edgeCoefs.put(edge, c);
} else {
edgeCoefs.put(edge, 0.0);
}
}
// if (edge.pointsTowards(edge.getNode2())) {
// RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
// double[] d = result.getCoef();
// sampleRegress.put(edge, d);
//
// Node a = edge.getNode1();
// Node b = edge.getNode2();
// if (semIm.existsEdgeCoef(a, b)) {
// Double c = semIm.getEdgeCoef(a, b);
// edgeCoefs.put(edge, c);
// } else { edgeCoefs.put(edge, 0.0); }
// }
}
}
System.out.println("All IM: " + semIm + "Finish");
System.out.println("Just IM coefs: " + semIm.getEdgeCoef());
System.out.println("IM Coef Map: " + edgeCoefs);
System.out.println("Regress Coef Map: " + sampleRegress);
//
for (Edge edge : sampleRegress.keySet()) {
System.out.println(" Sample Regression: " + edge + java.util.Arrays.toString(sampleRegress.get(edge)));
}
for (Edge edge : graph.getEdges()) {
// if (edge.isDirected()) {
// System.out.println("IM edge: " + semIm.getEdgeCoef(edge));
// }
System.out.println("Sample edge: " + java.util.Arrays.toString(sampleRegress.get(edge)));
}
//
//
System.out.println("Sample VCPC:");
System.out.println("# of patterns: " + patterns.size());
long endTime = System.currentTimeMillis();
this.elapsedTime = endTime - startTime;
System.out.println("Search Time (seconds):" + (elapsedTime) / 1000 + " s");
System.out.println("Search Time (milli):" + elapsedTime + " ms");
System.out.println("# of Apparent Nonadj: " + apparentlyNonadjacencies.size());
System.out.println("# of Definite Nonadj: " + definitelyNonadjacencies.size());
// System.out.println("Definitely Nonadjacencies:");
//
// for (Edge edge : definitelyNonadjacencies) {
// System.out.println(edge);
// }
//
// System.out.println("markov in all patterns:" + markovInAllPatterns);
// System.out.println("patterns:" + patterns);
// System.out.println("Apparently Nonadjacencies:");
//
// for (Edge edge : apparentlyNonadjacencies.keySet()) {
// System.out.println(edge);
// }
// System.out.println("Definitely Nonadjacencies:");
//
//
// for (Edge edge : definitelyNonadjacencies) {
// System.out.println(edge);
// }
TetradLogger.getInstance().log("apparentlyNonadjacencies", "\n Apparent Non-adjacencies" + apparentlyNonadjacencies);
TetradLogger.getInstance().log("definitelyNonadjacencies", "\n Definite Non-adjacencies" + definitelyNonadjacencies);
TetradLogger.getInstance().log("patterns", "Disambiguated Patterns: " + patterns);
TetradLogger.getInstance().log("graph", "\nReturning this graph: " + graph);
TetradLogger.getInstance().log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
TetradLogger.getInstance().log("info", "Finishing CPC algorithm.");
logTriples();
TetradLogger.getInstance().flush();
// SearchGraphUtils.verifySepsetIntegrity(Map<Edge, List<Node>>, graph);
return graph;
}
use of edu.cmu.tetrad.util.CombinationGenerator in project tetrad by cmu-phil.
the class TestCombinationGenerator method test1.
@Test
public void test1() {
CombinationGenerator gen = new CombinationGenerator(new int[] { 5, 3 });
int count = 0;
while (gen.next() != null) {
count++;
}
assertEquals(15, count);
}
Aggregations