Search in sources :

Example 6 with Classifier.supervised.modelAdaptation._AdaptStruct

use of Classifier.supervised.modelAdaptation._AdaptStruct in project IR_Base by Linda-sunshine.

the class CoRegLR method train.

// this is batch training in each individual user
@Override
public double train() {
    int[] iflag = { 0 }, iprint = { -1, 3 };
    double fValue, oldFValue = Double.MAX_VALUE;
    ;
    int vSize = (m_featureSize + 1) * m_userList.size(), displayCount = 0;
    double oldMag = 0;
    initLBFGS();
    init();
    try {
        do {
            fValue = 0;
            // initialize gradient
            Arrays.fill(m_g, 0);
            // accumulate function values and gradients from each user
            for (_AdaptStruct user : m_userList) {
                fValue += calculateFuncValue(user);
                calculateGradients(user);
            }
            // added by Lin for stopping lbfgs.
            double curMag = gradientTest();
            if (Math.abs(oldMag - curMag) < 0.1)
                break;
            oldMag = curMag;
            if (m_displayLv == 2) {
                System.out.println("Fvalue is " + fValue);
            } else if (m_displayLv == 1) {
                if (fValue < oldFValue)
                    System.out.print("o");
                else
                    System.out.print("x");
                if (++displayCount % 100 == 0)
                    System.out.println();
            }
            oldFValue = fValue;
            // In the training process, sharedW is updated.
            LBFGS.lbfgs(vSize, 5, _CoRegLRAdaptStruct.getSharedW(), fValue, m_g, false, m_diag, iprint, 1e-3, 1e-16, iflag);
        } while (iflag[0] != 0);
        System.out.println();
    } catch (ExceptionWithIflag e) {
        e.printStackTrace();
    }
    setPersonalizedModel();
    return oldFValue;
}
Also used : Classifier.supervised.modelAdaptation._AdaptStruct(Classifier.supervised.modelAdaptation._AdaptStruct) ExceptionWithIflag(LBFGS.LBFGS.ExceptionWithIflag)

Example 7 with Classifier.supervised.modelAdaptation._AdaptStruct

use of Classifier.supervised.modelAdaptation._AdaptStruct in project IR_Base by Linda-sunshine.

the class CoRegLR method setPersonalizedModel.

@Override
protected void setPersonalizedModel() {
    int vSize = m_featureSize + 1;
    m_pWeights = new double[m_featureSize + 1];
    double[] sharedW = _CoRegLRAdaptStruct.getSharedW();
    for (_AdaptStruct user : m_userList) {
        System.arraycopy(sharedW, user.getId() * vSize, m_pWeights, 0, vSize);
        user.setPersonalizedModel(m_pWeights);
    }
}
Also used : Classifier.supervised.modelAdaptation._AdaptStruct(Classifier.supervised.modelAdaptation._AdaptStruct)

Example 8 with Classifier.supervised.modelAdaptation._AdaptStruct

use of Classifier.supervised.modelAdaptation._AdaptStruct in project IR_Base by Linda-sunshine.

the class MTRegLR method train.

// this is batch training in each individual user
@Override
public double train() {
    int[] iflag = { 0 }, iprint = { -1, 3 };
    double fValue, oldFValue = Double.MAX_VALUE;
    int displayCount = 0;
    _AdaptStruct user;
    initLBFGS();
    init();
    try {
        do {
            fValue = 0;
            // initialize gradient
            Arrays.fill(m_g, 0);
            // accumulate function values and gradients from each user
            for (int i = 0; i < m_userList.size(); i++) {
                user = (_AdaptStruct) m_userList.get(i);
                // L + R^1(A_i)
                fValue += calculateFuncValue(user);
                calculateGradients(user);
            }
            if (m_displayLv == 2) {
                System.out.format("Fvalue is %.3f\t", fValue);
                gradientTest();
            } else if (m_displayLv == 1) {
                if (fValue < oldFValue)
                    System.out.print("o");
                else
                    System.out.print("x");
                if (++displayCount % 100 == 0)
                    System.out.println();
            }
            oldFValue = fValue;
            // In the training process, A is updated.
            LBFGS.lbfgs(m_ws.length, 6, m_ws, fValue, m_g, false, m_diag, iprint, 1e-3, 1e-16, iflag);
        } while (iflag[0] != 0);
        System.out.println();
    } catch (ExceptionWithIflag e) {
        System.err.println("********lbfgs fails here!******");
        e.printStackTrace();
    }
    setPersonalizedModel();
    return oldFValue;
}
Also used : Classifier.supervised.modelAdaptation._AdaptStruct(Classifier.supervised.modelAdaptation._AdaptStruct) ExceptionWithIflag(LBFGS.LBFGS.ExceptionWithIflag)

