Search in sources :

Example 21 with BayesIm

use of edu.cmu.tetrad.bayes.BayesIm in project tetrad by cmu-phil.

the class SaveBayesImXmlAction method actionPerformed.

public void actionPerformed(ActionEvent e) {
    try {
        File outfile = EditorUtils.getSaveFile("bayesim", "xml", this.bayesImEditor, false, "Save Bayes IM as XML...");
        BayesIm bayesIm = bayesImEditor.getWizard().getBayesIm();
        FileOutputStream out = new FileOutputStream(outfile);
        Element element = BayesXmlRenderer.getElement(bayesIm);
        Document document = new Document(element);
        Serializer serializer = new Serializer(out);
        serializer.setLineSeparator("\n");
        serializer.setIndent(2);
        serializer.write(document);
        out.close();
    } catch (IOException e1) {
        throw new RuntimeException(e1);
    }
}
Also used : BayesIm(edu.cmu.tetrad.bayes.BayesIm) FileOutputStream(java.io.FileOutputStream) Element(nu.xom.Element) IOException(java.io.IOException) Document(nu.xom.Document) File(java.io.File) Serializer(nu.xom.Serializer)

Example 22 with BayesIm

use of edu.cmu.tetrad.bayes.BayesIm in project tetrad by cmu-phil.

the class ConditionalGaussianSimulation method simulate.

