Search in sources :

Example 1 with RegularizeL2

use of edu.cmu.ml.proppr.learn.RegularizeL2 in project ProPPR by TeamCohen.

the class GradientFinderTest method setup.

@Before
public void setup() {
    super.setup();
    this.srw = new SRW();
    this.srw.setRegularizer(new RegularizationSchedule(this.srw, new RegularizeL2()));
    this.srw.setSquashingFunction(new ReLU<String>());
    this.initTrainer();
    query = new TIntDoubleHashMap();
    query.put(nodes.getId("r0"), 1.0);
    examples = new ArrayList<String>();
    for (int k = 0; k < this.magicNumber; k++) {
        for (int p = 0; p < this.magicNumber; p++) {
            StringBuilder serialized = new StringBuilder("r0").append("\t").append(//query
            nodes.getId("r0")).append("\t").append(// pos
            nodes.getId("b" + k)).append("\t").append(//neg
            nodes.getId("r" + p)).append("\t").append(// nodes
            brGraph.nodeSize()).append("\t").append(//edges
            brGraph.edgeSize()).append(// waiting for .append(-1) // label dependencies
            "\t");
            int labelDependencies = 0;
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < brGraph.getFeatureSet().size(); i++) {
                if (i > 0)
                    sb.append(":");
                sb.append(brGraph.featureLibrary.getSymbol(i + 1));
            }
            for (int u = 0; u < brGraph.node_hi; u++) {
                HashSet<Integer> outgoingFeatures = new HashSet<Integer>();
                for (int ec = brGraph.node_near_lo[u]; ec < brGraph.node_near_hi[u]; ec++) {
                    int v = brGraph.edge_dest[ec];
                    sb.append("\t").append(u).append("->").append(v).append(":");
                    for (int lc = brGraph.edge_labels_lo[ec]; lc < brGraph.edge_labels_hi[ec]; lc++) {
                        outgoingFeatures.add(brGraph.label_feature_id[lc]);
                        if (lc > brGraph.edge_labels_lo[ec])
                            sb.append(",");
                        sb.append(brGraph.label_feature_id[lc]).append("@").append(brGraph.label_feature_weight[lc]);
                    }
                }
                labelDependencies += outgoingFeatures.size() * (brGraph.node_near_hi[u] - brGraph.node_near_lo[u]);
            }
            serialized.append(labelDependencies).append("\t").append(sb);
            examples.add(serialized.toString());
        }
    }
}
Also used : RegularizationSchedule(edu.cmu.ml.proppr.learn.RegularizationSchedule) TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) RegularizeL2(edu.cmu.ml.proppr.learn.RegularizeL2) SRW(edu.cmu.ml.proppr.learn.SRW) HashSet(java.util.HashSet) Before(org.junit.Before)

Example 2 with RegularizeL2

use of edu.cmu.ml.proppr.learn.RegularizeL2 in project ProPPR by TeamCohen.

the class TrainerTest method setup.

@Before
public void setup() {
    super.setup();
    this.srw = new SRW();
    this.srw.setRegularizer(new RegularizationSchedule(this.srw, new RegularizeL2()));
    this.srw.setSquashingFunction(new ReLU<String>());
    this.initTrainer();
    query = new TIntDoubleHashMap();
    query.put(nodes.getId("r0"), 1.0);
    examples = new ArrayList<String>();
    for (int k = 0; k < this.magicNumber; k++) {
        for (int p = 0; p < this.magicNumber; p++) {
            examples.add(new PosNegRWExample(brGraph, query, new int[] { nodes.getId("b" + k) }, new int[] { nodes.getId("r" + p) }).serialize());
        }
    }
}
Also used : RegularizationSchedule(edu.cmu.ml.proppr.learn.RegularizationSchedule) TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) RegularizeL2(edu.cmu.ml.proppr.learn.RegularizeL2) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) SRW(edu.cmu.ml.proppr.learn.SRW) Before(org.junit.Before)

Aggregations

RegularizationSchedule (edu.cmu.ml.proppr.learn.RegularizationSchedule)2 RegularizeL2 (edu.cmu.ml.proppr.learn.RegularizeL2)2 SRW (edu.cmu.ml.proppr.learn.SRW)2 TIntDoubleHashMap (gnu.trove.map.hash.TIntDoubleHashMap)2 Before (org.junit.Before)2 PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)1 HashSet (java.util.HashSet)1