use of edu.cmu.ml.proppr.learn.RegularizationSchedule in project ProPPR by TeamCohen.
the class GradientFinderTest method setup.
@Before
public void setup() {
super.setup();
this.srw = new SRW();
this.srw.setRegularizer(new RegularizationSchedule(this.srw, new RegularizeL2()));
this.srw.setSquashingFunction(new ReLU<String>());
this.initTrainer();
query = new TIntDoubleHashMap();
query.put(nodes.getId("r0"), 1.0);
examples = new ArrayList<String>();
for (int k = 0; k < this.magicNumber; k++) {
for (int p = 0; p < this.magicNumber; p++) {
StringBuilder serialized = new StringBuilder("r0").append("\t").append(//query
nodes.getId("r0")).append("\t").append(// pos
nodes.getId("b" + k)).append("\t").append(//neg
nodes.getId("r" + p)).append("\t").append(// nodes
brGraph.nodeSize()).append("\t").append(//edges
brGraph.edgeSize()).append(// waiting for .append(-1) // label dependencies
"\t");
int labelDependencies = 0;
StringBuilder sb = new StringBuilder();
for (int i = 0; i < brGraph.getFeatureSet().size(); i++) {
if (i > 0)
sb.append(":");
sb.append(brGraph.featureLibrary.getSymbol(i + 1));
}
for (int u = 0; u < brGraph.node_hi; u++) {
HashSet<Integer> outgoingFeatures = new HashSet<Integer>();
for (int ec = brGraph.node_near_lo[u]; ec < brGraph.node_near_hi[u]; ec++) {
int v = brGraph.edge_dest[ec];
sb.append("\t").append(u).append("->").append(v).append(":");
for (int lc = brGraph.edge_labels_lo[ec]; lc < brGraph.edge_labels_hi[ec]; lc++) {
outgoingFeatures.add(brGraph.label_feature_id[lc]);
if (lc > brGraph.edge_labels_lo[ec])
sb.append(",");
sb.append(brGraph.label_feature_id[lc]).append("@").append(brGraph.label_feature_weight[lc]);
}
}
labelDependencies += outgoingFeatures.size() * (brGraph.node_near_hi[u] - brGraph.node_near_lo[u]);
}
serialized.append(labelDependencies).append("\t").append(sb);
examples.add(serialized.toString());
}
}
}
use of edu.cmu.ml.proppr.learn.RegularizationSchedule in project ProPPR by TeamCohen.
the class TrainerTest method setup.
@Before
public void setup() {
super.setup();
this.srw = new SRW();
this.srw.setRegularizer(new RegularizationSchedule(this.srw, new RegularizeL2()));
this.srw.setSquashingFunction(new ReLU<String>());
this.initTrainer();
query = new TIntDoubleHashMap();
query.put(nodes.getId("r0"), 1.0);
examples = new ArrayList<String>();
for (int k = 0; k < this.magicNumber; k++) {
for (int p = 0; p < this.magicNumber; p++) {
examples.add(new PosNegRWExample(brGraph, query, new int[] { nodes.getId("b" + k) }, new int[] { nodes.getId("r" + p) }).serialize());
}
}
}
Aggregations