Search in sources :

Example 11 with TIntDoubleHashMap

use of gnu.trove.map.hash.TIntDoubleHashMap in project cogcomp-nlp by CogComp.

the class FeatureUtilities method convert.

/**
     * Convert a feature set into a pair of arrays of integers and doubles by looking up the feature
     * name in the provided lexicon.
     *
     * @param features The feature set
     * @param lexicon The lexicon
     * @param trainingMode Should an unseen feature string be added to the lexicon? If this is
     *        false, unseen features will be given an ID whose value is one more than the number of
     *        features.
     * @return a pair of int[] and double[], representing the feature ids and values.
     */
public static Pair<int[], double[]> convert(Set<Feature> features, Lexicon lexicon, boolean trainingMode) {
    TIntDoubleHashMap fMap = new TIntDoubleHashMap(features.size());
    for (Feature feature : features) {
        final int featureId = FeatureUtilities.getFeatureId(lexicon, trainingMode, feature);
        if (featureId < 0)
            continue;
        double value = feature.getValue() + fMap.get(featureId);
        fMap.put(featureId, value);
    }
    int[] idsOriginal = fMap.keys();
    int[] ids = new int[idsOriginal.length];
    System.arraycopy(idsOriginal, 0, ids, 0, ids.length);
    Arrays.sort(ids);
    double[] vals = new double[fMap.size()];
    int count = 0;
    for (int key : ids) {
        vals[count++] = fMap.get(key);
    }
    return new Pair<>(ids, vals);
}
Also used : TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) RealPrimitiveFeature(edu.illinois.cs.cogcomp.lbjava.classify.RealPrimitiveFeature) DiscretePrimitiveFeature(edu.illinois.cs.cogcomp.lbjava.classify.DiscretePrimitiveFeature) Pair(edu.illinois.cs.cogcomp.core.datastructures.Pair)

Example 12 with TIntDoubleHashMap

use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.

the class SRWTest method moreSetup.

@Override
public void moreSetup(LearningGraphBuilder lgb) {
    uniformParams = srw.getRegularizer().setupParams(new SimpleParamVector<String>(new ConcurrentHashMap<String, Double>()));
    for (String n : new String[] { "fromb", "tob", "fromr", "tor" }) uniformParams.put(n, srw.getSquashingFunction().defaultValue());
    startVec = new TIntDoubleHashMap();
    startVec.put(nodes.getId("r0"), 1.0);
}
Also used : TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) SimpleParamVector(edu.cmu.ml.proppr.util.math.SimpleParamVector)

Example 13 with TIntDoubleHashMap

use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.

the class TrainerTest method setup.

@Before
public void setup() {
    super.setup();
    this.srw = new SRW();
    this.srw.setRegularizer(new RegularizationSchedule(this.srw, new RegularizeL2()));
    this.srw.setSquashingFunction(new ReLU<String>());
    this.initTrainer();
    query = new TIntDoubleHashMap();
    query.put(nodes.getId("r0"), 1.0);
    examples = new ArrayList<String>();
    for (int k = 0; k < this.magicNumber; k++) {
        for (int p = 0; p < this.magicNumber; p++) {
            examples.add(new PosNegRWExample(brGraph, query, new int[] { nodes.getId("b" + k) }, new int[] { nodes.getId("r" + p) }).serialize());
        }
    }
}
Also used : RegularizationSchedule(edu.cmu.ml.proppr.learn.RegularizationSchedule) TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) RegularizeL2(edu.cmu.ml.proppr.learn.RegularizeL2) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) SRW(edu.cmu.ml.proppr.learn.SRW) Before(org.junit.Before)

Example 14 with TIntDoubleHashMap

use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.

the class FiniteDifferenceTest method moreSetup.

@Override
public void moreSetup(LearningGraphBuilder lgb) {
    startVec = new TIntDoubleHashMap();
    startVec.put(nodes.getId("r0"), 1.0);
}
Also used : TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap)

Example 15 with TIntDoubleHashMap

use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.

the class SRW method inferenceUpdate.

