Search in sources :

Example 6 with TIntDoubleHashMap

use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.

the class GradientFinderTest method setup.

@Before
public void setup() {
    super.setup();
    this.srw = new SRW();
    this.srw.setRegularizer(new RegularizationSchedule(this.srw, new RegularizeL2()));
    this.srw.setSquashingFunction(new ReLU<String>());
    this.initTrainer();
    query = new TIntDoubleHashMap();
    query.put(nodes.getId("r0"), 1.0);
    examples = new ArrayList<String>();
    for (int k = 0; k < this.magicNumber; k++) {
        for (int p = 0; p < this.magicNumber; p++) {
            StringBuilder serialized = new StringBuilder("r0").append("\t").append(//query
            nodes.getId("r0")).append("\t").append(// pos
            nodes.getId("b" + k)).append("\t").append(//neg
            nodes.getId("r" + p)).append("\t").append(// nodes
            brGraph.nodeSize()).append("\t").append(//edges
            brGraph.edgeSize()).append(// waiting for .append(-1) // label dependencies
            "\t");
            int labelDependencies = 0;
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < brGraph.getFeatureSet().size(); i++) {
                if (i > 0)
                    sb.append(":");
                sb.append(brGraph.featureLibrary.getSymbol(i + 1));
            }
            for (int u = 0; u < brGraph.node_hi; u++) {
                HashSet<Integer> outgoingFeatures = new HashSet<Integer>();
                for (int ec = brGraph.node_near_lo[u]; ec < brGraph.node_near_hi[u]; ec++) {
                    int v = brGraph.edge_dest[ec];
                    sb.append("\t").append(u).append("->").append(v).append(":");
                    for (int lc = brGraph.edge_labels_lo[ec]; lc < brGraph.edge_labels_hi[ec]; lc++) {
                        outgoingFeatures.add(brGraph.label_feature_id[lc]);
                        if (lc > brGraph.edge_labels_lo[ec])
                            sb.append(",");
                        sb.append(brGraph.label_feature_id[lc]).append("@").append(brGraph.label_feature_weight[lc]);
                    }
                }
                labelDependencies += outgoingFeatures.size() * (brGraph.node_near_hi[u] - brGraph.node_near_lo[u]);
            }
            serialized.append(labelDependencies).append("\t").append(sb);
            examples.add(serialized.toString());
        }
    }
}
Also used : RegularizationSchedule(edu.cmu.ml.proppr.learn.RegularizationSchedule) TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) RegularizeL2(edu.cmu.ml.proppr.learn.RegularizeL2) SRW(edu.cmu.ml.proppr.learn.SRW) HashSet(java.util.HashSet) Before(org.junit.Before)

Example 7 with TIntDoubleHashMap

use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.

the class SRWTest method myRWR.

public TIntDoubleMap myRWR(TIntDoubleMap startVec, LearningGraph g, int maxT) {
    TIntDoubleMap vec = startVec;
    TIntDoubleMap nextVec = null;
    for (int t = 0; t < maxT; t++) {
        nextVec = new TIntDoubleHashMap();
        int k = -1;
        for (int u : vec.keys()) {
            k++;
            // near(u).size();
            int z = g.node_near_hi[u] - g.node_near_lo[u];
            for (int eid = g.node_near_lo[u]; eid < g.node_near_hi[u]; eid++) {
                //TIntIterator it = g.near(u).iterator(); it.hasNext(); ) {
                int v = g.edge_dest[eid];
                double inc = vec.get(u) / z;
                Dictionary.increment(nextVec, v, inc);
                log.debug("Incremented " + u + ", " + v + " by " + inc);
            }
        }
        vec = nextVec;
    }
    return nextVec;
}
Also used : TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) TIntDoubleMap(gnu.trove.map.TIntDoubleMap)

Example 8 with TIntDoubleHashMap

use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.

the class SRWTest method testLearn1.

/**
	 * check that learning on red/blue graph works
	 */
@Test
public void testLearn1() {
    TIntDoubleMap query = new TIntDoubleHashMap();
    query.put(nodes.getId("r0"), 1.0);
    int[] pos = new int[blues.size()];
    {
        int i = 0;
        for (String k : blues) pos[i++] = nodes.getId(k);
    }
    int[] neg = new int[reds.size()];
    {
        int i = 0;
        for (String k : reds) neg[i++] = nodes.getId(k);
    }
    PosNegRWExample example = factory.makeExample("learn1", brGraph, query, pos, neg);
    //		ParamVector weightVec = new SimpleParamVector();
    //		weightVec.put("fromb",1.01);
    //		weightVec.put("tob",1.0);
    //		weightVec.put("fromr",1.03);
    //		weightVec.put("tor",1.0);
    //		weightVec.put("id(restart)",1.02);
    ParamVector<String, ?> trainedParams = uniformParams.copy();
    double preLoss = makeLoss(trainedParams, example);
    srw.clearLoss();
    srw.trainOnExample(trainedParams, example, new StatusLogger());
    double postLoss = makeLoss(trainedParams, example);
    assertTrue(String.format("preloss %f >=? postloss %f", preLoss, postLoss), preLoss == 0 || preLoss > postLoss);
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) TIntDoubleMap(gnu.trove.map.TIntDoubleMap) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) Test(org.junit.Test)

