use of edu.cmu.tetrad.data.BootstrapSampler in project tetrad by cmu-phil.
the class StabilitySelection method search.
@Override
public Graph search(DataModel dataSet, Parameters parameters) {
DataSet _dataSet = (DataSet) dataSet;
double percentageB = parameters.getDouble("percentSubsampleSize");
int numSubsamples = parameters.getInt("numSubsamples");
Map<Edge, Integer> counts = new HashMap<>();
List<Graph> graphs = new ArrayList<>();
final ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool();
class StabilityAction extends RecursiveAction {
private int chunk;
private int from;
private int to;
private StabilityAction(int chunk, int from, int to) {
this.chunk = chunk;
this.from = from;
this.to = to;
}
@Override
protected void compute() {
if (to - from <= chunk) {
for (int s = from; s < to; s++) {
BootstrapSampler sampler = new BootstrapSampler();
sampler.setWithoutReplacements(true);
DataSet sample = sampler.sample(_dataSet, (int) (percentageB * _dataSet.getNumRows()));
Graph graph = algorithm.search(sample, parameters);
graphs.add(graph);
}
} else {
final int mid = (to + from) / 2;
StabilityAction left = new StabilityAction(chunk, from, mid);
StabilityAction right = new StabilityAction(chunk, mid, to);
left.fork();
right.compute();
left.join();
}
}
}
final int chunk = 2;
pool.invoke(new StabilityAction(chunk, 0, numSubsamples));
// }
for (Graph graph : graphs) {
for (Edge edge : graph.getEdges()) {
increment(edge, counts);
}
}
initialGraph = new EdgeListGraph(dataSet.getVariables());
double percentStability = parameters.getDouble("percentStability");
for (Edge edge : counts.keySet()) {
if (counts.get(edge) > percentStability * numSubsamples) {
initialGraph.addEdge(edge);
}
}
return initialGraph;
}
Aggregations