use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class MLLogisticRegression method predictClassProbs.
/**
* expensive operation
* @param vector
* @return
*/
public double[] predictClassProbs(Vector vector) {
double[] assignmentProbs = predictAssignmentProbs(vector);
double[] classProbs = new double[numClasses];
for (int a = 0; a < assignments.size(); a++) {
MultiLabel assignment = assignments.get(a);
double prob = assignmentProbs[a];
for (Integer label : assignment.getMatchedLabels()) {
double oldProb = classProbs[label];
classProbs[label] = oldProb + prob;
}
}
return classProbs;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class MLLogisticRegression method calClassProbs.
double[] calClassProbs(double[] assignmentProbs) {
double[] classProbs = new double[numClasses];
int numAssignments = assignments.size();
for (int a = 0; a < numAssignments; a++) {
MultiLabel assignment = assignments.get(a);
double prob = assignmentProbs[a];
for (Integer label : assignment.getMatchedLabels()) {
classProbs[label] += prob;
}
}
return classProbs;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class NoiseOptimizer method updateAlpha.
private void updateAlpha(int classIndex) {
double numerator = 0;
double denominator = 0;
MultiLabel[] multiLabels = dataSet.getMultiLabels();
for (int c = 0; c < combinations.size(); c++) {
if (combinations.get(c).matchClass(classIndex)) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
MultiLabel multiLabel = multiLabels[i];
denominator += targets[i][c];
if (multiLabel.matchClass(classIndex)) {
numerator += targets[i][c];
}
}
}
}
alphas[classIndex] = SafeDivide.divide(numerator, denominator, 1);
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class HammingLossTest method main.
/**
* Test example comes from:
* http://scikit-learn.org/stable/modules/model_evaluation.html#hamming-loss
* @param args
*/
public static void main(String[] args) {
MultiLabel[] labels = new MultiLabel[1];
MultiLabel[] predictions = new MultiLabel[1];
labels[0] = new MultiLabel();
predictions[0] = new MultiLabel();
labels[0].addLabel(1);
labels[0].addLabel(2);
labels[0].addLabel(3);
labels[0].addLabel(4);
predictions[0].addLabel(2);
predictions[0].addLabel(2);
predictions[0].addLabel(3);
predictions[0].addLabel(4);
System.out.println("Expected (value=0.25) - Output: " + HammingLoss.hammingLoss(labels, predictions, 4));
MultiLabel[] labels1 = new MultiLabel[2];
MultiLabel[] predictions1 = new MultiLabel[2];
labels1[0] = new MultiLabel();
labels1[1] = new MultiLabel();
predictions1[0] = new MultiLabel();
predictions1[1] = new MultiLabel();
labels1[0].addLabel(0);
labels1[0].addLabel(1);
labels1[1].addLabel(1);
predictions1[0].addLabel(0);
predictions1[1].addLabel(0);
System.out.println("Expected (value=0.75) - Output: " + HammingLoss.hammingLoss(labels1, predictions1, 2));
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class MultiLabelSynthesizer method flipOne.
public static MultiLabelClfDataSet flipOne(int numData, int numFeature, int numClass) {
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
// generate weights
Vector[] weights = new Vector[numClass];
for (int k = 0; k < numClass; k++) {
Vector vector = new DenseVector(numFeature);
for (int j = 0; j < numFeature; j++) {
vector.set(j, Sampling.doubleUniform(-1, 1));
}
weights[k] = vector;
}
// generate features
for (int i = 0; i < numData; i++) {
for (int j = 0; j < numFeature; j++) {
dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1, 1));
}
}
// assign labels
for (int i = 0; i < numData; i++) {
for (int k = 0; k < numClass; k++) {
double dot = weights[k].dot(dataSet.getRow(i));
if (dot >= 0) {
dataSet.addLabel(i, k);
}
}
}
// flip
for (int i = 0; i < numData; i++) {
int toChange = Sampling.intUniform(0, numClass - 1);
MultiLabel label = dataSet.getMultiLabels()[i];
if (label.matchClass(toChange)) {
label.removeLabel(toChange);
} else {
label.addLabel(toChange);
}
}
return dataSet;
}
Aggregations