Search in sources :

Example 6 with TIntDoubleIterator

use of gnu.trove.iterator.TIntDoubleIterator in project ProPPR by TeamCohen.

the class ParamsFile method save.

public static void save(TIntDoubleMap params, File paramsFile, ModuleConfiguration config) {
    BufferedWriter writer;
    try {
        writer = new BufferedWriter(new FileWriter(paramsFile));
        // write header
        if (config != null)
            saveHeader(writer, config);
        // write params
        for (TIntDoubleIterator e = params.iterator(); e.hasNext(); ) {
            e.advance();
            saveParameter(writer, String.valueOf(e.key()), e.value());
        }
        writer.close();
    } catch (IOException e) {
        // TODO Auto-generated catch block
        e.printStackTrace();
    }
}
Also used : FileWriter(java.io.FileWriter) TIntDoubleIterator(gnu.trove.iterator.TIntDoubleIterator) IOException(java.io.IOException) BufferedWriter(java.io.BufferedWriter)

Example 7 with TIntDoubleIterator

use of gnu.trove.iterator.TIntDoubleIterator in project ProPPR by TeamCohen.

the class Dictionary method buildString.

/**
	 * Serialize this map to a StringBuilder, using the specified delimiter between key:value pairs.
	 * The string added to the StringBuilder is:
	 *    $delim$key1:$value1$delim$key2:$value2 ... $delim$keyN$valueN
	 * @param map
	 * @param sb
	 * @param delim
	 * @return 
	 */
public static StringBuilder buildString(TIntDoubleMap map, StringBuilder sb, String delim) {
    for (TIntDoubleIterator e = map.iterator(); e.hasNext(); ) {
        e.advance();
        sb.append(delim).append(e.key()).append(":").append(e.value());
    }
    return sb;
}
Also used : TIntDoubleIterator(gnu.trove.iterator.TIntDoubleIterator)

Example 8 with TIntDoubleIterator

use of gnu.trove.iterator.TIntDoubleIterator in project ProPPR by TeamCohen.

the class SRW method inferenceUpdate.

protected void inferenceUpdate(PosNegRWExample example, StatusLogger status) {
    PprExample ex = (PprExample) example;
    double[] pNext = new double[ex.getGraph().node_hi];
    TIntDoubleMap[] dNext = new TIntDoubleMap[ex.getGraph().node_hi];
    // p: 2. for each node u
    for (int uid = 0; uid < ex.getGraph().node_hi; uid++) {
        if (log.isInfoEnabled() && status.due(4))
            log.info("Inference: node " + (uid + 1) + " of " + (ex.getGraph().node_hi));
        // p: 2(a) p_u^{t+1} += alpha * s_u
        pNext[uid] += c.apr.alpha * Dictionary.safeGet(ex.getQueryVec(), uid, 0.0);
        // p: 2(b) for each neighbor v of u:
        for (int eid = ex.getGraph().node_near_lo[uid], xvi = 0; eid < ex.getGraph().node_near_hi[uid]; eid++, xvi++) {
            int vid = ex.getGraph().edge_dest[eid];
            // p: 2(b)i. p_v^{t+1} += (1-alpha) * p_u^t * M_uv
            if (vid >= pNext.length) {
                throw new IllegalStateException("vid=" + vid + " > pNext.length=" + pNext.length);
            }
            pNext[vid] += (1 - c.apr.alpha) * ex.p[uid] * ex.M[uid][xvi];
            // d: i. for each feature i in dM_uv:
            if (dNext[vid] == null)
                dNext[vid] = new TIntDoubleHashMap(ex.dM_hi[uid][xvi] - ex.dM_lo[uid][xvi]);
            for (int dmi = ex.dM_lo[uid][xvi]; dmi < ex.dM_hi[uid][xvi]; dmi++) {
                // d_vi^{t+1} += (1-alpha) * p_u^{t} * dM_uvi
                if (ex.dM_value[dmi] == 0)
                    continue;
                double inc = (1 - c.apr.alpha) * ex.p[uid] * ex.dM_value[dmi];
                dNext[vid].adjustOrPutValue(ex.dM_feature_id[dmi], inc, inc);
            }
            // skip when d is empty
            if (ex.dp[uid] == null)
                continue;
            for (TIntDoubleIterator it = ex.dp[uid].iterator(); it.hasNext(); ) {
                it.advance();
                if (it.value() == 0)
                    continue;
                // d_vi^{t+1} += (1-alpha) * d_ui^t * M_uv
                double inc = (1 - c.apr.alpha) * it.value() * ex.M[uid][xvi];
                dNext[vid].adjustOrPutValue(it.key(), inc, inc);
            }
        }
    }
    // sanity check on p
    if (log.isDebugEnabled()) {
        double sum = 0;
        for (double d : pNext) sum += d;
        if (Math.abs(sum - 1.0) > c.apr.epsilon)
            log.error("invalid p computed: " + sum);
    }
    ex.p = pNext;
    ex.dp = dNext;
}
Also used : TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) TIntDoubleMap(gnu.trove.map.TIntDoubleMap) TIntDoubleIterator(gnu.trove.iterator.TIntDoubleIterator) PprExample(edu.cmu.ml.proppr.examples.PprExample)

