use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.
the class Dictionary method increment.
/**
* Increment the key's value, or set it if the key is new.
* Adds a TreeMap if key1 is new.
* @param map
* @param key1
* @param key2
* @param value
*/
public static void increment(TIntObjectMap<TIntDoubleHashMap> map, int key1, int key2, Double value) {
if (!map.containsKey(key1)) {
map.put(key1, new TIntDoubleHashMap());
}
TIntDoubleHashMap inner = map.get(key1);
if (!inner.containsKey(key2)) {
inner.put(key2, sanitize(value, key1, key2));
} else {
double newvalue = inner.get(key2) + value;
inner.put(key2, sanitize(newvalue, key1, key2));
}
}
use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.
the class RedBlueGraph method colorPart.
public TIntDoubleMap colorPart(final Set<String> color, TIntDoubleMap vec) {
final TIntDoubleMap result = new TIntDoubleHashMap();
vec.forEachEntry(new TIntDoubleProcedure() {
@Override
public boolean execute(int k, double v) {
if (color.contains(nodes.getSymbol(k)))
result.put(k, v);
return true;
}
});
return result;
}
use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.
the class GradientFinderTest method setup.
@Before
public void setup() {
super.setup();
this.srw = new SRW();
this.srw.setRegularizer(new RegularizationSchedule(this.srw, new RegularizeL2()));
this.srw.setSquashingFunction(new ReLU<String>());
this.initTrainer();
query = new TIntDoubleHashMap();
query.put(nodes.getId("r0"), 1.0);
examples = new ArrayList<String>();
for (int k = 0; k < this.magicNumber; k++) {
for (int p = 0; p < this.magicNumber; p++) {
StringBuilder serialized = new StringBuilder("r0").append("\t").append(//query
nodes.getId("r0")).append("\t").append(// pos
nodes.getId("b" + k)).append("\t").append(//neg
nodes.getId("r" + p)).append("\t").append(// nodes
brGraph.nodeSize()).append("\t").append(//edges
brGraph.edgeSize()).append(// waiting for .append(-1) // label dependencies
"\t");
int labelDependencies = 0;
StringBuilder sb = new StringBuilder();
for (int i = 0; i < brGraph.getFeatureSet().size(); i++) {
if (i > 0)
sb.append(":");
sb.append(brGraph.featureLibrary.getSymbol(i + 1));
}
for (int u = 0; u < brGraph.node_hi; u++) {
HashSet<Integer> outgoingFeatures = new HashSet<Integer>();
for (int ec = brGraph.node_near_lo[u]; ec < brGraph.node_near_hi[u]; ec++) {
int v = brGraph.edge_dest[ec];
sb.append("\t").append(u).append("->").append(v).append(":");
for (int lc = brGraph.edge_labels_lo[ec]; lc < brGraph.edge_labels_hi[ec]; lc++) {
outgoingFeatures.add(brGraph.label_feature_id[lc]);
if (lc > brGraph.edge_labels_lo[ec])
sb.append(",");
sb.append(brGraph.label_feature_id[lc]).append("@").append(brGraph.label_feature_weight[lc]);
}
}
labelDependencies += outgoingFeatures.size() * (brGraph.node_near_hi[u] - brGraph.node_near_lo[u]);
}
serialized.append(labelDependencies).append("\t").append(sb);
examples.add(serialized.toString());
}
}
}
use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.
the class SRWTest method myRWR.
public TIntDoubleMap myRWR(TIntDoubleMap startVec, LearningGraph g, int maxT) {
TIntDoubleMap vec = startVec;
TIntDoubleMap nextVec = null;
for (int t = 0; t < maxT; t++) {
nextVec = new TIntDoubleHashMap();
int k = -1;
for (int u : vec.keys()) {
k++;
// near(u).size();
int z = g.node_near_hi[u] - g.node_near_lo[u];
for (int eid = g.node_near_lo[u]; eid < g.node_near_hi[u]; eid++) {
//TIntIterator it = g.near(u).iterator(); it.hasNext(); ) {
int v = g.edge_dest[eid];
double inc = vec.get(u) / z;
Dictionary.increment(nextVec, v, inc);
log.debug("Incremented " + u + ", " + v + " by " + inc);
}
}
vec = nextVec;
}
return nextVec;
}
use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.
the class SRWTest method testLearn1.
/**
* check that learning on red/blue graph works
*/
@Test
public void testLearn1() {
TIntDoubleMap query = new TIntDoubleHashMap();
query.put(nodes.getId("r0"), 1.0);
int[] pos = new int[blues.size()];
{
int i = 0;
for (String k : blues) pos[i++] = nodes.getId(k);
}
int[] neg = new int[reds.size()];
{
int i = 0;
for (String k : reds) neg[i++] = nodes.getId(k);
}
PosNegRWExample example = factory.makeExample("learn1", brGraph, query, pos, neg);
// ParamVector weightVec = new SimpleParamVector();
// weightVec.put("fromb",1.01);
// weightVec.put("tob",1.0);
// weightVec.put("fromr",1.03);
// weightVec.put("tor",1.0);
// weightVec.put("id(restart)",1.02);
ParamVector<String, ?> trainedParams = uniformParams.copy();
double preLoss = makeLoss(trainedParams, example);
srw.clearLoss();
srw.trainOnExample(trainedParams, example, new StatusLogger());
double postLoss = makeLoss(trainedParams, example);
assertTrue(String.format("preloss %f >=? postloss %f", preLoss, postLoss), preLoss == 0 || preLoss > postLoss);
}
Aggregations