use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.
the class GradientFinderTest method setup.
@Before
public void setup() {
super.setup();
this.srw = new SRW();
this.srw.setRegularizer(new RegularizationSchedule(this.srw, new RegularizeL2()));
this.srw.setSquashingFunction(new ReLU<String>());
this.initTrainer();
query = new TIntDoubleHashMap();
query.put(nodes.getId("r0"), 1.0);
examples = new ArrayList<String>();
for (int k = 0; k < this.magicNumber; k++) {
for (int p = 0; p < this.magicNumber; p++) {
StringBuilder serialized = new StringBuilder("r0").append("\t").append(//query
nodes.getId("r0")).append("\t").append(// pos
nodes.getId("b" + k)).append("\t").append(//neg
nodes.getId("r" + p)).append("\t").append(// nodes
brGraph.nodeSize()).append("\t").append(//edges
brGraph.edgeSize()).append(// waiting for .append(-1) // label dependencies
"\t");
int labelDependencies = 0;
StringBuilder sb = new StringBuilder();
for (int i = 0; i < brGraph.getFeatureSet().size(); i++) {
if (i > 0)
sb.append(":");
sb.append(brGraph.featureLibrary.getSymbol(i + 1));
}
for (int u = 0; u < brGraph.node_hi; u++) {
HashSet<Integer> outgoingFeatures = new HashSet<Integer>();
for (int ec = brGraph.node_near_lo[u]; ec < brGraph.node_near_hi[u]; ec++) {
int v = brGraph.edge_dest[ec];
sb.append("\t").append(u).append("->").append(v).append(":");
for (int lc = brGraph.edge_labels_lo[ec]; lc < brGraph.edge_labels_hi[ec]; lc++) {
outgoingFeatures.add(brGraph.label_feature_id[lc]);
if (lc > brGraph.edge_labels_lo[ec])
sb.append(",");
sb.append(brGraph.label_feature_id[lc]).append("@").append(brGraph.label_feature_weight[lc]);
}
}
labelDependencies += outgoingFeatures.size() * (brGraph.node_near_hi[u] - brGraph.node_near_lo[u]);
}
serialized.append(labelDependencies).append("\t").append(sb);
examples.add(serialized.toString());
}
}
}
use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.
the class SRWTest method myRWR.
public TIntDoubleMap myRWR(TIntDoubleMap startVec, LearningGraph g, int maxT) {
TIntDoubleMap vec = startVec;
TIntDoubleMap nextVec = null;
for (int t = 0; t < maxT; t++) {
nextVec = new TIntDoubleHashMap();
int k = -1;
for (int u : vec.keys()) {
k++;
// near(u).size();
int z = g.node_near_hi[u] - g.node_near_lo[u];
for (int eid = g.node_near_lo[u]; eid < g.node_near_hi[u]; eid++) {
//TIntIterator it = g.near(u).iterator(); it.hasNext(); ) {
int v = g.edge_dest[eid];
double inc = vec.get(u) / z;
Dictionary.increment(nextVec, v, inc);
log.debug("Incremented " + u + ", " + v + " by " + inc);
}
}
vec = nextVec;
}
return nextVec;
}
use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.
the class SRWTest method testLearn1.
/**
* check that learning on red/blue graph works
*/
@Test
public void testLearn1() {
TIntDoubleMap query = new TIntDoubleHashMap();
query.put(nodes.getId("r0"), 1.0);
int[] pos = new int[blues.size()];
{
int i = 0;
for (String k : blues) pos[i++] = nodes.getId(k);
}
int[] neg = new int[reds.size()];
{
int i = 0;
for (String k : reds) neg[i++] = nodes.getId(k);
}
PosNegRWExample example = factory.makeExample("learn1", brGraph, query, pos, neg);
// ParamVector weightVec = new SimpleParamVector();
// weightVec.put("fromb",1.01);
// weightVec.put("tob",1.0);
// weightVec.put("fromr",1.03);
// weightVec.put("tor",1.0);
// weightVec.put("id(restart)",1.02);
ParamVector<String, ?> trainedParams = uniformParams.copy();
double preLoss = makeLoss(trainedParams, example);
srw.clearLoss();
srw.trainOnExample(trainedParams, example, new StatusLogger());
double postLoss = makeLoss(trainedParams, example);
assertTrue(String.format("preloss %f >=? postloss %f", preLoss, postLoss), preLoss == 0 || preLoss > postLoss);
}
use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.
the class SRWTest method myRWR.
public TIntDoubleMap myRWR(TIntDoubleMap startVec, LearningGraph g, int maxT, ParamVector<String, ?> params, SquashingFunction scheme) {
TIntDoubleMap vec = startVec;
TIntDoubleMap nextVec = null;
for (int t = 0; t < maxT; t++) {
nextVec = new TIntDoubleHashMap();
int k = -1;
for (int u : vec.keys()) {
k++;
// compute total edge weight:
double z = 0.0;
for (int eid = g.node_near_lo[u]; eid < g.node_near_hi[u]; eid++) {
int v = g.edge_dest[eid];
double suv = 0.0;
for (int fid = g.edge_labels_lo[eid]; fid < g.edge_labels_hi[eid]; fid++) {
suv += Dictionary.safeGet(params, (g.featureLibrary.getSymbol(g.label_feature_id[fid])), scheme.defaultValue()) * g.label_feature_weight[fid];
}
double ew = scheme.edgeWeight(suv);
z += ew;
}
for (int eid = g.node_near_lo[u]; eid < g.node_near_hi[u]; eid++) {
int v = g.edge_dest[eid];
double suv = 0.0;
for (int fid = g.edge_labels_lo[eid]; fid < g.edge_labels_hi[eid]; fid++) {
suv += Dictionary.safeGet(params, (g.featureLibrary.getSymbol(g.label_feature_id[fid])), scheme.defaultValue()) * g.label_feature_weight[fid];
}
double ew = scheme.edgeWeight(suv);
double inc = vec.get(u) * ew / z;
Dictionary.increment(nextVec, v, inc);
log.debug("Incremented " + u + ", " + v + " by " + inc);
}
}
vec = nextVec;
}
return nextVec;
}
use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.
the class PathDprProverTest method foo.
@Test
public void foo() {
TIntDoubleHashMap foo = new TIntDoubleHashMap();
foo.put(1, 0.0);
foo.adjustOrPutValue(1, 0.5, 0.5);
foo.adjustOrPutValue(2, 0.5, 0.5);
System.out.println("1: " + foo.get(1));
System.out.println("2: " + foo.get(2));
}
Aggregations