use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class DBR method predict.
@Override
public MultiLabel predict(Vector vector) {
Vector stage1Predictions = stage1BR.predict(vector).toVectorRandomSparse(stage1BR.getNumClasses());
Vector stage2input;
if (useXStage2) {
stage2input = Vectors.concatenate(stage1Predictions, vector);
} else {
stage2input = new DenseVector(stage1Predictions);
}
return stage2BR.predict(stage2input);
}
use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class BackTrackingLineSearcher method moveAlongDirection.
/**
* move to a new position along the direction
*/
public MoveInfo moveAlongDirection(Vector searchDirection) {
Vector localSearchDir;
if (logger.isDebugEnabled()) {
logger.debug("start line search");
// don't want to show too much; only show it on small problems
if (searchDirection.size() < 100) {
logger.debug("direction=" + searchDirection);
}
}
MoveInfo moveInfo = new MoveInfo();
double stepLength = initialStepLength;
double value = function.getValue();
moveInfo.setOldValue(value);
Vector gradient = function.getGradient();
double product = gradient.dot(searchDirection);
if (product < 0) {
localSearchDir = searchDirection;
} else {
if (logger.isWarnEnabled()) {
logger.warn("Bad search direction! Use negative gradient instead. Product of gradient and search direction = " + product);
}
localSearchDir = gradient.times(-1);
}
Vector initialPosition;
// keep a copy of initial parameters
if (function.getParameters().isDense()) {
initialPosition = new DenseVector(function.getParameters());
} else {
initialPosition = new RandomAccessSparseVector(function.getParameters());
}
while (true) {
Vector step = localSearchDir.times(stepLength);
Vector target = initialPosition.plus(step);
function.setParameters(target);
double targetValue = function.getValue();
if (logger.isDebugEnabled()) {
logger.debug("step length = " + stepLength + ", target value = " + targetValue);
// logger.debug("requirement = "+(value + c*stepLength*product));
}
// todo: if equal ok?
if ((targetValue <= value + c * stepLength * product && value < Double.POSITIVE_INFINITY) || stepLength == 0) {
moveInfo.setStep(step);
moveInfo.setStepLength(stepLength);
moveInfo.setNewValue(targetValue);
break;
}
stepLength *= shrinkage;
}
if (logger.isDebugEnabled()) {
logger.debug("line search done. " + moveInfo);
}
return moveInfo;
}
use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class SupervisedEmbeddingLoss method getGradient.
public Vector getGradient() {
int numData = this.updatedEmbeddingMatrix.getNumDataPoints();
int numFeatures = this.updatedEmbeddingMatrix.getNumFeatures();
int vecSize = numData * numFeatures;
Vector finalGradient = new DenseVector(vecSize);
for (int i = 0; i < numData; i++) {
Vector gradient = new DenseVector(numFeatures);
Vector q_i = this.updatedEmbeddingMatrix.getRow(i);
Vector q_i_orig = this.embeddingMatrix.getRow(i);
gradient = gradient.plus(q_i.minus(q_i_orig).times(2.0 * this.alpha));
for (int j = 0; j < numData; j++) {
Vector q_j = this.updatedEmbeddingMatrix.getRow(j);
double pi_x = this.projMatrix.getColumn(0).dot(q_i);
double pi_y = this.projMatrix.getColumn(1).dot(q_i);
double pj_x = this.projMatrix.getColumn(0).dot(q_j);
double pj_y = this.projMatrix.getColumn(1).dot(q_j);
double p_sq = (pi_x - pj_x) * (pi_x - pj_x) + (pi_y - pj_y) * (pi_y - pj_y);
double d_sq = this.distMatrix.getRow(i).get(j) * this.distMatrix.getRow(i).get(j);
Vector p_dist_vec = new DenseVector(2);
p_dist_vec.set(0, pi_x - pj_x);
p_dist_vec.set(1, pi_y - pj_y);
Vector tmp = new DenseVector(this.projMatrix.getNumDataPoints());
for (int k = 0; k < this.projMatrix.getNumDataPoints(); k++) {
tmp.set(k, this.projMatrix.getRow(k).dot(p_dist_vec));
}
gradient = gradient.plus(tmp.times(4.0 * this.beta * (p_sq - d_sq)));
}
for (int j = 0; j < numFeatures; j++) {
finalGradient.set(i * numFeatures + j, gradient.get(j));
}
}
return finalGradient;
}
use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class PlattScaling method transform.
public double transform(double uncalibrated) {
Vector vector = new DenseVector(1);
vector.set(0, uncalibrated);
return logisticRegression.predictClassProb(vector, 1);
}
use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class KMeans method updateCenters.
private void updateCenters(int k) {
Vector center = new DenseVector(dataSet.getNumFeatures());
double count = 0;
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
if (assignments[i] == k) {
Vector instance = dataSet.getRow(i);
for (int j = 0; j < instance.size(); j++) {
center.set(j, center.get(j) + instance.get(j));
}
count += 1;
}
}
center = center.divide(count);
centers[k] = center;
System.out.println("update the centroid of cluster " + (k + 1) + " based on " + (int) count + " instances in the cluster");
}
Aggregations