private DataSet simulate(Graph G, Parameters parameters) {
    HashMap<String, Integer> nd = new HashMap<>();
    List<Node> nodes = G.getNodes();
    Collections.shuffle(nodes);
    if (this.shuffledOrder == null) {
        List<Node> shuffledNodes = new ArrayList<>(nodes);
        Collections.shuffle(shuffledNodes);
        this.shuffledOrder = shuffledNodes;
    }
    for (int i = 0; i < nodes.size(); i++) {
        if (i < nodes.size() * parameters.getDouble("percentDiscrete") * 0.01) {
            final int minNumCategories = parameters.getInt("minCategories");
            final int maxNumCategories = parameters.getInt("maxCategories");
            final int value = pickNumCategories(minNumCategories, maxNumCategories);
            nd.put(shuffledOrder.get(i).getName(), value);
        } else {
            nd.put(shuffledOrder.get(i).getName(), 0);
        }
    }
    G = makeMixedGraph(G, nd);
    nodes = G.getNodes();
    DataSet mixedData = new BoxDataSet(new MixedDataBox(nodes, parameters.getInt("sampleSize")), nodes);
    List<Node> X = new ArrayList<>();
    List<Node> A = new ArrayList<>();
    for (Node node : G.getNodes()) {
        if (node instanceof ContinuousVariable) {
            X.add(node);
        } else {
            A.add(node);
        }
    }
    Graph AG = G.subgraph(A);
    Graph XG = G.subgraph(X);
    Map<ContinuousVariable, DiscreteVariable> erstatzNodes = new HashMap<>();
    Map<String, ContinuousVariable> erstatzNodesReverse = new HashMap<>();
    for (Node y : A) {
        for (Node x : G.getParents(y)) {
            if (x instanceof ContinuousVariable) {
                DiscreteVariable ersatz = erstatzNodes.get(x);
                if (ersatz == null) {
                    ersatz = new DiscreteVariable("Ersatz_" + x.getName(), RandomUtil.getInstance().nextInt(3) + 2);
                    erstatzNodes.put((ContinuousVariable) x, ersatz);
                    erstatzNodesReverse.put(ersatz.getName(), (ContinuousVariable) x);
                    AG.addNode(ersatz);
                }
                AG.addDirectedEdge(ersatz, y);
            }
        }
    }
    BayesPm bayesPm = new BayesPm(AG);
    BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
    SemPm semPm = new SemPm(XG);
    Map<Combination, Double> paramValues = new HashMap<>();
    List<Node> tierOrdering = G.getCausalOrdering();
    int[] tiers = new int[tierOrdering.size()];
    for (int t = 0; t < tierOrdering.size(); t++) {
        tiers[t] = nodes.indexOf(tierOrdering.get(t));
    }
    Map<Integer, double[]> breakpointsMap = new HashMap<>();
    for (int mixedIndex : tiers) {
        for (int i = 0; i < parameters.getInt("sampleSize"); i++) {
            if (nodes.get(mixedIndex) instanceof DiscreteVariable) {
                int bayesIndex = bayesIm.getNodeIndex(nodes.get(mixedIndex));
                int[] bayesParents = bayesIm.getParents(bayesIndex);
                int[] parentValues = new int[bayesParents.length];
                for (int k = 0; k < parentValues.length; k++) {
                    int bayesParentColumn = bayesParents[k];
                    Node bayesParent = bayesIm.getVariables().get(bayesParentColumn);
                    DiscreteVariable _parent = (DiscreteVariable) bayesParent;
                    int value;
                    ContinuousVariable orig = erstatzNodesReverse.get(_parent.getName());
                    if (orig != null) {
                        int mixedParentColumn = mixedData.getColumn(orig);
                        double d = mixedData.getDouble(i, mixedParentColumn);
                        double[] breakpoints = breakpointsMap.get(mixedParentColumn);
                        if (breakpoints == null) {
                            breakpoints = getBreakpoints(mixedData, _parent, mixedParentColumn);
                            breakpointsMap.put(mixedParentColumn, breakpoints);
                        }
                        value = breakpoints.length;
                        for (int j = 0; j < breakpoints.length; j++) {
                            if (d < breakpoints[j]) {
                                value = j;
                                break;
                            }
                        }
                    } else {
                        int mixedColumn = mixedData.getColumn(bayesParent);
                        value = mixedData.getInt(i, mixedColumn);
                    }
                    parentValues[k] = value;
                }
                int rowIndex = bayesIm.getRowIndex(bayesIndex, parentValues);
                double sum = 0.0;
                double r = RandomUtil.getInstance().nextDouble();
                mixedData.setInt(i, mixedIndex, 0);
                for (int k = 0; k < bayesIm.getNumColumns(bayesIndex); k++) {
                    double probability = bayesIm.getProbability(bayesIndex, rowIndex, k);
                    sum += probability;
                    if (sum >= r) {
                        mixedData.setInt(i, mixedIndex, k);
                        break;
                    }
                }
            } else {
                Node y = nodes.get(mixedIndex);
                Set<DiscreteVariable> discreteParents = new HashSet<>();
                Set<ContinuousVariable> continuousParents = new HashSet<>();
                for (Node node : G.getParents(y)) {
                    if (node instanceof DiscreteVariable) {
                        discreteParents.add((DiscreteVariable) node);
                    } else {
                        continuousParents.add((ContinuousVariable) node);
                    }
                }
                Parameter varParam = semPm.getParameter(y, y);
                Parameter muParam = semPm.getMeanParameter(y);
                Combination varComb = new Combination(varParam);
                Combination muComb = new Combination(muParam);
                for (DiscreteVariable v : discreteParents) {
                    varComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
                    muComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
                }
                double value = RandomUtil.getInstance().nextNormal(0, getParamValue(varComb, paramValues));
                for (Node x : continuousParents) {
                    Parameter coefParam = semPm.getParameter(x, y);
                    Combination coefComb = new Combination(coefParam);
                    for (DiscreteVariable v : discreteParents) {
                        coefComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
                    }
                    int parent = nodes.indexOf(x);
                    double parentValue = mixedData.getDouble(i, parent);
                    double parentCoef = getParamValue(coefComb, paramValues);
                    value += parentValue * parentCoef;
                }
                value += getParamValue(muComb, paramValues);
                mixedData.setDouble(i, mixedIndex, value);
            }
        }
    }
    boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
    return saveLatentVars ? mixedData : DataUtils.restrictToMeasured(mixedData);
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) RandomGraph(edu.cmu.tetrad.algcomparison.graph.RandomGraph) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 23 with BayesIm

use of edu.cmu.tetrad.bayes.BayesIm in project tetrad by cmu-phil.

the class XdslXmlParser method buildIM.