Example 9 with TIntDoubleIterator

use of gnu.trove.iterator.TIntDoubleIterator in project ProPPR by TeamCohen.

the class SRW method sgd.

/** edits params */
protected void sgd(ParamVector<String, ?> params, PosNegRWExample ex) {
    TIntDoubleMap gradient = gradient(params, ex);
    // apply gradient to param vector
    for (TIntDoubleIterator grad = gradient.iterator(); grad.hasNext(); ) {
        grad.advance();
        if (grad.value() == 0)
            continue;
        String feature = ex.getGraph().featureLibrary.getSymbol(grad.key());
        if (trainable(feature)) {
            params.adjustValue(feature, -learningRate(feature) * grad.value());
            if (params.get(feature).isInfinite()) {
                log.warn("Infinity at " + feature + "; gradient " + grad.value());
            }
        }
    }
}
Also used : TIntDoubleMap(gnu.trove.map.TIntDoubleMap) TIntDoubleIterator(gnu.trove.iterator.TIntDoubleIterator)

Example 10 with TIntDoubleIterator

use of gnu.trove.iterator.TIntDoubleIterator in project ProPPR by TeamCohen.

the class SRW method inference.

/** fills p, dp 
	 * @param params */
protected void inference(ParamVector<String, ?> params, PosNegRWExample example, StatusLogger status) {
    PosNegRWExample ex = (PosNegRWExample) example;
    ex.p = new double[ex.getGraph().node_hi];
    ex.dp = new TIntDoubleMap[ex.getGraph().node_hi];
    Arrays.fill(ex.p, 0.0);
    // copy query into p
    for (TIntDoubleIterator it = ex.getQueryVec().iterator(); it.hasNext(); ) {
        it.advance();
        ex.p[it.key()] = it.value();
    }
    for (int i = 0; i < c.apr.maxDepth; i++) {
        if (log.isInfoEnabled() && status.due(3))
            log.info("APR: iter " + (i + 1) + " of " + (c.apr.maxDepth));
        inferenceUpdate(ex, status);
    }
}
Also used : PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) TIntDoubleIterator(gnu.trove.iterator.TIntDoubleIterator)

Aggregations

TIntDoubleIterator (gnu.trove.iterator.TIntDoubleIterator)13 TIntDoubleMap (gnu.trove.map.TIntDoubleMap)6 PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)3 TIntDoubleHashMap (gnu.trove.map.hash.TIntDoubleHashMap)3 PprExample (edu.cmu.ml.proppr.examples.PprExample)2 DprExample (edu.cmu.ml.proppr.examples.DprExample)1 Feature (edu.cmu.ml.proppr.prove.wam.Feature)1 SimpleParamVector (edu.cmu.ml.proppr.util.math.SimpleParamVector)1 TIntIterator (gnu.trove.iterator.TIntIterator)1 TIntArrayList (gnu.trove.list.array.TIntArrayList)1 BufferedWriter (java.io.BufferedWriter)1 FileWriter (java.io.FileWriter)1 IOException (java.io.IOException)1 HashMap (java.util.HashMap)1 HashSet (java.util.HashSet)1 Map (java.util.Map)1