use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.
the class Identifiability method getJointMarginal.
// //////////////////////////////////////////////////////////////
// compute P_t(s):
// s is given by the variable-value pairs in the argument
// t is given by the evidence
//
// t and s should each be only one combination of variable values
// (i.e., no disjunctions of values of a variable even those they
// are allowed in Proposition)
//
public double getJointMarginal(int[] sVariables, int[] sValues) {
if (sVariables.length != sValues.length) {
throw new IllegalArgumentException("Values must match variables.");
}
// ////////////////////////////////////
// collect s and t variables
// the s variables
List<Node> sNodes = new ArrayList<>();
for (int sVariable : sVariables) {
sNodes.add(bayesIm.getNode(sVariable));
}
if (debug) {
System.out.println("\nsVariables: " + sNodes);
}
List<Node> tNodesTmp = evidence.getVariablesInEvidence();
// any other regular updater (here: RowSummingExactUpdater)
if (tNodesTmp.size() == 0) {
/*
ManipulatingBayesUpdater rowSumUpdater =
new RowSummingExactUpdater(bayesIm, Evidence.tautology(bayesIm));
return rowSumUpdater.getJointMarginal(sVariables, sValues);
*/
Proposition prop = Proposition.tautology(bayesIm);
for (int i = 0; i < sVariables.length; i++) {
/*
int[] variableValues = getVariableValues(i);
String nodeName = getVariable().get(j).getNode();
Node node = bayesIm.getNode(nodeName);
targetProp.setCategory(bayesIm.getNodeIndex(node),
variableValues[j]);
*/
prop.setCategory(sVariables[i], sValues[i]);
}
// restrict the proposition to only observed variables
Proposition propObs = new Proposition(((MlBayesImObs) bayesIm).getBayesImObs(), prop);
return ((MlBayesImObs) bayesIm).getJPD().getProb(propObs);
}
// recast the t variables
List<Node> tNodes = new ArrayList<>();
for (Node node1 : tNodesTmp) {
tNodes.add(bayesIm.getNode(node1.getName()));
}
if (debug) {
System.out.println("\ntVariables: " + tNodes);
}
// ////////////////////////////////////
// collect variable values that are set in S and T
int nNodes = bayesIm.getNumNodes();
int[] fixedVarValues = new int[nNodes];
for (int i = 0; i < nNodes; i++) {
// value not set
fixedVarValues[i] = -1;
}
// incorporate values of S
for (int i = 0; i < sVariables.length; i++) {
fixedVarValues[sVariables[i]] = sValues[i];
}
// assume all the variables are manipulated
for (int i = 0; i < evidence.getNumNodes(); i++) {
int tNodeValue = evidence.getProposition().getSingleCategory(i);
if (// only consider variables with a single value
tNodeValue != -1) {
Node tNode = evidence.getNode(i);
String tNodeStr = evidence.getNode(i).getName();
Node tNodeInBayesIm = bayesIm.getNode(tNodeStr);
int tIndexInBayesIm = bayesIm.getNodeIndex(tNodeInBayesIm);
String tCategoryStr = evidence.getCategory(tNode, tNodeValue);
int tValueInBayesIm = bayesIm.getBayesPm().getCategoryIndex(tNodeInBayesIm, tCategoryStr);
int oldValue = fixedVarValues[tIndexInBayesIm];
if (// S has a value for this variable
oldValue != -1) {
if (oldValue == tValueInBayesIm) // remove the variable from S (S and T should be disjoint)
{
sNodes.remove(tNodeInBayesIm);
if (debug) {
System.out.println("sNode removed: index: " + tNodeInBayesIm + "; name: " + tNodeStr);
}
} else // S and T have different values for this same variable
{
// S and T inconsistent: Prob = 0
return 0.0;
}
} else // no value for this variable yet
{
fixedVarValues[tIndexInBayesIm] = tValueInBayesIm;
}
}
}
if (debug) {
System.out.print("fixedVarValues: ");
for (int i = 0; i < nNodes; i++) {
System.out.print(fixedVarValues[i] + " ");
}
System.out.println();
}
// if all nodes in S are removed
if (sNodes.size() == 0) {
// Prob = 1
return 1.0;
}
// ////////////////////////////////////
// get c-components
int[] cComponents = getCComponents(bayesIm);
int nCComponents = nCComponents(cComponents);
// store the nodes in each c-component
List<List<Node>> cComponentNodes = new LinkedList<>();
for (int i = 0; i < nCComponents; i++) {
cComponentNodes.add(getCComponentNodes(bayesIm, cComponents, i));
}
if (debug) {
for (int i = 0; i < nCComponents; i++) {
System.out.println("c-component " + i + ": " + cComponentNodes.get(i));
}
}
// ////////////////////////////////////
// allowUnfaithfulness joint probability of all the measured variables
int[] probTermV = new int[nNodes];
for (int i = 0; i < nNodes; i++) {
if (bayesIm.getNode(i).getNodeType() == NodeType.MEASURED) {
probTermV[i] = 1;
}
}
if (debug) {
System.out.print("probTermV: ");
for (int i = 0; i < nNodes; i++) {
System.out.print(probTermV[i] + " ");
}
System.out.println();
}
// get C-factors
QList[] cFactors = new QList[nCComponents];
for (int i = 0; i < nCComponents; i++) {
// Q[V]: allowUnfaithfulness joint probTermV
QList qV = new QList(nNodes, probTermV);
if (debug) {
System.out.println("cFactors " + i + " " + bayesIm.getDag().getNodes() + " " + cComponentNodes.get(i));
System.out.println("============== QList: qV ==============");
qV.printQList(0, 0);
}
cFactors[i] = qDecomposition(bayesIm, bayesIm.getDag().getNodes(), cComponentNodes.get(i), qV);
if (debug) {
System.out.println("============== QList: cFactors[" + i + "] ==============");
cFactors[i].printQList(0, 0);
}
}
// ////////////////////////////////////
// get D
// Note: "dag" is a new copy of the dag; otherwise modifications
// would be made to the one in the bayesIm
// Note: the ordering of the nodes may not be the same as in
// the original graph
Dag dag = new Dag(bayesIm.getDag());
if (debug) {
System.out.println("------ here1 -------------");
System.out.println(bayesIm.getDag());
// watch out! tNodes may be empty
// System.out.println(tNodes.get(0));
// System.out.println(dag.getNodes().get(1));
// System.out.println(tNodes.get(0).equals(dag.getNodes().get(1)));
}
dag.removeNodes(tNodes);
if (debug) {
System.out.println("------ here2 -------------");
System.out.println(bayesIm.getDag());
}
List<Node> dNodes = dag.getAncestors(sNodes);
// create a Bayes IM with the dag G_dNodes
Dag gD = new Dag(bayesIm.getDag().subgraph(dNodes));
BayesPm bayesPmD = new BayesPm(gD, bayesIm.getBayesPm());
BayesIm bayesImD = new MlBayesIm(bayesPmD, bayesIm, MlBayesIm.RANDOM);
if (debug) {
System.out.println("------ bayeIm.getDag() -------------");
System.out.println(bayesIm.getDag());
System.out.println("------ gD -------------");
System.out.println(gD);
System.out.println("------ bayeImD.getDag() -------------");
System.out.println(bayesImD.getDag());
System.out.println("bayesIm node 0: " + bayesIm.getNode(0));
System.out.println("bayesImD node 0: " + bayesImD.getNode(0));
}
// get c-components of gD
int[] cComponentsD = getCComponents(bayesImD);
int nCComponentsD = nCComponents(cComponentsD);
// Q[Di]
QList[] qD = new QList[nCComponentsD];
for (int i = 0; i < nCComponentsD; i++) {
// Di
List<Node> cComponentNodesDi = getCComponentNodes(bayesImD, cComponentsD, i);
// Sj
// Find the index j of the c-component Sj in cComponentNodes
// which is a superset of cComponentNodesDi
//
int j = 0;
boolean flag = false;
while ((j < nCComponents) && !flag) {
List<Node> cComponentNodesSj = cComponentNodes.get(j);
if (cComponentNodesSj.containsAll(cComponentNodesDi)) {
if (debug) {
System.out.println("----- Di Sj --------");
System.out.println(i + " " + cComponentNodesDi + " " + cComponentNodesSj);
}
flag = true;
qD[i] = identify(cComponentNodesDi, cComponentNodesSj, cFactors[j]);
// fail: qD[i] not identifiable with this algorithm
if (qD[i] == null) {
if (debug) {
System.out.println("----- FAIL qD[" + i + "] --------");
}
// fail: P_t(s) not identifiable with this algorithm
return -1.0;
}
if (debug) {
System.out.println("======================== QList: qD[" + i + "] =================");
qD[i].printQList(0, 0);
}
}
j++;
}
if (// something is wrong
!flag) {
throw new RuntimeException("getJointMarginal: Sj not found");
}
}
// ////////////////////////////////////
// multiply the Q[Di]'s
QList qDProducts = new QList(nNodes);
int[] sumOverVariables = new int[nNodes];
for (int i = 0; i < nNodes; i++) {
sumOverVariables[i] = 0;
}
for (int i = 0; i < nCComponentsD; i++) {
qDProducts.add(qD[i], sumOverVariables, true);
}
// P_t(s)
QList qPTS = new QList(nNodes);
qPTS.add(qDProducts, sumList(nNodes, dNodes, sNodes), true);
if (debug) {
System.out.println("***************************** QList: qPTS *******************");
qPTS.printQList(0, 0);
}
// compute numeric value from the algebraic expression qPTS
if (debug) {
System.out.println("***************************** computeValue *******************");
}
return qPTS.computeValue(bayesIm, fixedVarValues);
}
use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.
the class Identifiability method qDecomposition.
// ///////////////////////////////////////////////////////////////
// Compute generalized Q-decomposition (Tian and Pearl 2002, Lemma 4)
//
// hj is a c-component in subgraph graphWhole_h
// qH is the q-factor for h
//
// return q-factor of hj
//
// ??? may not need to pass graphWhole as argument ???
private QList qDecomposition(BayesIm graphWhole, List<Node> h, List<Node> hj, QList qH) {
Dag graphH = new Dag(graphWhole.getDag().subgraph(h));
// tier ordering
List<Node> tierOrdering = graphH.getCausalOrdering();
// convert to the indices of the original graph
// (from which the subgraph was obtained)
int tierSize = tierOrdering.size();
int[] tiers = new int[tierSize];
for (int i = 0; i < tierSize; i++) {
tiers[i] = graphWhole.getNodeIndex(tierOrdering.get(i));
}
if (debug) {
System.out.print("************************* QDecomposition: Tier ordering: ");
for (int i = 0; i < tierSize; i++) {
System.out.print(graphWhole.getNode(tiers[i]) + " ");
}
System.out.println();
}
int nNodes = graphWhole.getNumNodes();
QList qHj = new QList(nNodes);
for (Node nodeHj : hj) {
// index of node hj in the original graph
int nodeHjIndex = graphWhole.getNodeIndex(nodeHj);
if (graphWhole.getNode(nodeHjIndex).getNodeType() == // skip latent variables
NodeType.MEASURED) {
// get index of node hj in the tier ordering of the nodes
// of the original graph
int nodeHjTierIndex;
for (nodeHjTierIndex = 0; nodeHjTierIndex < tierSize; nodeHjTierIndex++) {
if (tiers[nodeHjTierIndex] == nodeHjIndex) {
break;
}
}
if (nodeHjTierIndex == tierSize) {
throw new RuntimeException("qDecomposition: index out of bound");
}
// Q[H^i]
int[] sumOverVariables = new int[nNodes];
for (int i = 0; i < nNodes; i++) {
sumOverVariables[i] = 0;
}
for (int i = nodeHjTierIndex + 1; i < tierSize; i++) {
if (graphWhole.getNode(tiers[i]).getNodeType() == NodeType.MEASURED) {
sumOverVariables[tiers[i]] = 1;
}
}
qHj.add(qH, sumOverVariables, true);
if (debug) {
System.out.println("************* QDecomposition: Q[H^i], sumOverVariables: ");
for (int i = 0; i < nNodes; i++) {
System.out.print(sumOverVariables[i] + " ");
}
System.out.println();
}
// Q[H^{i-1}]
sumOverVariables[tiers[nodeHjTierIndex]] = 1;
qHj.add(qH, sumOverVariables, false);
if (debug) {
System.out.println("************* QDecomposition: Q[H^{i-1}], sumOverVariables: ");
for (int i = 0; i < nNodes; i++) {
System.out.print(sumOverVariables[i] + " ");
}
System.out.println();
}
}
}
return qHj;
}
use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.
the class Identifiability method identify.
// ///////////////////////////////////////////////////////////////
// identify
//
private QList identify(List<Node> nodesC, List<Node> nodesT, QList qT) {
Dag graphT = new Dag(bayesIm.getDag().subgraph(nodesT));
List<Node> nodesA = graphT.getAncestors(nodesC);
int nNodes = bayesIm.getNumNodes();
QList qC = new QList(nNodes);
if (debug) {
System.out.println("-------------- identify -------------");
System.out.println("----- bayesIm.getDag() -----");
System.out.println(bayesIm.getDag());
System.out.println("----- graphT -----");
System.out.println(graphT);
System.out.println("nodesC: " + nodesC);
System.out.println("nodesT: " + nodesT);
System.out.println("nodesA: " + nodesA);
System.out.println("nodesC containsAll nodesA: " + nodesC.containsAll(nodesA));
System.out.println("nodesA containsAll nodesT: " + nodesA.containsAll(nodesT));
}
// when checking for the subset instead of equality relation
if (nodesC.containsAll(nodesA)) {
qC.add(qT, sumList(nNodes, nodesT, nodesC), true);
if (debug) {
System.out.println("***************** identify: QList: qC *****************");
qC.printQList(0, 0);
}
return qC;
} else // (see comments for the first "if" branch above)
if (nodesA.containsAll(nodesT)) {
if (debug) {
System.out.println("----- FAIL: identify -----");
}
// fail
return null;
}
// ///////////////////////////////
// must be: nodesC subset A subset T
Dag graphA = new Dag(bayesIm.getDag().subgraph(nodesA));
// construct an IM with the dag graphA
BayesPm bayesPmA = new BayesPm(graphA, bayesIm.getBayesPm());
BayesIm bayesImA = new MlBayesIm(bayesPmA, bayesIm, MlBayesIm.RANDOM);
// get c-components of graphA
int[] cComponentsA = getCComponents(bayesImA);
int nCComponentsA = nCComponents(cComponentsA);
// get Q[A]
QList qA = new QList(nNodes);
qA.add(qT, sumList(nNodes, nodesT, nodesA), true);
if (debug) {
System.out.println("***************** identify: QList: qA *****************");
qC.printQList(0, 0);
}
int i = 0;
while ((i < nCComponentsA)) {
List<Node> cComponentNodesT2 = getCComponentNodes(bayesImA, cComponentsA, i);
if (debug) {
System.out.println("identify Q[A]: i: " + i);
System.out.println("cComponentNodesT2: " + cComponentNodesT2);
System.out.println("cComponentNodesT2.containsAll(nodesC): " + cComponentNodesT2.containsAll(nodesC));
}
if (cComponentNodesT2.containsAll(nodesC)) {
// get Q[T2]
QList qT2 = qDecomposition(bayesIm, nodesA, cComponentNodesT2, qA);
// recursive call to "identify"
return identify(nodesC, cComponentNodesT2, qT2);
}
i++;
}
throw new RuntimeException("identify: T2 not found");
}
use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.
the class ApproximateUpdater method setEvidence.
/**
* Sets new evidence for the next update operation.
*/
public final void setEvidence(Evidence evidence) {
if (evidence == null) {
throw new NullPointerException();
}
if (evidence.isIncompatibleWith(bayesIm)) {
throw new IllegalArgumentException("The variables for the given " + "evidence must be compatible with the Bayes IM being updated.");
}
this.evidence = new Evidence(evidence);
Graph graph = bayesIm.getBayesPm().getDag();
Dag manipulatedGraph = createManipulatedGraph(graph);
BayesPm manipulatedBayesPm = createUpdatedBayesPm(manipulatedGraph);
this.manipulatedBayesIm = createdUpdatedBayesIm(manipulatedBayesPm);
this.counts = null;
}
use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.
the class RowSummingExactUpdater method createManipulatedGraph.
private Dag createManipulatedGraph(Graph graph) {
Dag updatedGraph = new Dag(graph);
// alters graph for manipulated evidenceItems
for (int i = 0; i < evidence.getNumNodes(); ++i) {
if (evidence.isManipulated(i)) {
Node node = updatedGraph.getNode(evidence.getNode(i).getName());
List<Node> parents = updatedGraph.getParents(node);
for (Object parent1 : parents) {
Node parent = (Node) parent1;
updatedGraph.removeEdge(node, parent);
}
}
}
return updatedGraph;
}
Aggregations