private BayesIm buildIM(Element element0, Map<String, String> displayNames) {
    Elements elements = element0.getChildElements();
    for (int i = 0; i < elements.size(); i++) {
        if (!"cpt".equals(elements.get(i).getQualifiedName())) {
            throw new IllegalArgumentException("Expecting cpt element.");
        }
    }
    Dag dag = new Dag();
    // Get the nodes.
    for (int i = 0; i < elements.size(); i++) {
        Element cpt = elements.get(i);
        String name = cpt.getAttribute(0).getValue();
        if (displayNames == null) {
            dag.addNode(new GraphNode(name));
        } else {
            dag.addNode(new GraphNode(displayNames.get(name)));
        }
    }
    // Get the edges.
    for (int i = 0; i < elements.size(); i++) {
        Element cpt = elements.get(i);
        Elements cptElements = cpt.getChildElements();
        for (int j = 0; j < cptElements.size(); j++) {
            Element cptElement = cptElements.get(j);
            if (cptElement.getQualifiedName().equals("parents")) {
                String list = cptElement.getValue();
                String[] parentNames = list.split(" ");
                for (String name : parentNames) {
                    if (displayNames == null) {
                        edu.cmu.tetrad.graph.Node parent = dag.getNode(name);
                        edu.cmu.tetrad.graph.Node child = dag.getNode(cpt.getAttribute(0).getValue());
                        dag.addDirectedEdge(parent, child);
                    } else {
                        edu.cmu.tetrad.graph.Node parent = dag.getNode(displayNames.get(name));
                        edu.cmu.tetrad.graph.Node child = dag.getNode(displayNames.get(cpt.getAttribute(0).getValue()));
                        dag.addDirectedEdge(parent, child);
                    }
                }
            }
        }
        String name;
        if (displayNames == null) {
            name = cpt.getAttribute(0).getValue();
        } else {
            name = displayNames.get(cpt.getAttribute(0).getValue());
        }
        dag.addNode(new GraphNode(name));
    }
    // PM
    BayesPm pm = new BayesPm(dag);
    for (int i = 0; i < elements.size(); i++) {
        Element cpt = elements.get(i);
        String varName = cpt.getAttribute(0).getValue();
        Node node;
        if (displayNames == null) {
            node = dag.getNode(varName);
        } else {
            node = dag.getNode(displayNames.get(varName));
        }
        Elements cptElements = cpt.getChildElements();
        List<String> stateNames = new ArrayList<>();
        for (int j = 0; j < cptElements.size(); j++) {
            Element cptElement = cptElements.get(j);
            if (cptElement.getQualifiedName().equals("state")) {
                Attribute attribute = cptElement.getAttribute(0);
                String stateName = attribute.getValue();
                stateNames.add(stateName);
            }
        }
        pm.setCategories(node, stateNames);
    }
    // IM
    BayesIm im = new MlBayesIm(pm);
    for (int nodeIndex = 0; nodeIndex < elements.size(); nodeIndex++) {
        Element cpt = elements.get(nodeIndex);
        Elements cptElements = cpt.getChildElements();
        for (int j = 0; j < cptElements.size(); j++) {
            Element cptElement = cptElements.get(j);
            if (cptElement.getQualifiedName().equals("probabilities")) {
                String list = cptElement.getValue();
                String[] probsStrings = list.split(" ");
                List<Double> probs = new ArrayList<>();
                for (String probString : probsStrings) {
                    probs.add(Double.parseDouble(probString));
                }
                int count = -1;
                for (int row = 0; row < im.getNumRows(nodeIndex); row++) {
                    for (int col = 0; col < im.getNumColumns(nodeIndex); col++) {
                        im.setProbability(nodeIndex, row, col, probs.get(++count));
                    }
                }
            }
        }
    }
    return im;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) Attribute(nu.xom.Attribute) Element(nu.xom.Element) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) GraphNode(edu.cmu.tetrad.graph.GraphNode) Dag(edu.cmu.tetrad.graph.Dag) Elements(nu.xom.Elements) Node(edu.cmu.tetrad.graph.Node) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 24 with BayesIm

use of edu.cmu.tetrad.bayes.BayesIm in project tetrad by cmu-phil.

the class XdslXmlParser method getBayesIm.

/**
 * Takes an xml representation of a Bayes IM and reinstantiates the IM
 *
 * @param element the xml of the IM
 * @return the BayesIM
 */