Example 9 with TIntDoubleHashMap

use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.

the class SRWTest method myRWR.

public TIntDoubleMap myRWR(TIntDoubleMap startVec, LearningGraph g, int maxT, ParamVector<String, ?> params, SquashingFunction scheme) {
    TIntDoubleMap vec = startVec;
    TIntDoubleMap nextVec = null;
    for (int t = 0; t < maxT; t++) {
        nextVec = new TIntDoubleHashMap();
        int k = -1;
        for (int u : vec.keys()) {
            k++;
            // compute total edge weight:
            double z = 0.0;
            for (int eid = g.node_near_lo[u]; eid < g.node_near_hi[u]; eid++) {
                int v = g.edge_dest[eid];
                double suv = 0.0;
                for (int fid = g.edge_labels_lo[eid]; fid < g.edge_labels_hi[eid]; fid++) {
                    suv += Dictionary.safeGet(params, (g.featureLibrary.getSymbol(g.label_feature_id[fid])), scheme.defaultValue()) * g.label_feature_weight[fid];
                }
                double ew = scheme.edgeWeight(suv);
                z += ew;
            }
            for (int eid = g.node_near_lo[u]; eid < g.node_near_hi[u]; eid++) {
                int v = g.edge_dest[eid];
                double suv = 0.0;
                for (int fid = g.edge_labels_lo[eid]; fid < g.edge_labels_hi[eid]; fid++) {
                    suv += Dictionary.safeGet(params, (g.featureLibrary.getSymbol(g.label_feature_id[fid])), scheme.defaultValue()) * g.label_feature_weight[fid];
                }
                double ew = scheme.edgeWeight(suv);
                double inc = vec.get(u) * ew / z;
                Dictionary.increment(nextVec, v, inc);
                log.debug("Incremented " + u + ", " + v + " by " + inc);
            }
        }
        vec = nextVec;
    }
    return nextVec;
}
Also used : TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) TIntDoubleMap(gnu.trove.map.TIntDoubleMap)

Example 10 with TIntDoubleHashMap

use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.

the class PathDprProverTest method foo.

@Test
public void foo() {
    TIntDoubleHashMap foo = new TIntDoubleHashMap();
    foo.put(1, 0.0);
    foo.adjustOrPutValue(1, 0.5, 0.5);
    foo.adjustOrPutValue(2, 0.5, 0.5);
    System.out.println("1: " + foo.get(1));
    System.out.println("2: " + foo.get(2));
}
Also used : TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) Test(org.junit.Test) GrounderTest(edu.cmu.ml.proppr.GrounderTest) SparseGraphPluginTest(edu.cmu.ml.proppr.prove.wam.plugins.SparseGraphPluginTest)

Aggregations

TIntDoubleHashMap (gnu.trove.map.hash.TIntDoubleHashMap)18 TIntDoubleMap (gnu.trove.map.TIntDoubleMap)9 PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)3 PprExample (edu.cmu.ml.proppr.examples.PprExample)2 RegularizationSchedule (edu.cmu.ml.proppr.learn.RegularizationSchedule)2 RegularizeL2 (edu.cmu.ml.proppr.learn.RegularizeL2)2 SRW (edu.cmu.ml.proppr.learn.SRW)2 Feature (edu.cmu.ml.proppr.prove.wam.Feature)2 TIntDoubleIterator (gnu.trove.iterator.TIntDoubleIterator)2 TIntObjectHashMap (gnu.trove.map.hash.TIntObjectHashMap)2 TIntDoubleProcedure (gnu.trove.procedure.TIntDoubleProcedure)2 HashMap (java.util.HashMap)2 Before (org.junit.Before)2 Test (org.junit.Test)2 GrounderTest (edu.cmu.ml.proppr.GrounderTest)1 GraphFormatException (edu.cmu.ml.proppr.graph.GraphFormatException)1 LearningGraph (edu.cmu.ml.proppr.graph.LearningGraph)1 Outlink (edu.cmu.ml.proppr.prove.wam.Outlink)1 SparseGraphPluginTest (edu.cmu.ml.proppr.prove.wam.plugins.SparseGraphPluginTest)1 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)1