Search in sources :

Example 46 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class Identifiability method getJointMarginal.

// //////////////////////////////////////////////////////////////
// compute P_t(s):
// s is given by the variable-value pairs in the argument
// t is given by the evidence
// 
// t and s should each be only one combination of variable values
// (i.e., no disjunctions of values of a variable even those they
// are allowed in Proposition)
// 
public double getJointMarginal(int[] sVariables, int[] sValues) {
    if (sVariables.length != sValues.length) {
        throw new IllegalArgumentException("Values must match variables.");
    }
    // ////////////////////////////////////
    // collect s and t variables
    // the s variables
    List<Node> sNodes = new ArrayList<>();
    for (int sVariable : sVariables) {
        sNodes.add(bayesIm.getNode(sVariable));
    }
    if (debug) {
        System.out.println("\nsVariables: " + sNodes);
    }
    List<Node> tNodesTmp = evidence.getVariablesInEvidence();
    // any other regular updater (here: RowSummingExactUpdater)
    if (tNodesTmp.size() == 0) {
        /*
			 ManipulatingBayesUpdater rowSumUpdater = 
				new RowSummingExactUpdater(bayesIm, Evidence.tautology(bayesIm));
			return rowSumUpdater.getJointMarginal(sVariables, sValues);
			 */
        Proposition prop = Proposition.tautology(bayesIm);
        for (int i = 0; i < sVariables.length; i++) {
            /*
				int[] variableValues = getVariableValues(i);
				 
				String nodeName = getVariable().get(j).getNode();
				Node node = bayesIm.getNode(nodeName);
				targetProp.setCategory(bayesIm.getNodeIndex(node), 
									   variableValues[j]);		
				*/
            prop.setCategory(sVariables[i], sValues[i]);
        }
        // restrict the proposition to only observed variables
        Proposition propObs = new Proposition(((MlBayesImObs) bayesIm).getBayesImObs(), prop);
        return ((MlBayesImObs) bayesIm).getJPD().getProb(propObs);
    }
    // recast the t variables
    List<Node> tNodes = new ArrayList<>();
    for (Node node1 : tNodesTmp) {
        tNodes.add(bayesIm.getNode(node1.getName()));
    }
    if (debug) {
        System.out.println("\ntVariables: " + tNodes);
    }
    // ////////////////////////////////////
    // collect variable values that are set in S and T
    int nNodes = bayesIm.getNumNodes();
    int[] fixedVarValues = new int[nNodes];
    for (int i = 0; i < nNodes; i++) {
        // value not set
        fixedVarValues[i] = -1;
    }
    // incorporate values of S
    for (int i = 0; i < sVariables.length; i++) {
        fixedVarValues[sVariables[i]] = sValues[i];
    }
    // assume all the variables are manipulated
    for (int i = 0; i < evidence.getNumNodes(); i++) {
        int tNodeValue = evidence.getProposition().getSingleCategory(i);
        if (// only consider variables with a single value
        tNodeValue != -1) {
            Node tNode = evidence.getNode(i);
            String tNodeStr = evidence.getNode(i).getName();
            Node tNodeInBayesIm = bayesIm.getNode(tNodeStr);
            int tIndexInBayesIm = bayesIm.getNodeIndex(tNodeInBayesIm);
            String tCategoryStr = evidence.getCategory(tNode, tNodeValue);
            int tValueInBayesIm = bayesIm.getBayesPm().getCategoryIndex(tNodeInBayesIm, tCategoryStr);
            int oldValue = fixedVarValues[tIndexInBayesIm];
            if (// S has a value for this variable
            oldValue != -1) {
                if (oldValue == tValueInBayesIm) // remove the variable from S (S and T should be disjoint)
                {
                    sNodes.remove(tNodeInBayesIm);
                    if (debug) {
                        System.out.println("sNode removed: index: " + tNodeInBayesIm + "; name: " + tNodeStr);
                    }
                } else // S and T have different values for this same variable
                {
                    // S and T inconsistent: Prob = 0
                    return 0.0;
                }
            } else // no value for this variable yet
            {
                fixedVarValues[tIndexInBayesIm] = tValueInBayesIm;
            }
        }
    }
    if (debug) {
        System.out.print("fixedVarValues: ");
        for (int i = 0; i < nNodes; i++) {
            System.out.print(fixedVarValues[i] + "  ");
        }
        System.out.println();
    }
    // if all nodes in S are removed
    if (sNodes.size() == 0) {
        // Prob = 1
        return 1.0;
    }
    // ////////////////////////////////////
    // get c-components
    int[] cComponents = getCComponents(bayesIm);
    int nCComponents = nCComponents(cComponents);
    // store the nodes in each c-component
    List<List<Node>> cComponentNodes = new LinkedList<>();
    for (int i = 0; i < nCComponents; i++) {
        cComponentNodes.add(getCComponentNodes(bayesIm, cComponents, i));
    }
    if (debug) {
        for (int i = 0; i < nCComponents; i++) {
            System.out.println("c-component " + i + ": " + cComponentNodes.get(i));
        }
    }
    // ////////////////////////////////////
    // allowUnfaithfulness joint probability of all the measured variables
    int[] probTermV = new int[nNodes];
    for (int i = 0; i < nNodes; i++) {
        if (bayesIm.getNode(i).getNodeType() == NodeType.MEASURED) {
            probTermV[i] = 1;
        }
    }
    if (debug) {
        System.out.print("probTermV: ");
        for (int i = 0; i < nNodes; i++) {
            System.out.print(probTermV[i] + "   ");
        }
        System.out.println();
    }
    // get C-factors
    QList[] cFactors = new QList[nCComponents];
    for (int i = 0; i < nCComponents; i++) {
        // Q[V]: allowUnfaithfulness joint probTermV
        QList qV = new QList(nNodes, probTermV);
        if (debug) {
            System.out.println("cFactors " + i + "   " + bayesIm.getDag().getNodes() + "   " + cComponentNodes.get(i));
            System.out.println("============== QList: qV ==============");
            qV.printQList(0, 0);
        }
        cFactors[i] = qDecomposition(bayesIm, bayesIm.getDag().getNodes(), cComponentNodes.get(i), qV);
        if (debug) {
            System.out.println("============== QList: cFactors[" + i + "] ==============");
            cFactors[i].printQList(0, 0);
        }
    }
    // ////////////////////////////////////
    // get D
    // Note: "dag" is a new copy of the dag; otherwise modifications
    // would be made to the one in the bayesIm
    // Note: the ordering of the nodes may not be the same as in
    // the original graph
    Dag dag = new Dag(bayesIm.getDag());
    if (debug) {
        System.out.println("------ here1 -------------");
        System.out.println(bayesIm.getDag());
    // watch out!  tNodes may be empty
    // System.out.println(tNodes.get(0));
    // System.out.println(dag.getNodes().get(1));
    // System.out.println(tNodes.get(0).equals(dag.getNodes().get(1)));
    }
    dag.removeNodes(tNodes);
    if (debug) {
        System.out.println("------ here2 -------------");
        System.out.println(bayesIm.getDag());
    }
    List<Node> dNodes = dag.getAncestors(sNodes);
    // create a Bayes IM with the dag G_dNodes
    Dag gD = new Dag(bayesIm.getDag().subgraph(dNodes));
    BayesPm bayesPmD = new BayesPm(gD, bayesIm.getBayesPm());
    BayesIm bayesImD = new MlBayesIm(bayesPmD, bayesIm, MlBayesIm.RANDOM);
    if (debug) {
        System.out.println("------ bayeIm.getDag() -------------");
        System.out.println(bayesIm.getDag());
        System.out.println("------ gD -------------");
        System.out.println(gD);
        System.out.println("------ bayeImD.getDag() -------------");
        System.out.println(bayesImD.getDag());
        System.out.println("bayesIm node 0: " + bayesIm.getNode(0));
        System.out.println("bayesImD node 0: " + bayesImD.getNode(0));
    }
    // get c-components of gD
    int[] cComponentsD = getCComponents(bayesImD);
    int nCComponentsD = nCComponents(cComponentsD);
    // Q[Di]
    QList[] qD = new QList[nCComponentsD];
    for (int i = 0; i < nCComponentsD; i++) {
        // Di
        List<Node> cComponentNodesDi = getCComponentNodes(bayesImD, cComponentsD, i);
        // Sj
        // Find the index j of the c-component Sj in cComponentNodes
        // which is a superset of cComponentNodesDi
        // 
        int j = 0;
        boolean flag = false;
        while ((j < nCComponents) && !flag) {
            List<Node> cComponentNodesSj = cComponentNodes.get(j);
            if (cComponentNodesSj.containsAll(cComponentNodesDi)) {
                if (debug) {
                    System.out.println("----- Di   Sj --------");
                    System.out.println(i + "   " + cComponentNodesDi + "    " + cComponentNodesSj);
                }
                flag = true;
                qD[i] = identify(cComponentNodesDi, cComponentNodesSj, cFactors[j]);
                // fail: qD[i] not identifiable with this algorithm
                if (qD[i] == null) {
                    if (debug) {
                        System.out.println("----- FAIL qD[" + i + "] --------");
                    }
                    // fail: P_t(s) not identifiable with this algorithm
                    return -1.0;
                }
                if (debug) {
                    System.out.println("======================== QList: qD[" + i + "] =================");
                    qD[i].printQList(0, 0);
                }
            }
            j++;
        }
        if (// something is wrong
        !flag) {
            throw new RuntimeException("getJointMarginal: Sj not found");
        }
    }
    // ////////////////////////////////////
    // multiply the Q[Di]'s
    QList qDProducts = new QList(nNodes);
    int[] sumOverVariables = new int[nNodes];
    for (int i = 0; i < nNodes; i++) {
        sumOverVariables[i] = 0;
    }
    for (int i = 0; i < nCComponentsD; i++) {
        qDProducts.add(qD[i], sumOverVariables, true);
    }
    // P_t(s)
    QList qPTS = new QList(nNodes);
    qPTS.add(qDProducts, sumList(nNodes, dNodes, sNodes), true);
    if (debug) {
        System.out.println("***************************** QList: qPTS *******************");
        qPTS.printQList(0, 0);
    }
    // compute numeric value from the algebraic expression qPTS
    if (debug) {
        System.out.println("***************************** computeValue *******************");
    }
    return qPTS.computeValue(bayesIm, fixedVarValues);
}
Also used : Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Dag(edu.cmu.tetrad.graph.Dag) LinkedList(java.util.LinkedList) List(java.util.List) LinkedList(java.util.LinkedList) ArrayList(java.util.ArrayList)

