Search in sources :

Example 21 with DenseVector

use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.

the class Weights method deepCopy.

// todo buggy
public Weights deepCopy() {
    Weights copy = new Weights(this.numClasses, numFeatures);
    copy.weightVector = new DenseVector(this.weightVector);
    return copy;
}
Also used : DenseVector(org.apache.mahout.math.DenseVector)

Example 22 with DenseVector

use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.

the class GeneralF1Predictor method exhaustiveSearch.

public static MultiLabel exhaustiveSearch(int numClasses, Matrix lossMatrix, List<Double> probabilities) {
    double bestScore = Double.POSITIVE_INFINITY;
    Vector vector = new DenseVector(probabilities.size());
    for (int i = 0; i < vector.size(); i++) {
        vector.set(i, probabilities.get(i));
    }
    List<MultiLabel> multiLabels = Enumerator.enumerate(numClasses);
    MultiLabel multiLabel = null;
    for (int j = 0; j < lossMatrix.numCols(); j++) {
        Vector column = lossMatrix.viewColumn(j);
        double score = column.dot(vector);
        System.out.println("column " + j + ", expected loss = " + score);
        if (score < bestScore) {
            bestScore = score;
            multiLabel = multiLabels.get(j);
        }
    }
    return multiLabel;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) DenseVector(org.apache.mahout.math.DenseVector)

Example 23 with DenseVector

use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.

the class LBFGS method iterate.

public void iterate() {
    if (logger.isDebugEnabled()) {
        logger.debug("start one iteration");
    }
    // we need to make a copy of the gradient; should not use pointer
    Vector oldGradient = new DenseVector(function.getGradient());
    Vector direction = findDirection();
    if (logger.isDebugEnabled()) {
        logger.debug("norm of direction = " + direction.norm(2));
    }
    BackTrackingLineSearcher.MoveInfo moveInfo = lineSearcher.moveAlongDirection(direction);
    Vector s = moveInfo.getStep();
    Vector newGradient = function.getGradient();
    Vector y = newGradient.minus(oldGradient);
    double denominator = y.dot(s);
    // todo what to do if denominator is not positive?
    // round-off errors and an ill-conditioned inverse Hessian
    double rho = 0;
    if (denominator > 0) {
        rho = 1 / denominator;
    } else {
        terminator.forceTerminate();
        if (logger.isWarnEnabled()) {
            logger.warn("denominator <= 0, force to terminate");
        }
    // if (logger.isWarnEnabled()){
    // logger.warn("denominator <= 0, give up the current iteration; reset history, and directly jump to next iteration!");
    // }
    // reset();
    // return;
    }
    if (logger.isDebugEnabled()) {
        if (y.size() < 100) {
            logger.debug("y= " + y);
            logger.debug("s= " + s);
        }
        logger.debug("denominator = " + denominator);
        logger.debug("rho = " + rho);
    }
    sQueue.add(s);
    yQueue.add(y);
    rhoQueue.add(rho);
    if (sQueue.size() > m) {
        sQueue.remove();
        yQueue.remove();
        rhoQueue.remove();
    }
    double value = function.getValue();
    terminator.add(value);
    if (logger.isDebugEnabled()) {
        logger.debug("finish one iteration. loss = " + value);
    }
}
Also used : DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) DenseVector(org.apache.mahout.math.DenseVector)

Example 24 with DenseVector

use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.

the class LBFGS method findDirection.

Vector findDirection() {
    Vector g = function.getGradient();
    // todo
    // if (rhoQueue.size()==0){
    // if (logger.isDebugEnabled()){
    // logger.debug("use negative gradient as search direction");
    // }
    // return g.times(-1);
    // }
    // using dense vector is much faster
    Vector q = new DenseVector(g.size());
    q.assign(g);
    Iterator<Double> rhoDesIterator = rhoQueue.descendingIterator();
    Iterator<Vector> sDesIterator = sQueue.descendingIterator();
    Iterator<Vector> yDesIterator = yQueue.descendingIterator();
    LinkedList<Double> alphaQueue = new LinkedList<>();
    while (rhoDesIterator.hasNext()) {
        double rho = rhoDesIterator.next();
        Vector s = sDesIterator.next();
        Vector y = yDesIterator.next();
        double alpha = s.dot(q) * rho;
        alphaQueue.addFirst(alpha);
        // seems no need to use "assign"
        q = q.minus(y.times(alpha));
    }
    double gamma = gamma();
    // use H_k^0 = gamma I
    Vector r = q.times(gamma);
    Iterator<Double> rhoIterator = rhoQueue.iterator();
    Iterator<Vector> sIterator = sQueue.iterator();
    Iterator<Vector> yIterator = yQueue.iterator();
    Iterator<Double> alphaIterator = alphaQueue.iterator();
    while (rhoIterator.hasNext()) {
        double rho = rhoIterator.next();
        Vector s = sIterator.next();
        Vector y = yIterator.next();
        double alpha = alphaIterator.next();
        double beta = y.dot(r) * rho;
        r = r.plus(s.times(alpha - beta));
    }
    return r.times(-1);
}
Also used : DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) DenseVector(org.apache.mahout.math.DenseVector) LinkedList(java.util.LinkedList)

Example 25 with DenseVector

use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.

the class Weights method readObject.

private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
    numFeatures = in.readInt();
    serializableWeights = (double[]) in.readObject();
    weightVector = new DenseVector((numFeatures + 1));
    for (int i = 0; i < serializableWeights.length; i++) {
        weightVector.set(i, serializableWeights[i]);
    }
}
Also used : DenseVector(org.apache.mahout.math.DenseVector)

Aggregations

DenseVector (org.apache.mahout.math.DenseVector)79 Vector (org.apache.mahout.math.Vector)73 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)9 RandomAccessSparseVector (org.apache.mahout.math.RandomAccessSparseVector)8 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)7 SequentialAccessSparseVector (org.apache.mahout.math.SequentialAccessSparseVector)6 Pair (edu.neu.ccs.pyramid.util.Pair)4 List (java.util.List)3 IntStream (java.util.stream.IntStream)3 EnumeratedIntegerDistribution (org.apache.commons.math3.distribution.EnumeratedIntegerDistribution)3 LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)2 DataSet (edu.neu.ccs.pyramid.dataset.DataSet)2 EmpiricalCDF (edu.neu.ccs.pyramid.util.EmpiricalCDF)2 IntegerDistribution (org.apache.commons.math3.distribution.IntegerDistribution)2 MultivariateNormalDistribution (org.apache.commons.math3.distribution.MultivariateNormalDistribution)2 Classifier (edu.neu.ccs.pyramid.classification.Classifier)1 Weights (edu.neu.ccs.pyramid.classification.logistic_regression.Weights)1 RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)1 ConstantRegressor (edu.neu.ccs.pyramid.regression.ConstantRegressor)1 BernoulliDistribution (edu.neu.ccs.pyramid.util.BernoulliDistribution)1