use of java.rmi.MarshalledObject in project tetrad by cmu-phil.
the class BuildPureClustersRunner method execute.
// ===================PUBLIC METHODS OVERRIDING ABSTRACT================//
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
public void execute() {
boolean rKey = getParams().getBoolean("BPCrDown", false);
BpcAlgorithmType algorithm = (BpcAlgorithmType) getParams().get("bpcAlgorithmthmType", BpcAlgorithmType.FIND_ONE_FACTOR_CLUSTERS);
Graph searchGraph;
if (rKey) {
Washdown washdown;
Object source = getData();
if (source instanceof DataSet) {
washdown = new Washdown((DataSet) source, getParams().getDouble("alpha", 0.001));
} else {
washdown = new Washdown((CovarianceMatrix) source, getParams().getDouble("alpha", 0.001));
}
searchGraph = washdown.search();
} else {
TestType tetradTestType = (TestType) getParams().get("tetradTestType", TestType.TETRAD_WISHART);
if (algorithm == BpcAlgorithmType.TETRAD_PURIFY_WASHDOWN) {
BpcTetradPurifyWashdown bpc;
Object source = getData();
if (source instanceof DataSet) {
bpc = new BpcTetradPurifyWashdown((DataSet) source, tetradTestType, getParams().getDouble("alpha", 0.001));
} else {
bpc = new BpcTetradPurifyWashdown((ICovarianceMatrix) source, tetradTestType, getParams().getDouble("alpha", 0.001));
}
searchGraph = bpc.search();
} else if (algorithm == BpcAlgorithmType.BUILD_PURE_CLUSTERS) {
BuildPureClusters bpc;
DataModel source = getData();
TestType testType = (TestType) getParams().get("tetradTestType", TestType.TETRAD_WISHART);
TestType purifyType = TestType.TETRAD_BASED;
if (source instanceof ICovarianceMatrix) {
bpc = new BuildPureClusters((ICovarianceMatrix) source, getParams().getDouble("alpha", 0.001), testType, purifyType);
} else if (source instanceof DataSet) {
bpc = new BuildPureClusters((DataSet) source, getParams().getDouble("alpha", 0.001), testType, purifyType);
} else {
throw new IllegalArgumentException();
}
searchGraph = bpc.search();
} else // else if (algorithm == BpcAlgorithmType.FIND_ONE_FACTOR_CLUSTERS) {
// // FindOneFactorClusters bpc;
// // Object source = getContinuousData();
// //
// // if (source instanceof DataSet) {
// // bpc = new FindOneFactorClusters(
// // (DataSet) source,
// // tetradTestType,
// // getParameters().getAlternativePenalty());
// // } else {
// // bpc = new FindOneFactorClusters((ICovarianceMatrix) source,
// // tetradTestType, getParameters().getAlternativePenalty());
// // }
// //
// // searchGraph = bpc.search();
//
// FindOneFactorClusters2 bpc;
// Object source = getContinuousData();
// FindOneFactorClusters2.Algorithm sag = FindOneFactorClusters2.Algorithm.SAG;
//
// if (source instanceof DataSet) {
// bpc = new FindOneFactorClusters2(
// (DataSet) source,
// tetradTestType, sag,
// getParameters().getAlternativePenalty());
//
// // bpc = new FindTwoFactorClusters4(
// // (DataSet) source,
// // getParameters().getAlternativePenalty());
// } else {
// bpc = new FindOneFactorClusters2((ICovarianceMatrix) source,
// tetradTestType, sag, getParameters().getAlternativePenalty());
// //
// // bpc = new FindTwoFactorClusters4((ICovarianceMatrix) source,
// // getParameters().getAlternativePenalty());
// }
//
// searchGraph = bpc.search();
//
// }
// else if (algorithm == BpcAlgorithmType.FIND_TWO_FACTOR_CLUSTERS) {
// FindTwoFactorClusters2 bpc;
// Object source = getContinuousData();
//
// if (source instanceof DataSet) {
// bpc = new FindTwoFactorClusters2(
// (DataSet) source,
// tetradTestType,
// getParameters().getAlternativePenalty());
//
// // bpc = new FindTwoFactorClusters4(
// // (DataSet) source,
// // getParameters().getAlternativePenalty());
// } else {
// bpc = new FindTwoFactorClusters2((ICovarianceMatrix) source,
// tetradTestType, getParameters().getAlternativePenalty());
// //
// // bpc = new FindTwoFactorClusters4((ICovarianceMatrix) source,
// // getParameters().getAlternativePenalty());
// }
//
// searchGraph = bpc.search();
// }
{
throw new IllegalStateException();
}
}
if (semIm != null) {
List<List<Node>> partition = MimUtils.convertToClusters2(searchGraph);
List<String> variableNames = ReidentifyVariables.reidentifyVariables2(partition, trueGraph, (DataSet) getData());
rename(searchGraph, partition, variableNames);
// searchGraph = reidentifyVariables2(searchGraph, semIm);
} else if (trueGraph != null) {
List<List<Node>> partition = MimUtils.convertToClusters2(searchGraph);
List<String> variableNames = ReidentifyVariables.reidentifyVariables1(partition, trueGraph);
rename(searchGraph, partition, variableNames);
// searchGraph = reidentifyVariables(searchGraph, trueGraph);
}
System.out.println("Search Graph " + searchGraph);
try {
Graph graph = new MarshalledObject<>(searchGraph).get();
GraphUtils.circleLayout(graph, 200, 200, 150);
GraphUtils.fruchtermanReingoldLayout(graph);
setResultGraph(graph);
setClusters(MimUtils.convertToClusters(graph, getData().getVariables()));
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
use of java.rmi.MarshalledObject in project tetrad by cmu-phil.
the class TestMimbuild2 method changeLatentNames.
// private Graph condense(Graph mimStructure, Graph mimbuildStructure) {
// // System.out.println("Uncondensed: " + mimbuildStructure);
//
// Map<Node, Node> substitutions = new HashMap<Node, Node>();
//
// for (Node node : mimbuildStructure.getNodes()) {
// for (Node _node : mimStructure.getNodes()) {
// if (node.getNode().startsWith(_node.getNode())) {
// substitutions.put(node, _node);
// break;
// }
//
// substitutions.put(node, node);
// }
// }
//
// HashSet<Node> nodes = new HashSet<Node>(substitutions.values());
// Graph graph = new EdgeListGraph(new ArrayList<Node>(nodes));
//
// for (Edge edge : mimbuildStructure.getEdges()) {
// Node node1 = substitutions.get(edge.getNode1());
// Node node2 = substitutions.get(edge.getNode2());
//
// if (node1 == node2) continue;
//
// if (graph.isAdjacentTo(node1, node2)) continue;
//
// graph.addEdge(new Edge(node1, node2, edge.getEndpoint1(), edge.getEndpoint2()));
// }
//
// // System.out.println("Condensed: " + graph);
//
// return graph;
// }
// public void rtest2() {
// System.out.println("SHD\tP");
// // System.out.println("MB1\tMB2\tMB3\tMB4\tMB5\tMB6");
//
// Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 10, 0, 0, 0);
//
// Graph mimStructure = structure(mim);
//
// SemPm pm = new SemPm(mim);
// Parameters params = new Parameters();
// params.setCoefRange(0.5, 1.5);
//
// NumberFormat nf = new DecimalFormat("0.0000");
//
// int totalError = 0;
// int errorCount = 0;
//
// for (int r = 0; r < 1; r++) {
// SemIm im = new SemIm(pm, params);
//
// DataSet data = im.simulateData(300, false);
//
// mim = GraphUtils.replaceNodes(mim, data.getVariable());
// List<List<Node>> trueClusters = MimUtils.convertToClusters2(mim);
//
// CovarianceMatrix _cov = new CovarianceMatrix(data);
//
// ICovarianceMatrix cov = DataUtils.reorderColumns(_cov);
//
// String algorithm = "FOFC";
// Graph searchGraph;
// List<List<Node>> partition;
//
// if (algorithm.equals("FOFC")) {
// FindOneFactorClusters fofc = new FindOneFactorClusters(cov, TestType.TETRAD_WISHART, 0.001);
// searchGraph = fofc.search();
// searchGraph = GraphUtils.replaceNodes(searchGraph, data.getVariable());
// partition = MimUtils.convertToClusters2(searchGraph);
// } else if (algorithm.equals("BPC")) {
// TestType testType = TestType.TETRAD_WISHART;
// TestType purifyType = TestType.TETRAD_BASED;
//
// BuildPureClusters bpc = new BuildPureClusters(
// data, 0.001,
// testType,
// purifyType);
// searchGraph = bpc.search();
//
// partition = MimUtils.convertToClusters2(searchGraph);
// } else {
// throw new IllegalStateException();
// }
//
// mimStructure = GraphUtils.replaceNodes(mimStructure, data.getVariable());
//
// List<String> latentVarList = reidentifyVariables(mim, data, partition, 2);
//
// Graph mimbuildStructure;
//
// Mimbuild2 mimbuild = new Mimbuild2();
// mimbuild.setAlternativePenalty(0.001);
// mimbuild.setMinClusterSize(4);
// // mimbuild.setFixOneLoadingPerCluster(true);
//
// try {
// mimbuildStructure = mimbuild.search(partition, latentVarList, cov);
// } catch (Exception e) {
// e.printStackTrace();
// continue;
// }
//
// mimbuildStructure = GraphUtils.replaceNodes(mimbuildStructure, data.getVariable());
//
// int shd = SearchGraphUtils.structuralHammingDistance(restrictToEmpiricalLatents(mimStructure, mimbuildStructure), mimbuildStructure);
// boolean impureCluster = containsImpureCluster(partition, trueClusters);
// double pValue = mimbuild.getpValue();
// // double pValue = pvalue(mimbuild.getClustering(), _cov);
// boolean pBelow05 = pValue < 0.05;
// boolean numClustersNe5 = partition.size() != 5;
// boolean error = false;
//
// // boolean condition = impureCluster || numClustersNe5 || pBelow05;
// // boolean condition = numClustersNe5 || pBelow05;
// // boolean condition = numClustered(partition) == 40;
// boolean condition = numClustersNe5;
//
// if (!condition) {
// totalError += shd;
// errorCount++;
// }
//
// System.out.print(shd + "\t" + nf.format(pValue) + " "
// // + (error ? 1 : 0) + " "
// // + (pBelow05 ? "P < 0.05 " : "")
// // + (impureCluster ? "Impure cluster " : "")
// + (numClustersNe5 ? "# Clusters = " + partition.size() + " " : "")
// // + clusterSizes(partition, trueClusters)
// // + numClustered(partition)
// + partition
// );
//
// System.out.println();
// }
//
// System.out.println("\nAvg SHD for not-flagged cases = " + (totalError / (double) errorCount));
// }
// private Graph restrictToEmpiricalLatents(Graph mimStructure, Graph mimbuildStructure) {
// Graph _mim = new EdgeListGraph(mimStructure);
//
// for (Node node : mimbuildStructure.getNodes()) {
// if (!mimbuildStructure.containsNode(node)) {
// _mim.removeNode(node);
// }
// }
//
// return _mim;
// }
// private String clusterSizes(List<List<Node>> partition, List<List<Node>> trueClusters) {
// String s = "";
//
// FOR:
// for (int i = 0; i < partition.size(); i++) {
// List<Node> cluster = partition.get(i);
// s += cluster.size();
//
// for (List<Node> trueCluster : trueClusters) {
// if (trueCluster.containsAll(cluster)) {
// // Collections.sort(trueCluster);
// // Collections.sort(cluster);
// // System.out.println(trueCluster + " " + cluster);
// s += "p";
//
// if (i < partition.size() - 1) {
// s += ",";
// }
//
// continue FOR;
// }
// }
//
// if (i < partition.size() - 1) {
// s += ",";
// }
// }
//
// return s;
// }
// private int numClustered(List<List<Node>> partition) {
// int sum = 0;
//
// for (int i = 0; i < partition.size(); i++) {
// List<Node> cluster = partition.get(i);
// sum += cluster.size();
// }
//
// return sum;
// }
// private boolean containsImpureCluster(List<List<Node>> partition, List<List<Node>> trueClusters) {
//
// FOR:
// for (int i = 0; i < partition.size(); i++) {
// List<Node> cluster = partition.get(i);
//
// for (List<Node> trueCluster : trueClusters) {
// if (trueCluster.containsAll(cluster)) {
// continue FOR;
// }
// }
//
// return true;
// }
//
// return false;
// }
// public void rtest3() {
// Node x = new GraphNode("X");
// Node y = new GraphNode("Y");
// Node z = new GraphNode("Z");
// Node w = new GraphNode("W");
//
// List<Node> nodes = new ArrayList<Node>();
// nodes.add(x);
// nodes.add(y);
// nodes.add(z);
// nodes.add(w);
//
// Graph g = new EdgeListGraph(nodes);
// g.addDirectedEdge(x, y);
// g.addDirectedEdge(x, z);
// g.addDirectedEdge(y, w);
// g.addDirectedEdge(z, w);
//
// Graph maxGraph = null;
// double maxPValue = -1.0;
// ICovarianceMatrix maxLatentCov = null;
//
// Graph mim = DataGraphUtils.randomMim(g, 8, 0, 0, 0, true);
// // Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 8, 0, 0, 0);
// Graph mimStructure = structure(mim);
// SemPm pm = new SemPm(mim);
//
// System.out.println("\n\nTrue graph:");
// System.out.println(mimStructure);
//
// Parameters params = new Parameters();
// params.setCoefRange(0.5, 1.5);
//
// SemIm im = new SemIm(pm, params);
//
// int N = 1000;
//
// DataSet data = im.simulateData(N, false);
//
// CovarianceMatrix cov = new CovarianceMatrix(data);
//
// for (int i = 0; i < 1; i++) {
//
// ICovarianceMatrix _cov = DataUtils.reorderColumns(cov);
// List<List<Node>> partition;
//
// FindOneFactorClusters fofc = new FindOneFactorClusters(_cov, TestType.TETRAD_WISHART, .001);
// fofc.search();
// partition = fofc.getClusters();
// System.out.println(partition);
//
// List<String> latentVarList = reidentifyVariables(mim, data, partition, 2);
//
// Mimbuild2 mimbuild = new Mimbuild2();
//
// mimbuild.setAlternativePenalty(0.001);
// // mimbuild.setMinimumSize(5);
//
// // To test knowledge.
// // Knowledge knowledge = new Knowledge2();
// // knowledge.setEdgeForbidden("L.Y", "L.W", true);
// // knowledge.setEdgeRequired("L.Y", "L.Z", true);
// // mimbuild.setKnowledge(knowledge);
//
// Graph mimbuildStructure = mimbuild.search(partition, latentVarList, _cov);
//
// double pValue = mimbuild.getpValue();
// System.out.println(mimbuildStructure);
// System.out.println("P = " + pValue);
// System.out.println("Latent Cov = " + mimbuild.getLatentsCov());
//
// if (pValue > maxPValue) {
// maxPValue = pValue;
// maxGraph = new EdgeListGraph(mimbuildStructure);
// maxLatentCov = mimbuild.getLatentsCov();
// }
// }
//
// System.out.println("\n\nTrue graph:");
// System.out.println(mimStructure);
// System.out.println("\nBest graph:");
// System.out.println(maxGraph);
// System.out.println("P = " + maxPValue);
// System.out.println("Latent Cov = " + maxLatentCov);
// System.out.println();
// }
// public void rtest4() {
// System.out.println("SHD\tP");
// // System.out.println("MB1\tMB2\tMB3\tMB4\tMB5\tMB6");
//
// Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 8, 0, 0, 0);
//
// Graph mimStructure = structure(mim);
//
// SemPm pm = new SemPm(mim);
// Parameters params = new Parameters();
// params.setCoefRange(0.5, 1.5);
//
// NumberFormat nf = new DecimalFormat("0.0000");
//
// int totalError = 0;
// int errorCount = 0;
//
// int maxScore = 0;
// int maxNumMeasures = 0;
// double maxP = 0.0;
//
// for (int r = 0; r < 1; r++) {
// SemIm im = new SemIm(pm, params);
//
// DataSet data = im.simulateData(1000, false);
//
// mim = GraphUtils.replaceNodes(mim, data.getVariable());
// List<List<Node>> trueClusters = MimUtils.convertToClusters2(mim);
//
// CovarianceMatrix _cov = new CovarianceMatrix(data);
//
// ICovarianceMatrix cov = DataUtils.reorderColumns(_cov);
//
// String algorithm = "FOFC";
// Graph searchGraph;
// List<List<Node>> partition;
//
// if (algorithm.equals("FOFC")) {
// FindOneFactorClusters fofc = new FindOneFactorClusters(cov, TestType.TETRAD_WISHART, 0.001f);
// searchGraph = fofc.search();
// searchGraph = GraphUtils.replaceNodes(searchGraph, data.getVariable());
// partition = MimUtils.convertToClusters2(searchGraph);
// } else if (algorithm.equals("BPC")) {
// TestType testType = TestType.TETRAD_WISHART;
// TestType purifyType = TestType.TETRAD_BASED;
//
// BuildPureClusters bpc = new BuildPureClusters(
// data, 0.001,
// testType,
// purifyType);
// searchGraph = bpc.search();
//
// partition = MimUtils.convertToClusters2(searchGraph);
// } else {
// throw new IllegalStateException();
// }
//
// mimStructure = GraphUtils.replaceNodes(mimStructure, data.getVariable());
//
// List<String> latentVarList = reidentifyVariables(mim, data, partition, 2);
//
// Graph mimbuildStructure;
//
// Mimbuild2 mimbuild = new Mimbuild2();
// mimbuild.setAlternativePenalty(0.001);
// mimbuild.setMinClusterSize(3);
//
// try {
// mimbuildStructure = mimbuild.search(partition, latentVarList, cov);
// } catch (Exception e) {
// e.printStackTrace();
// continue;
// }
//
// mimbuildStructure = GraphUtils.replaceNodes(mimbuildStructure, data.getVariable());
// mimbuildStructure = condense(mimStructure, mimbuildStructure);
//
// int shd = SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure);
// boolean impureCluster = containsImpureCluster(partition, trueClusters);
// double pValue = mimbuild.getpValue();
// boolean pBelow05 = pValue < 0.05;
// boolean numClustersGreaterThan5 = partition.size() != 5;
// boolean error = false;
//
// // boolean condition = impureCluster || numClustersGreaterThan5 || pBelow05;
// // boolean condition = numClustersGreaterThan5 || pBelow05;
// boolean condition = numClustered(partition) == 40;
//
// if (!condition && (shd > 5)) {
// error = true;
// }
//
// if (!condition) {
// totalError += shd;
// errorCount++;
//
// }
//
// if (pValue > maxP) {
// maxScore = shd;
// maxP = mimbuild.getpValue();
// maxNumMeasures = numClustered(partition);
// System.out.println("maxNumMeasures = " + maxNumMeasures);
// System.out.println("maxScore = " + maxScore);
// System.out.println("maxP = " + maxP);
// System.out.println("clusters = " + clusterSizes(partition, trueClusters));
// }
//
// System.out.print(shd + "\t" + nf.format(pValue) + " "
// + numClustered(partition)
// );
//
// System.out.println();
// }
//
// System.out.println("\nAvg SHD for not-flagged cases = " + (totalError / (double) errorCount));
//
// System.out.println("maxNumMeasures = " + maxNumMeasures);
// System.out.println("maxScore = " + maxScore);
// System.out.println("maxP = " + maxP);
// }
private Graph changeLatentNames(Graph full, Clusters measurements, List<String> latentVarList) {
Graph g2 = null;
try {
g2 = (Graph) new MarshalledObject(full).get();
} catch (IOException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
for (int i = 0; i < measurements.getNumClusters(); i++) {
List<String> d = measurements.getCluster(i);
String latentName = latentVarList.get(i);
for (Node node : full.getNodes()) {
if (!(node.getNodeType() == NodeType.LATENT)) {
continue;
}
List<Node> _children = full.getChildren(node);
_children.removeAll(ReidentifyVariables.getLatents(full));
List<String> childNames = getNames(_children);
if (new HashSet<>(childNames).equals(new HashSet<>(d))) {
g2.getNode(node.getName()).setName(latentName);
}
}
}
return g2;
}
Aggregations