Example 47 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class Identifiability method qDecomposition.

// ///////////////////////////////////////////////////////////////
// Compute generalized Q-decomposition (Tian and Pearl 2002, Lemma 4)
// 
// hj is a c-component in subgraph graphWhole_h
// qH is the q-factor for h
// 
// return q-factor of hj
// 
// ??? may not need to pass graphWhole as argument ???
private QList qDecomposition(BayesIm graphWhole, List<Node> h, List<Node> hj, QList qH) {
    Dag graphH = new Dag(graphWhole.getDag().subgraph(h));
    // tier ordering
    List<Node> tierOrdering = graphH.getCausalOrdering();
    // convert to the indices of the original graph
    // (from which the subgraph was obtained)
    int tierSize = tierOrdering.size();
    int[] tiers = new int[tierSize];
    for (int i = 0; i < tierSize; i++) {
        tiers[i] = graphWhole.getNodeIndex(tierOrdering.get(i));
    }
    if (debug) {
        System.out.print("************************* QDecomposition: Tier ordering: ");
        for (int i = 0; i < tierSize; i++) {
            System.out.print(graphWhole.getNode(tiers[i]) + "  ");
        }
        System.out.println();
    }
    int nNodes = graphWhole.getNumNodes();
    QList qHj = new QList(nNodes);
    for (Node nodeHj : hj) {
        // index of node hj in the original graph
        int nodeHjIndex = graphWhole.getNodeIndex(nodeHj);
        if (graphWhole.getNode(nodeHjIndex).getNodeType() == // skip latent variables
        NodeType.MEASURED) {
            // get index of node hj in the tier ordering of the nodes
            // of the original graph
            int nodeHjTierIndex;
            for (nodeHjTierIndex = 0; nodeHjTierIndex < tierSize; nodeHjTierIndex++) {
                if (tiers[nodeHjTierIndex] == nodeHjIndex) {
                    break;
                }
            }
            if (nodeHjTierIndex == tierSize) {
                throw new RuntimeException("qDecomposition: index out of bound");
            }
            // Q[H^i]
            int[] sumOverVariables = new int[nNodes];
            for (int i = 0; i < nNodes; i++) {
                sumOverVariables[i] = 0;
            }
            for (int i = nodeHjTierIndex + 1; i < tierSize; i++) {
                if (graphWhole.getNode(tiers[i]).getNodeType() == NodeType.MEASURED) {
                    sumOverVariables[tiers[i]] = 1;
                }
            }
            qHj.add(qH, sumOverVariables, true);
            if (debug) {
                System.out.println("************* QDecomposition: Q[H^i], sumOverVariables: ");
                for (int i = 0; i < nNodes; i++) {
                    System.out.print(sumOverVariables[i] + "   ");
                }
                System.out.println();
            }
            // Q[H^{i-1}]
            sumOverVariables[tiers[nodeHjTierIndex]] = 1;
            qHj.add(qH, sumOverVariables, false);
            if (debug) {
                System.out.println("************* QDecomposition: Q[H^{i-1}], sumOverVariables: ");
                for (int i = 0; i < nNodes; i++) {
                    System.out.print(sumOverVariables[i] + "   ");
                }
                System.out.println();
            }
        }
    }
    return qHj;
}
Also used : Node(edu.cmu.tetrad.graph.Node) Dag(edu.cmu.tetrad.graph.Dag)

