use of edu.cmu.tetrad.regression.Regression in project tetrad by cmu-phil.
the class SampleVcpc method search.
/**
* Runs PC starting with a fully connected graph over all of the variables in the domain of the independence test.
* See PC for caveats. The number of possible cycles and bidirected edges is far less with CPC than with PC.
*/
// public final Graph search() {
// return search(independenceTest.getVariable());
// }
// // public Graph search(List<Node> nodes) {
// //
// //// return search(new FasICov2(getIndependenceTest()), nodes);
// //// return search(new Fas(getIndependenceTest()), nodes);
// // return search(new Fas(getIndependenceTest()), nodes);
// }
// modified FAS into VCFAS; added in definitelyNonadjacencies set of edges.
public Graph search() {
this.logger.log("info", "Starting VCCPC algorithm");
this.logger.log("info", "Independence test = " + getIndependenceTest() + ".");
this.allTriples = new HashSet<>();
this.ambiguousTriples = new HashSet<>();
this.colliderTriples = new HashSet<>();
this.noncolliderTriples = new HashSet<>();
Vcfas fas = new Vcfas(getIndependenceTest());
definitelyNonadjacencies = new HashSet<>();
markovInAllPatterns = new HashSet<>();
// this.logger.log("info", "Variables " + independenceTest.getVariable());
long startTime = System.currentTimeMillis();
if (getIndependenceTest() == null) {
throw new NullPointerException();
}
List<Node> allNodes = getIndependenceTest().getVariables();
// if (!allNodes.containsAll(nodes)) {
// throw new IllegalArgumentException("All of the given nodes must " +
// "be in the domain of the independence test provided.");
// }
// Fas fas = new Fas(graph, getIndependenceTest());
// FasStableConcurrent fas = new FasStableConcurrent(graph, getIndependenceTest());
// Fas6 fas = new Fas6(graph, getIndependenceTest());
// fas = new FasICov(graph, (IndTestFisherZ) getIndependenceTest());
fas.setKnowledge(getKnowledge());
fas.setDepth(getDepth());
fas.setVerbose(verbose);
// Note that we are ignoring the sepset map returned by this method
// on purpose; it is not used in this search.
graph = fas.search();
apparentlyNonadjacencies = fas.getApparentlyNonadjacencies();
if (isDoOrientation()) {
if (verbose) {
System.out.println("CPC orientation...");
}
SearchGraphUtils.pcOrientbk(knowledge, graph, allNodes);
orientUnshieldedTriples(knowledge, getIndependenceTest(), getDepth());
// orientUnshieldedTriplesConcurrent(knowledge, getIndependenceTest(), getMaxIndegree());
MeekRules meekRules = new MeekRules();
meekRules.setAggressivelyPreventCycles(this.aggressivelyPreventCycles);
meekRules.setKnowledge(knowledge);
meekRules.orientImplied(graph);
}
List<Triple> ambiguousTriples = new ArrayList(graph.getAmbiguousTriples());
int[] dims = new int[ambiguousTriples.size()];
for (int i = 0; i < ambiguousTriples.size(); i++) {
dims[i] = 2;
}
List<Graph> patterns = new ArrayList<>();
Map<Graph, List<Triple>> newColliders = new IdentityHashMap<>();
Map<Graph, List<Triple>> newNonColliders = new IdentityHashMap<>();
// Using combination generator to generate a list of combinations of ambiguous triples dismabiguated into colliders
// and non-colliders. The combinations are added as graphs to the list patterns. The graphs are then subject to
// basic rules to ensure consistent patterns.
CombinationGenerator generator = new CombinationGenerator(dims);
int[] combination;
while ((combination = generator.next()) != null) {
Graph _graph = new EdgeListGraph(graph);
newColliders.put(_graph, new ArrayList<Triple>());
newNonColliders.put(_graph, new ArrayList<Triple>());
for (Graph graph : newColliders.keySet()) {
// System.out.println("$$$ " + newColliders.get(graph));
}
for (int k = 0; k < combination.length; k++) {
// System.out.println("k = " + combination[k]);
Triple triple = ambiguousTriples.get(k);
_graph.removeAmbiguousTriple(triple.getX(), triple.getY(), triple.getZ());
if (combination[k] == 0) {
newColliders.get(_graph).add(triple);
// System.out.println(newColliders.get(_graph));
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
_graph.setEndpoint(x, y, Endpoint.ARROW);
_graph.setEndpoint(z, y, Endpoint.ARROW);
}
if (combination[k] == 1) {
newNonColliders.get(_graph).add(triple);
}
}
patterns.add(_graph);
}
List<Graph> _patterns = new ArrayList<>(patterns);
GRAPH: for (Graph graph : new ArrayList<>(patterns)) {
// _graph = new EdgeListGraph(graph);
// System.out.println("graph = " + graph + " in keyset? " + newColliders.containsKey(graph));
//
List<Triple> colliders = newColliders.get(graph);
List<Triple> nonColliders = newNonColliders.get(graph);
for (Triple triple : colliders) {
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
if (graph.getEdge(x, y).pointsTowards(x) || (graph.getEdge(y, z).pointsTowards(z))) {
patterns.remove(graph);
continue GRAPH;
}
}
for (Triple triple : colliders) {
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
graph.setEndpoint(x, y, Endpoint.ARROW);
graph.setEndpoint(z, y, Endpoint.ARROW);
}
for (Triple triple : nonColliders) {
Node x = triple.getX();
Node y = triple.getY();
Node z = triple.getZ();
if (graph.getEdge(x, y).pointsTowards(y)) {
graph.removeEdge(y, z);
graph.addDirectedEdge(y, z);
}
if (graph.getEdge(y, z).pointsTowards(y)) {
graph.removeEdge(x, y);
graph.addDirectedEdge(y, x);
}
}
for (Edge edge : graph.getEdges()) {
Node x = edge.getNode1();
Node y = edge.getNode2();
if (Edges.isBidirectedEdge(edge)) {
graph.removeEdge(x, y);
graph.addUndirectedEdge(x, y);
}
}
// for (Edge edge : graph.getEdges()) {
// if (Edges.isBidirectedEdge(edge)) {
// patterns.remove(graph);
// continue Graph;
// }
// }
MeekRules rules = new MeekRules();
rules.orientImplied(graph);
if (graph.existsDirectedCycle()) {
patterns.remove(graph);
continue GRAPH;
}
}
MARKOV: for (Edge edge : apparentlyNonadjacencies.keySet()) {
Node x = edge.getNode1();
Node y = edge.getNode2();
for (Graph _graph : new ArrayList<>(patterns)) {
List<Node> boundaryX = new ArrayList<>(boundary(x, _graph));
List<Node> boundaryY = new ArrayList<>(boundary(y, _graph));
List<Node> futureX = new ArrayList<>(future(x, _graph));
List<Node> futureY = new ArrayList<>(future(y, _graph));
if (y == x) {
continue;
}
if (boundaryX.contains(y) || boundaryY.contains(x)) {
continue;
}
IndependenceTest test = independenceTest;
if (!futureX.contains(y)) {
if (!test.isIndependent(x, y, boundaryX)) {
continue MARKOV;
}
}
if (!futureY.contains(x)) {
if (!test.isIndependent(y, x, boundaryY)) {
continue MARKOV;
}
}
}
definitelyNonadjacencies.add(edge);
// apparentlyNonadjacencies.remove(edge);
}
for (Edge edge : definitelyNonadjacencies) {
if (apparentlyNonadjacencies.keySet().contains(edge)) {
apparentlyNonadjacencies.keySet().remove(edge);
}
}
setSemIm(semIm);
// semIm.getSemPm().getGraph();
System.out.println(semIm.getEdgeCoef());
// graph = DataGraphUtils.replaceNodes(graph, semIm.getVariableNodes());
// System.out.println(semIm.getEdgeCoef());
// System.out.println(sampleRegress.entrySet());
List<Double> squaredDifference = new ArrayList<>();
int numNullEdges = 0;
// //Edge Estimation Alg I
Regression sampleRegression = new RegressionDataset(dataSet);
System.out.println(sampleRegression.getGraph());
graph = GraphUtils.replaceNodes(graph, dataSet.getVariables());
Map<Edge, double[]> sampleRegress = new HashMap<>();
Map<Edge, Double> edgeCoefs = new HashMap<>();
ESTIMATION: for (Node z : graph.getNodes()) {
Set<Edge> adj = getAdj(z, graph);
for (Edge edge : apparentlyNonadjacencies.keySet()) {
if (z == edge.getNode1() || z == edge.getNode2()) {
for (Edge adjacency : adj) {
// return Unknown and go to next Z
sampleRegress.put(adjacency, null);
Node a = adjacency.getNode1();
Node b = adjacency.getNode2();
if (semIm.existsEdgeCoef(a, b)) {
Double c = semIm.getEdgeCoef(a, b);
edgeCoefs.put(adjacency, c);
} else {
edgeCoefs.put(adjacency, 0.0);
}
}
continue ESTIMATION;
}
}
for (Edge nonadj : definitelyNonadjacencies) {
if (nonadj.getNode1() == z || nonadj.getNode2() == z) {
// return 0 for e
double[] d = { 0, 0 };
sampleRegress.put(nonadj, d);
Node a = nonadj.getNode1();
Node b = nonadj.getNode2();
if (semIm.existsEdgeCoef(a, b)) {
Double c = semIm.getEdgeCoef(a, b);
edgeCoefs.put(nonadj, c);
} else {
edgeCoefs.put(nonadj, 0.0);
}
}
}
Set<Edge> parentsOfZ = new HashSet<>();
Set<Edge> _adj = getAdj(z, graph);
for (Edge _adjacency : _adj) {
if (!_adjacency.isDirected()) {
for (Edge adjacency : adj) {
sampleRegress.put(adjacency, null);
Node a = adjacency.getNode1();
Node b = adjacency.getNode2();
if (semIm.existsEdgeCoef(a, b)) {
Double c = semIm.getEdgeCoef(a, b);
edgeCoefs.put(adjacency, c);
} else {
edgeCoefs.put(adjacency, 0.0);
}
}
}
if (_adjacency.pointsTowards(z)) {
parentsOfZ.add(_adjacency);
}
}
for (Edge edge : parentsOfZ) {
if (edge.pointsTowards(edge.getNode2())) {
RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
System.out.println(result);
double[] d = result.getCoef();
sampleRegress.put(edge, d);
Node a = edge.getNode1();
Node b = edge.getNode2();
if (semIm.existsEdgeCoef(a, b)) {
Double c = semIm.getEdgeCoef(a, b);
edgeCoefs.put(edge, c);
} else {
edgeCoefs.put(edge, 0.0);
}
}
// if (edge.pointsTowards(edge.getNode2())) {
// RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
// double[] d = result.getCoef();
// sampleRegress.put(edge, d);
//
// Node a = edge.getNode1();
// Node b = edge.getNode2();
// if (semIm.existsEdgeCoef(a, b)) {
// Double c = semIm.getEdgeCoef(a, b);
// edgeCoefs.put(edge, c);
// } else { edgeCoefs.put(edge, 0.0); }
// }
}
}
System.out.println("All IM: " + semIm + "Finish");
System.out.println("Just IM coefs: " + semIm.getEdgeCoef());
System.out.println("IM Coef Map: " + edgeCoefs);
System.out.println("Regress Coef Map: " + sampleRegress);
//
for (Edge edge : sampleRegress.keySet()) {
System.out.println(" Sample Regression: " + edge + java.util.Arrays.toString(sampleRegress.get(edge)));
}
for (Edge edge : graph.getEdges()) {
// if (edge.isDirected()) {
// System.out.println("IM edge: " + semIm.getEdgeCoef(edge));
// }
System.out.println("Sample edge: " + java.util.Arrays.toString(sampleRegress.get(edge)));
}
//
//
System.out.println("Sample VCPC:");
System.out.println("# of patterns: " + patterns.size());
long endTime = System.currentTimeMillis();
this.elapsedTime = endTime - startTime;
System.out.println("Search Time (seconds):" + (elapsedTime) / 1000 + " s");
System.out.println("Search Time (milli):" + elapsedTime + " ms");
System.out.println("# of Apparent Nonadj: " + apparentlyNonadjacencies.size());
System.out.println("# of Definite Nonadj: " + definitelyNonadjacencies.size());
// System.out.println("Definitely Nonadjacencies:");
//
// for (Edge edge : definitelyNonadjacencies) {
// System.out.println(edge);
// }
//
// System.out.println("markov in all patterns:" + markovInAllPatterns);
// System.out.println("patterns:" + patterns);
// System.out.println("Apparently Nonadjacencies:");
//
// for (Edge edge : apparentlyNonadjacencies.keySet()) {
// System.out.println(edge);
// }
// System.out.println("Definitely Nonadjacencies:");
//
//
// for (Edge edge : definitelyNonadjacencies) {
// System.out.println(edge);
// }
TetradLogger.getInstance().log("apparentlyNonadjacencies", "\n Apparent Non-adjacencies" + apparentlyNonadjacencies);
TetradLogger.getInstance().log("definitelyNonadjacencies", "\n Definite Non-adjacencies" + definitelyNonadjacencies);
TetradLogger.getInstance().log("patterns", "Disambiguated Patterns: " + patterns);
TetradLogger.getInstance().log("graph", "\nReturning this graph: " + graph);
TetradLogger.getInstance().log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
TetradLogger.getInstance().log("info", "Finishing CPC algorithm.");
logTriples();
TetradLogger.getInstance().flush();
// SearchGraphUtils.verifySepsetIntegrity(Map<Edge, List<Node>>, graph);
return graph;
}
use of edu.cmu.tetrad.regression.Regression in project tetrad by cmu-phil.
the class RegressionRunner method execute.
// =================PUBLIC METHODS OVERRIDING ABSTRACT=================//
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
public void execute() {
if (regressorNames.size() == 0 || targetName == null) {
outGraph = new EdgeListGraph();
return;
}
if (regressorNames.contains(targetName)) {
outGraph = new EdgeListGraph();
return;
}
Regression regression;
Node target;
List<Node> regressors;
if (getDataModel() instanceof DataSet) {
DataSet _dataSet = (DataSet) getDataModel();
regression = new RegressionDataset(_dataSet);
target = _dataSet.getVariable(targetName);
regressors = new LinkedList<>();
for (String regressorName : regressorNames) {
regressors.add(_dataSet.getVariable(regressorName));
}
double alpha = params.getDouble("alpha", 0.001);
regression.setAlpha(alpha);
result = regression.regress(target, regressors);
outGraph = regression.getGraph();
} else if (getDataModel() instanceof ICovarianceMatrix) {
ICovarianceMatrix covariances = (ICovarianceMatrix) getDataModel();
regression = new RegressionCovariance(covariances);
target = covariances.getVariable(targetName);
regressors = new LinkedList<>();
for (String regressorName : regressorNames) {
regressors.add(covariances.getVariable(regressorName));
}
double alpha = params.getDouble("alpha", 0.001);
regression.setAlpha(alpha);
result = regression.regress(target, regressors);
outGraph = regression.getGraph();
}
setResultGraph(outGraph);
}
use of edu.cmu.tetrad.regression.Regression in project tetrad by cmu-phil.
the class BffBeam method removeZeroEdges.
public Graph removeZeroEdges(Graph bestGraph) {
boolean changed = true;
Graph graph = new EdgeListGraph(bestGraph);
while (changed) {
changed = false;
Score score = scoreGraph(graph);
SemIm estSem = score.getEstimatedSem();
for (Parameter param : estSem.getSemPm().getParameters()) {
if (param.getType() != ParamType.COEF) {
continue;
}
Node nodeA = param.getNodeA();
Node nodeB = param.getNodeB();
Node parent;
Node child;
if (this.graph.isParentOf(nodeA, nodeB)) {
parent = nodeA;
child = nodeB;
} else {
parent = nodeB;
child = nodeA;
}
Regression regression = new RegressionCovariance(cov);
List<Node> parents = graph.getParents(child);
RegressionResult result = regression.regress(child, parents);
double p = result.getP()[parents.indexOf(parent) + 1];
if (p > getHighPValueAlpha()) {
Edge edge = graph.getEdge(param.getNodeA(), param.getNodeB());
if (getKnowledge().isRequired(edge.getNode1().getName(), edge.getNode2().getName())) {
System.out.println("Not removing " + edge + " because it is required.");
TetradLogger.getInstance().log("details", "Not removing " + edge + " because it is required.");
continue;
}
System.out.println("Removing edge " + edge + " because it has p = " + p);
TetradLogger.getInstance().log("details", "Removing edge " + edge + " because it has p = " + p);
graph.removeEdge(edge);
changed = true;
}
}
}
return graph;
}
use of edu.cmu.tetrad.regression.Regression in project tetrad by cmu-phil.
the class TestRegression method testTabular.
/**
* This tests whether the answer to a rather arbitrary problem changes.
* At one point, this was the answer being returned.
*/
@Test
public void testTabular() {
setUp();
RandomUtil.getInstance().setSeed(3848283L);
List<Node> nodes = data.getVariables();
Node target = nodes.get(0);
List<Node> regressors = new ArrayList<>();
for (int i = 1; i < nodes.size(); i++) {
regressors.add(nodes.get(i));
}
Regression regression = new RegressionDataset(data);
RegressionResult result = regression.regress(target, regressors);
double[] coeffs = result.getCoef();
assertEquals(.08, coeffs[0], 0.01);
assertEquals(-.05, coeffs[1], 0.01);
assertEquals(.035, coeffs[2], 0.01);
assertEquals(0.019, coeffs[3], 0.01);
assertEquals(-.003, coeffs[4], 0.01);
}
Aggregations