protected void inferenceUpdate(PosNegRWExample example, StatusLogger status) {
    PprExample ex = (PprExample) example;
    double[] pNext = new double[ex.getGraph().node_hi];
    TIntDoubleMap[] dNext = new TIntDoubleMap[ex.getGraph().node_hi];
    // p: 2. for each node u
    for (int uid = 0; uid < ex.getGraph().node_hi; uid++) {
        if (log.isInfoEnabled() && status.due(4))
            log.info("Inference: node " + (uid + 1) + " of " + (ex.getGraph().node_hi));
        // p: 2(a) p_u^{t+1} += alpha * s_u
        pNext[uid] += c.apr.alpha * Dictionary.safeGet(ex.getQueryVec(), uid, 0.0);
        // p: 2(b) for each neighbor v of u:
        for (int eid = ex.getGraph().node_near_lo[uid], xvi = 0; eid < ex.getGraph().node_near_hi[uid]; eid++, xvi++) {
            int vid = ex.getGraph().edge_dest[eid];
            // p: 2(b)i. p_v^{t+1} += (1-alpha) * p_u^t * M_uv
            if (vid >= pNext.length) {
                throw new IllegalStateException("vid=" + vid + " > pNext.length=" + pNext.length);
            }
            pNext[vid] += (1 - c.apr.alpha) * ex.p[uid] * ex.M[uid][xvi];
            // d: i. for each feature i in dM_uv:
            if (dNext[vid] == null)
                dNext[vid] = new TIntDoubleHashMap(ex.dM_hi[uid][xvi] - ex.dM_lo[uid][xvi]);
            for (int dmi = ex.dM_lo[uid][xvi]; dmi < ex.dM_hi[uid][xvi]; dmi++) {
                // d_vi^{t+1} += (1-alpha) * p_u^{t} * dM_uvi
                if (ex.dM_value[dmi] == 0)
                    continue;
                double inc = (1 - c.apr.alpha) * ex.p[uid] * ex.dM_value[dmi];
                dNext[vid].adjustOrPutValue(ex.dM_feature_id[dmi], inc, inc);
            }
            // skip when d is empty
            if (ex.dp[uid] == null)
                continue;
            for (TIntDoubleIterator it = ex.dp[uid].iterator(); it.hasNext(); ) {
                it.advance();
                if (it.value() == 0)
                    continue;
                // d_vi^{t+1} += (1-alpha) * d_ui^t * M_uv
                double inc = (1 - c.apr.alpha) * it.value() * ex.M[uid][xvi];
                dNext[vid].adjustOrPutValue(it.key(), inc, inc);
            }
        }
    }
    // sanity check on p
    if (log.isDebugEnabled()) {
        double sum = 0;
        for (double d : pNext) sum += d;
        if (Math.abs(sum - 1.0) > c.apr.epsilon)
            log.error("invalid p computed: " + sum);
    }
    ex.p = pNext;
    ex.dp = dNext;
}
Also used : TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) TIntDoubleMap(gnu.trove.map.TIntDoubleMap) TIntDoubleIterator(gnu.trove.iterator.TIntDoubleIterator) PprExample(edu.cmu.ml.proppr.examples.PprExample)

Aggregations

TIntDoubleHashMap (gnu.trove.map.hash.TIntDoubleHashMap)18 TIntDoubleMap (gnu.trove.map.TIntDoubleMap)9 PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)3 PprExample (edu.cmu.ml.proppr.examples.PprExample)2 RegularizationSchedule (edu.cmu.ml.proppr.learn.RegularizationSchedule)2 RegularizeL2 (edu.cmu.ml.proppr.learn.RegularizeL2)2 SRW (edu.cmu.ml.proppr.learn.SRW)2 Feature (edu.cmu.ml.proppr.prove.wam.Feature)2 TIntDoubleIterator (gnu.trove.iterator.TIntDoubleIterator)2 TIntObjectHashMap (gnu.trove.map.hash.TIntObjectHashMap)2 TIntDoubleProcedure (gnu.trove.procedure.TIntDoubleProcedure)2 HashMap (java.util.HashMap)2 Before (org.junit.Before)2 Test (org.junit.Test)2 GrounderTest (edu.cmu.ml.proppr.GrounderTest)1 GraphFormatException (edu.cmu.ml.proppr.graph.GraphFormatException)1 LearningGraph (edu.cmu.ml.proppr.graph.LearningGraph)1 Outlink (edu.cmu.ml.proppr.prove.wam.Outlink)1 SparseGraphPluginTest (edu.cmu.ml.proppr.prove.wam.plugins.SparseGraphPluginTest)1 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)1