Example 48 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class Identifiability method identify.

// ///////////////////////////////////////////////////////////////
// identify
// 
private QList identify(List<Node> nodesC, List<Node> nodesT, QList qT) {
    Dag graphT = new Dag(bayesIm.getDag().subgraph(nodesT));
    List<Node> nodesA = graphT.getAncestors(nodesC);
    int nNodes = bayesIm.getNumNodes();
    QList qC = new QList(nNodes);
    if (debug) {
        System.out.println("-------------- identify -------------");
        System.out.println("----- bayesIm.getDag() -----");
        System.out.println(bayesIm.getDag());
        System.out.println("----- graphT -----");
        System.out.println(graphT);
        System.out.println("nodesC: " + nodesC);
        System.out.println("nodesT: " + nodesT);
        System.out.println("nodesA: " + nodesA);
        System.out.println("nodesC containsAll nodesA: " + nodesC.containsAll(nodesA));
        System.out.println("nodesA containsAll nodesT: " + nodesA.containsAll(nodesT));
    }
    // when checking for the subset instead of equality relation
    if (nodesC.containsAll(nodesA)) {
        qC.add(qT, sumList(nNodes, nodesT, nodesC), true);
        if (debug) {
            System.out.println("***************** identify: QList: qC *****************");
            qC.printQList(0, 0);
        }
        return qC;
    } else // (see comments for the first "if" branch above)
    if (nodesA.containsAll(nodesT)) {
        if (debug) {
            System.out.println("----- FAIL: identify -----");
        }
        // fail
        return null;
    }
    // ///////////////////////////////
    // must be: nodesC subset A subset T
    Dag graphA = new Dag(bayesIm.getDag().subgraph(nodesA));
    // construct an IM with the dag graphA
    BayesPm bayesPmA = new BayesPm(graphA, bayesIm.getBayesPm());
    BayesIm bayesImA = new MlBayesIm(bayesPmA, bayesIm, MlBayesIm.RANDOM);
    // get c-components of graphA
    int[] cComponentsA = getCComponents(bayesImA);
    int nCComponentsA = nCComponents(cComponentsA);
    // get Q[A]
    QList qA = new QList(nNodes);
    qA.add(qT, sumList(nNodes, nodesT, nodesA), true);
    if (debug) {
        System.out.println("***************** identify: QList: qA *****************");
        qC.printQList(0, 0);
    }
    int i = 0;
    while ((i < nCComponentsA)) {
        List<Node> cComponentNodesT2 = getCComponentNodes(bayesImA, cComponentsA, i);
        if (debug) {
            System.out.println("identify Q[A]: i: " + i);
            System.out.println("cComponentNodesT2: " + cComponentNodesT2);
            System.out.println("cComponentNodesT2.containsAll(nodesC): " + cComponentNodesT2.containsAll(nodesC));
        }
        if (cComponentNodesT2.containsAll(nodesC)) {
            // get Q[T2]
            QList qT2 = qDecomposition(bayesIm, nodesA, cComponentNodesT2, qA);
            // recursive call to "identify"
            return identify(nodesC, cComponentNodesT2, qT2);
        }
        i++;
    }
    throw new RuntimeException("identify: T2 not found");
}
Also used : Node(edu.cmu.tetrad.graph.Node) Dag(edu.cmu.tetrad.graph.Dag)

