use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class AugmentedLRLoss method penaltyGradient.
private Vector penaltyGradient() {
Vector featureWeights = augmentedLR.featureWeights();
Vector componentWeights = augmentedLR.componentWeights();
Vector penaltyGradient = new DenseVector(augmentedLR.getAllWeights().size());
for (int d = 0; d < numFeatures; d++) {
penaltyGradient.set(d, featureWeights.get(d) / featureWeightVariance);
}
for (int k = 0; k < numComponents; k++) {
penaltyGradient.set(numFeatures + k, componentWeights.get(k) / componentWeightVariance);
}
return penaltyGradient;
}
use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class MLACPlattScaling method predictAssignmentProb.
@Override
public double predictAssignmentProb(Vector vector, MultiLabel assignment) {
double[] scores = scoreEstimator.predictClassScores(vector);
Vector scoreVector = new DenseVector(scores.length);
for (int i = 0; i < scores.length; i++) {
scoreVector.set(i, scores[i]);
}
return logisticRegression.predictAssignmentProb(scoreVector, assignment);
}
use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class MLFlatScaling method predictClassProbs.
@Override
public double[] predictClassProbs(Vector vector) {
double[] scores = scoreEstimator.predictClassScores(vector);
double[] probs = new double[scores.length];
for (int k = 0; k < scores.length; k++) {
Vector scoreFeatureVector = new DenseVector(1);
scoreFeatureVector.set(0, scores[k]);
probs[k] = logisticRegression.predictClassProb(scoreFeatureVector, 1);
}
return probs;
}
use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class RegressionTreeTest method test4.
private static void test4() {
Node a = new Node();
a.setFeatureIndex(0);
a.setThreshold(0.0);
a.setLeftProb(0.3);
a.setRightProb(0.7);
Node b = new Node();
b.setFeatureIndex(1);
b.setThreshold(0.1);
b.setLeftProb(0.8);
b.setRightProb(0.2);
Node c = new Node();
c.setFeatureIndex(2);
c.setThreshold(0.2);
c.setLeftProb(0.1);
c.setRightProb(0.9);
Node d = new Node();
d.setLeaf(true);
d.setValue(1);
Node e = new Node();
e.setLeaf(true);
e.setValue(2);
Node f = new Node();
f.setLeaf(true);
f.setValue(3);
Node g = new Node();
g.setLeaf(true);
g.setValue(4);
a.setLeftChild(b);
a.setRightChild(c);
b.setLeftChild(d);
b.setRightChild(e);
c.setLeftChild(f);
c.setRightChild(g);
RegressionTree tree = new RegressionTree();
tree.root = a;
tree.leaves.add(d);
tree.leaves.add(e);
tree.leaves.add(f);
tree.leaves.add(g);
Vector vector1 = new DenseVector(3);
vector1.set(0, 1);
vector1.set(1, Double.NaN);
vector1.set(2, Double.NaN);
System.out.println(tree.probability(vector1, a));
System.out.println(tree.probability(vector1, b));
System.out.println(tree.probability(vector1, c));
System.out.println(tree.probability(vector1, d));
System.out.println(tree.probability(vector1, e));
System.out.println(tree.probability(vector1, f));
System.out.println(tree.probability(vector1, g));
System.out.println(tree.predict(vector1));
}
use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class RegressionTreeTest method test3.
private static void test3() {
Node a = new Node();
a.setFeatureIndex(0);
a.setThreshold(0.0);
a.setLeftProb(0.3);
a.setRightProb(0.7);
Node b = new Node();
b.setFeatureIndex(1);
b.setThreshold(0.1);
b.setLeftProb(0.8);
b.setRightProb(0.2);
Node c = new Node();
c.setFeatureIndex(2);
c.setThreshold(0.2);
c.setLeftProb(0.1);
c.setRightProb(0.9);
Node d = new Node();
d.setLeaf(true);
d.setValue(1);
Node e = new Node();
e.setLeaf(true);
e.setValue(2);
Node f = new Node();
f.setLeaf(true);
f.setValue(3);
Node g = new Node();
g.setLeaf(true);
g.setValue(4);
a.setLeftChild(b);
a.setRightChild(c);
b.setLeftChild(d);
b.setRightChild(e);
c.setLeftChild(f);
c.setRightChild(g);
RegressionTree tree = new RegressionTree();
tree.root = a;
tree.leaves.add(d);
tree.leaves.add(e);
tree.leaves.add(f);
tree.leaves.add(g);
Vector vector1 = new DenseVector(3);
vector1.set(0, -1);
vector1.set(1, 0.2);
vector1.set(2, Double.NaN);
System.out.println(tree.probability(vector1, a));
System.out.println(tree.probability(vector1, b));
System.out.println(tree.probability(vector1, c));
System.out.println(tree.probability(vector1, d));
System.out.println(tree.probability(vector1, e));
System.out.println(tree.probability(vector1, f));
System.out.println(tree.probability(vector1, g));
System.out.println(tree.predict(vector1));
}
Aggregations