use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestDataWrapper method testDataModelList.
@Test
public void testDataModelList() {
DataModelList modelList = new DataModelList();
List<Node> variables1 = new ArrayList<>();
for (int i = 0; i < 10; i++) {
variables1.add(new ContinuousVariable("X" + i));
}
List<Node> variables2 = new ArrayList<>();
for (int i = 0; i < 10; i++) {
variables2.add(new ContinuousVariable("X" + i));
}
DataSet first = new ColtDataSet(10, variables1);
first.setName("first");
DataSet second = new ColtDataSet(10, variables2);
second.setName("second");
modelList.add(first);
modelList.add(second);
assertTrue(modelList.contains(first));
assertTrue(modelList.contains(second));
modelList.setSelectedModel(second);
try {
DataModelList modelList2 = new MarshalledObject<>(modelList).get();
assertEquals("second", modelList2.getSelectedModel().getName());
} catch (Exception e) {
e.printStackTrace();
}
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class MNLRLikelihood method getLik.
public double getLik(int child_index, int[] parents) {
double lik = 0;
Node c = variables.get(child_index);
List<ContinuousVariable> continuous_parents = new ArrayList<>();
List<DiscreteVariable> discrete_parents = new ArrayList<>();
for (int p : parents) {
Node parent = variables.get(p);
if (parent instanceof ContinuousVariable) {
continuous_parents.add((ContinuousVariable) parent);
} else {
discrete_parents.add((DiscreteVariable) parent);
}
}
int p = continuous_parents.size();
List<List<Integer>> cells = adTree.getCellLeaves(discrete_parents);
// List<List<Integer>> cells = partition(discrete_parents);
int[] continuousCols = new int[p];
for (int j = 0; j < p; j++) continuousCols[j] = nodesHash.get(continuous_parents.get(j));
for (List<Integer> cell : cells) {
int r = cell.size();
if (r > 1) {
double[] mean = new double[p];
double[] var = new double[p];
for (int i = 0; i < p; i++) {
for (int j = 0; j < r; j++) {
mean[i] += continuousData[continuousCols[i]][cell.get(j)];
var[i] += Math.pow(continuousData[continuousCols[i]][cell.get(j)], 2);
}
mean[i] /= r;
var[i] /= r;
var[i] -= Math.pow(mean[i], 2);
var[i] = Math.sqrt(var[i]);
if (Double.isNaN(var[i])) {
System.out.println(var[i]);
}
}
int degree = fDegree;
if (fDegree < 1) {
degree = (int) Math.floor(Math.log(r));
}
TetradMatrix subset = new TetradMatrix(r, p * degree + 1);
for (int i = 0; i < r; i++) {
subset.set(i, p * degree, 1);
for (int j = 0; j < p; j++) {
for (int d = 0; d < degree; d++) {
subset.set(i, p * d + j, Math.pow((continuousData[continuousCols[j]][cell.get(i)] - mean[j]) / var[j], d + 1));
}
}
}
if (c instanceof ContinuousVariable) {
TetradVector target = new TetradVector(r);
for (int i = 0; i < r; i++) {
target.set(i, continuousData[child_index][cell.get(i)]);
}
lik += multipleRegression(target, subset);
} else {
ArrayList<Integer> temp = new ArrayList<>();
TetradMatrix target = new TetradMatrix(r, ((DiscreteVariable) c).getNumCategories());
for (int i = 0; i < r; i++) {
for (int j = 0; j < ((DiscreteVariable) c).getNumCategories(); j++) {
target.set(i, j, -1);
}
target.set(i, discreteData[child_index][cell.get(i)], 1);
}
lik += MultinomialLogisticRegression(target, subset);
}
}
}
return lik;
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestHistogram method test1.
public void test1() {
List<Node> nodes = new LinkedList<>();
Node x1 = new ContinuousVariable("X1");
Node x2 = new ContinuousVariable("X2");
nodes.add(x1);
nodes.add(x2);
TetradMatrix dataMatrix = new TetradMatrix(10, 2);
dataMatrix.set(0, 0, 0);
dataMatrix.set(1, 0, 0);
dataMatrix.set(2, 0, 0);
dataMatrix.set(3, 0, 0);
dataMatrix.set(4, 0, 0);
dataMatrix.set(5, 0, 1);
dataMatrix.set(6, 0, 1);
dataMatrix.set(7, 0, 1);
dataMatrix.set(8, 0, 1);
dataMatrix.set(9, 0, 1);
dataMatrix.set(0, 1, 0);
dataMatrix.set(1, 1, 1);
dataMatrix.set(2, 1, 1);
dataMatrix.set(3, 1, 1);
dataMatrix.set(4, 1, 1);
dataMatrix.set(5, 1, 0);
dataMatrix.set(6, 1, 0);
dataMatrix.set(7, 1, 0);
dataMatrix.set(8, 1, 0);
dataMatrix.set(9, 1, 1);
DataSet dataSet = ColtDataSet.makeContinuousData(nodes, dataMatrix);
// Histogram histogram = new Histogram(dataSet, );
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TabularDataTable method addColumnsOutTo.
/**
* Col index here is JTable index.
*/
private void addColumnsOutTo(int col) {
for (int i = dataSet.getNumColumns() + getNumLeadingCols(); i <= col; i++) {
ContinuousVariable var = new ContinuousVariable("");
dataSet.addVariable(var);
System.out.println("Adding " + var + " col " + dataSet.getColumn(var));
}
pcs.firePropertyChange("modelChanged", null, null);
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TimeLagGraphEditor method createGraphMenu.
private JMenu createGraphMenu() {
JMenu graph = new JMenu("Graph");
graph.add(new GraphPropertiesAction(getWorkbench()));
graph.add(new PathsAction(getWorkbench()));
// graph.add(new DirectedPathsAction(getWorkbench()));
// graph.add(new TreksAction(getWorkbench()));
// graph.add(new AllPathsAction(getWorkbench()));
// graph.add(new NeighborhoodsAction(getWorkbench()));
graph.addSeparator();
JMenuItem correlateExogenous = new JMenuItem("Correlate Exogenous Variables");
JMenuItem uncorrelateExogenous = new JMenuItem("Uncorrelate Exogenous Variables");
graph.add(correlateExogenous);
graph.add(uncorrelateExogenous);
graph.addSeparator();
correlateExogenous.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
correlateExogenousVariables();
getWorkbench().invalidate();
getWorkbench().repaint();
}
});
uncorrelateExogenous.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
uncorrelationExogenousVariables();
getWorkbench().invalidate();
getWorkbench().repaint();
}
});
JMenuItem randomGraph = new JMenuItem("Random Graph");
graph.add(randomGraph);
randomGraph.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
RandomGraphEditor editor = new RandomGraphEditor(workbench.getGraph(), true, parameters);
int ret = JOptionPane.showConfirmDialog(TimeLagGraphEditor.this, editor, "Edit Random DAG Parameters", JOptionPane.PLAIN_MESSAGE);
if (ret == JOptionPane.OK_OPTION) {
Graph graph = null;
Graph dag = new Dag();
int numTrials = 0;
while (graph == null && ++numTrials < 100) {
if (editor.isRandomForward()) {
dag = GraphUtils.randomGraphRandomForwardEdges(getGraph().getNodes(), editor.getNumLatents(), editor.getMaxEdges(), 30, 15, 15, false, true);
GraphUtils.arrangeBySourceGraph(dag, getWorkbench().getGraph());
HashMap<String, PointXy> layout = GraphUtils.grabLayout(workbench.getGraph().getNodes());
GraphUtils.arrangeByLayout(dag, layout);
} else if (editor.isUniformlySelected()) {
if (getGraph().getNumNodes() == editor.getNumNodes()) {
HashMap<String, PointXy> layout = GraphUtils.grabLayout(workbench.getGraph().getNodes());
dag = GraphUtils.randomGraph(getGraph().getNodes(), editor.getNumLatents(), editor.getMaxEdges(), editor.getMaxDegree(), editor.getMaxIndegree(), editor.getMaxOutdegree(), editor.isConnected());
GraphUtils.arrangeBySourceGraph(dag, getWorkbench().getGraph());
GraphUtils.arrangeByLayout(dag, layout);
} else {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < editor.getNumNodes(); i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
dag = GraphUtils.randomGraph(nodes, editor.getNumLatents(), editor.getMaxEdges(), editor.getMaxDegree(), editor.getMaxIndegree(), editor.getMaxOutdegree(), editor.isConnected());
}
} else {
do {
if (getGraph().getNumNodes() == editor.getNumNodes()) {
HashMap<String, PointXy> layout = GraphUtils.grabLayout(workbench.getGraph().getNodes());
dag = GraphUtils.randomDag(getGraph().getNodes(), editor.getNumLatents(), editor.getMaxEdges(), 30, 15, 15, editor.isConnected());
GraphUtils.arrangeByLayout(dag, layout);
} else {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < editor.getNumNodes(); i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
dag = GraphUtils.randomGraph(nodes, editor.getNumLatents(), editor.getMaxEdges(), 30, 15, 15, editor.isConnected());
}
} while (dag.getNumEdges() < editor.getMaxEdges());
}
boolean addCycles = editor.isAddCycles();
if (addCycles) {
int minNumCycles = editor.getMinNumCycles();
int minCycleLength = editor.getMinCycleLength();
// graph = DataGraphUtils.addCycles2(dag, minNumCycles, minCycleLength);
graph = GraphUtils.cyclicGraph2(editor.getNumNodes(), editor.getMaxEdges(), 8);
GraphUtils.addTwoCycles(graph, editor.getMinNumCycles());
} else {
graph = new EdgeListGraph(dag);
}
}
if (graph == null) {
JOptionPane.showMessageDialog(TimeLagGraphEditor.this, "Could not find a graph that fits those constrains.");
getWorkbench().setGraph(new EdgeListGraph(dag));
} else {
getWorkbench().setGraph(graph);
}
// getWorkbench().setGraph(new EdgeListGraph(dag));
// getWorkbench().setGraph(graph);
}
}
});
JMenuItem randomIndicatorModel = new JMenuItem("Random Multiple Indicator Model");
graph.add(randomIndicatorModel);
randomIndicatorModel.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
RandomMimParamsEditor editor = new RandomMimParamsEditor(parameters);
int ret = JOptionPane.showConfirmDialog(JOptionUtils.centeringComp(), editor, "Edit Random MIM Parameters", JOptionPane.OK_CANCEL_OPTION, JOptionPane.PLAIN_MESSAGE);
if (ret == JOptionPane.OK_OPTION) {
int numFactors = Preferences.userRoot().getInt("randomMimNumFactors", 1);
int numStructuralNodes = Preferences.userRoot().getInt("numStructuralNodes", 3);
int maxStructuralEdges = Preferences.userRoot().getInt("numStructuralEdges", 3);
int measurementModelDegree = Preferences.userRoot().getInt("measurementModelDegree", 3);
int numLatentMeasuredImpureParents = Preferences.userRoot().getInt("latentMeasuredImpureParents", 0);
int numMeasuredMeasuredImpureParents = Preferences.userRoot().getInt("measuredMeasuredImpureParents", 0);
int numMeasuredMeasuredImpureAssociations = Preferences.userRoot().getInt("measuredMeasuredImpureAssociations", 0);
Graph graph;
if (numFactors == 1) {
graph = DataGraphUtils.randomSingleFactorModel(numStructuralNodes, maxStructuralEdges, measurementModelDegree, numLatentMeasuredImpureParents, numMeasuredMeasuredImpureParents, numMeasuredMeasuredImpureAssociations);
} else if (numFactors == 2) {
graph = DataGraphUtils.randomBifactorModel(numStructuralNodes, maxStructuralEdges, measurementModelDegree, numLatentMeasuredImpureParents, numMeasuredMeasuredImpureParents, numMeasuredMeasuredImpureAssociations);
} else {
throw new IllegalArgumentException("Can only make random MIMs for 1 or 2 factors, " + "sorry dude.");
}
getWorkbench().setGraph(graph);
}
}
});
graph.addSeparator();
graph.add(new JMenuItem(new SelectBidirectedAction(getWorkbench())));
graph.add(new JMenuItem(new SelectUndirectedAction(getWorkbench())));
// graph.add(action);
return graph;
}
Aggregations