Example 49 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class ApproximateUpdater method setEvidence.

/**
 * Sets new evidence for the next update operation.
 */
public final void setEvidence(Evidence evidence) {
    if (evidence == null) {
        throw new NullPointerException();
    }
    if (evidence.isIncompatibleWith(bayesIm)) {
        throw new IllegalArgumentException("The variables for the given " + "evidence must be compatible with the Bayes IM being updated.");
    }
    this.evidence = new Evidence(evidence);
    Graph graph = bayesIm.getBayesPm().getDag();
    Dag manipulatedGraph = createManipulatedGraph(graph);
    BayesPm manipulatedBayesPm = createUpdatedBayesPm(manipulatedGraph);
    this.manipulatedBayesIm = createdUpdatedBayesIm(manipulatedBayesPm);
    this.counts = null;
}
Also used : Graph(edu.cmu.tetrad.graph.Graph) Dag(edu.cmu.tetrad.graph.Dag)

Example 50 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class RowSummingExactUpdater method createManipulatedGraph.

private Dag createManipulatedGraph(Graph graph) {
    Dag updatedGraph = new Dag(graph);
    // alters graph for manipulated evidenceItems
    for (int i = 0; i < evidence.getNumNodes(); ++i) {
        if (evidence.isManipulated(i)) {
            Node node = updatedGraph.getNode(evidence.getNode(i).getName());
            List<Node> parents = updatedGraph.getParents(node);
            for (Object parent1 : parents) {
                Node parent = (Node) parent1;
                updatedGraph.removeEdge(node, parent);
            }
        }
    }
    return updatedGraph;
}
Also used : Node(edu.cmu.tetrad.graph.Node) Dag(edu.cmu.tetrad.graph.Dag)

Aggregations

Dag (edu.cmu.tetrad.graph.Dag)53 Node (edu.cmu.tetrad.graph.Node)41 Graph (edu.cmu.tetrad.graph.Graph)21 GraphNode (edu.cmu.tetrad.graph.GraphNode)21 Test (org.junit.Test)18 ArrayList (java.util.ArrayList)12 SemIm (edu.cmu.tetrad.sem.SemIm)10 SemPm (edu.cmu.tetrad.sem.SemPm)10 DataSet (edu.cmu.tetrad.data.DataSet)7 BayesPm (edu.cmu.tetrad.bayes.BayesPm)6 MlBayesIm (edu.cmu.tetrad.bayes.MlBayesIm)6 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)6 BayesIm (edu.cmu.tetrad.bayes.BayesIm)5 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)3 FruchtermanReingoldLayout (edu.cmu.tetrad.graph.FruchtermanReingoldLayout)2 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)2 LinkedList (java.util.LinkedList)2 StoredCellProbs (edu.cmu.tetrad.bayes.StoredCellProbs)1 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)1 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)1