Example 9 with Classifier.supervised.modelAdaptation._AdaptStruct

use of Classifier.supervised.modelAdaptation._AdaptStruct in project IR_Base by Linda-sunshine.

the class MTRegLR method setPersonalizedModel.

public void setPersonalizedModel() {
    // w_i
    double[] uWeights;
    // personalzied weights, u*w_g + w_i
    double[] pWeights = new double[m_featureSize + 1];
    double[] gWeights = Arrays.copyOfRange(m_ws, m_userList.size() * (m_featureSize + 1), m_ws.length);
    int start, end;
    for (_AdaptStruct u : m_userList) {
        start = u.getId() * (m_featureSize + 1);
        end = (u.getId() + 1) * (m_featureSize + 1);
        uWeights = Arrays.copyOfRange(m_ws, start, end);
        for (int k = 0; k < uWeights.length; k++) pWeights[k] = uWeights[k] + m_u * gWeights[k];
        u.setPersonalizedModel(pWeights);
    }
}
Also used : Classifier.supervised.modelAdaptation._AdaptStruct(Classifier.supervised.modelAdaptation._AdaptStruct)

Example 10 with Classifier.supervised.modelAdaptation._AdaptStruct

use of Classifier.supervised.modelAdaptation._AdaptStruct in project IR_Base by Linda-sunshine.

the class RegLR method train.

// this is batch training in each individual user
@Override
public double train() {
    int[] iflag = { 0 }, iprint = { -1, 3 };
    double fValue = 0, w[], oldFValue = Double.MAX_VALUE, totalFvalue = 0;
    init();
    for (_AdaptStruct user : m_userList) {
        initLBFGS();
        iflag[0] = 0;
        try {
            w = user.getUserModel();
            oldFValue = Double.MAX_VALUE;
            do {
                // initialize gradient
                Arrays.fill(m_g, 0);
                fValue = calculateFuncValue(user);
                calculateGradients(user);
                if (m_displayLv == 2) {
                    System.out.println("Fvalue is " + fValue);
                    gradientTest();
                } else if (m_displayLv == 1) {
                    if (fValue < oldFValue)
                        System.out.print("o");
                    else
                        System.out.print("x");
                }
                oldFValue = fValue;
                // In the training process, A is updated.
                LBFGS.lbfgs(w.length, 6, w, fValue, m_g, false, m_diag, iprint, 1e-4, 1e-32, iflag);
            } while (iflag[0] != 0);
        } catch (ExceptionWithIflag e) {
            if (m_displayLv > 0)
                System.out.print("X");
            else
                System.out.println("X");
        }
        if (m_displayLv > 0)
            System.out.println();
        totalFvalue += fValue;
    }
    setPersonalizedModel();
    return totalFvalue;
}
Also used : Classifier.supervised.modelAdaptation._AdaptStruct(Classifier.supervised.modelAdaptation._AdaptStruct) ExceptionWithIflag(LBFGS.LBFGS.ExceptionWithIflag)

Aggregations

Classifier.supervised.modelAdaptation._AdaptStruct (Classifier.supervised.modelAdaptation._AdaptStruct)34 structures._User (structures._User)15 File (java.io.File)6 PrintWriter (java.io.PrintWriter)6 structures._Review (structures._Review)6 IOException (java.io.IOException)5 ExceptionWithIflag (LBFGS.LBFGS.ExceptionWithIflag)3 structures._SparseFeature (structures._SparseFeature)3 Feature (Classifier.supervised.liblinear.Feature)2 Parameter (Classifier.supervised.liblinear.Parameter)2 Problem (Classifier.supervised.liblinear.Problem)2 ArrayList (java.util.ArrayList)2 structures._HDPThetaStar (structures._HDPThetaStar)2 structures._PerformanceStat (structures._PerformanceStat)2 FileNotFoundException (java.io.FileNotFoundException)1 HashSet (java.util.HashSet)1 structures._thetaStar (structures._thetaStar)1