public BayesIm getBayesIm(Element element) {
    if (!"smile".equals(element.getQualifiedName())) {
        throw new IllegalArgumentException("Expecting " + "smile" + " element.");
    }
    Elements elements = element.getChildElements();
    Element element0 = null, element1 = null;
    for (int i = 0; i < elements.size(); i++) {
        Element _element = elements.get(i);
        if ("nodes".equals(_element.getQualifiedName())) {
            element0 = _element;
        }
        if ("extensions".equals(_element.getQualifiedName())) {
            element1 = _element.getFirstChildElement("genie");
        }
    }
    Map<String, String> displayNames = mapDisplayNames(element1, useDisplayNames);
    BayesIm bayesIm = buildIM(element0, displayNames);
    return bayesIm;
}
Also used : BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) Element(nu.xom.Element) Elements(nu.xom.Elements)

Example 25 with BayesIm

use of edu.cmu.tetrad.bayes.BayesIm in project tetrad by cmu-phil.

the class TestEvidence method sampleBayesIm2.

private static BayesIm sampleBayesIm2() {
    Node a = new GraphNode("a");
    Node b = new GraphNode("b");
    Node c = new GraphNode("c");
    Dag graph;
    graph = new Dag();
    graph.addNode(a);
    graph.addNode(b);
    graph.addNode(c);
    graph.addDirectedEdge(a, b);
    graph.addDirectedEdge(a, c);
    graph.addDirectedEdge(b, c);
    BayesPm bayesPm = new BayesPm(graph);
    bayesPm.setNumCategories(b, 3);
    BayesIm bayesIm1 = new MlBayesIm(bayesPm);
    bayesIm1.setProbability(0, 0, 0, .3);
    bayesIm1.setProbability(0, 0, 1, .7);
    bayesIm1.setProbability(1, 0, 0, .3);
    bayesIm1.setProbability(1, 0, 1, .4);
    bayesIm1.setProbability(1, 0, 2, .3);
    bayesIm1.setProbability(1, 1, 0, .6);
    bayesIm1.setProbability(1, 1, 1, .1);
    bayesIm1.setProbability(1, 1, 2, .3);
    bayesIm1.setProbability(2, 0, 0, .9);
    bayesIm1.setProbability(2, 0, 1, .1);
    bayesIm1.setProbability(2, 1, 0, .1);
    bayesIm1.setProbability(2, 1, 1, .9);
    bayesIm1.setProbability(2, 2, 0, .5);
    bayesIm1.setProbability(2, 2, 1, .5);
    bayesIm1.setProbability(2, 3, 0, .2);
    bayesIm1.setProbability(2, 3, 1, .8);
    bayesIm1.setProbability(2, 4, 0, .6);
    bayesIm1.setProbability(2, 4, 1, .4);
    bayesIm1.setProbability(2, 5, 0, .7);
    bayesIm1.setProbability(2, 5, 1, .3);
    return bayesIm1;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) GraphNode(edu.cmu.tetrad.graph.GraphNode) Dag(edu.cmu.tetrad.graph.Dag) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Aggregations

BayesIm (edu.cmu.tetrad.bayes.BayesIm)36 MlBayesIm (edu.cmu.tetrad.bayes.MlBayesIm)21 BayesPm (edu.cmu.tetrad.bayes.BayesPm)18 Test (org.junit.Test)14 Graph (edu.cmu.tetrad.graph.Graph)7 Node (edu.cmu.tetrad.graph.Node)7 DataSet (edu.cmu.tetrad.data.DataSet)6 Dag (edu.cmu.tetrad.graph.Dag)5 Algorithm (edu.cmu.tetrad.algcomparison.algorithm.Algorithm)3 GraphNode (edu.cmu.tetrad.graph.GraphNode)3 Parameters (edu.cmu.tetrad.util.Parameters)3 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)3 File (java.io.File)3 IOException (java.io.IOException)3 ArrayList (java.util.ArrayList)3 Element (nu.xom.Element)3 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)2 ChiSquare (edu.cmu.tetrad.algcomparison.independence.ChiSquare)2 IndependenceWrapper (edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper)2 BdeuScore (edu.cmu.tetrad.algcomparison.score.BdeuScore)2