Search in sources :

Example 1 with Weights

use of edu.neu.ccs.pyramid.classification.logistic_regression.Weights in project pyramid by cheng-li.

the class CBMInspector method getMean.

public static Weights getMean(CBM bmm, int label) {
    int numClusters = bmm.getNumComponents();
    int length = ((LogisticRegression) bmm.getBinaryClassifiers()[0][0]).getWeights().getAllWeights().size();
    int numFeatures = ((LogisticRegression) bmm.getBinaryClassifiers()[0][0]).getNumFeatures();
    Vector mean = new DenseVector(length);
    for (int k = 0; k < numClusters; k++) {
        mean = mean.plus(((LogisticRegression) bmm.getBinaryClassifiers()[k][label]).getWeights().getAllWeights());
    }
    mean = mean.divide(numClusters);
    return new Weights(2, numFeatures, mean);
}
Also used : Weights(edu.neu.ccs.pyramid.classification.logistic_regression.Weights) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) DenseVector(org.apache.mahout.math.DenseVector)

Aggregations

LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)1 Weights (edu.neu.ccs.pyramid.classification.logistic_regression.Weights)1 DenseVector (org.apache.mahout.math.DenseVector)1 Vector (org.apache.mahout.math.Vector)1