use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CRFLoss method calEmpiricalCountForFeature.
private double calEmpiricalCountForFeature(int parameterIndex) {
double empiricalCount = 0.0;
int classIndex = parameterToClass[parameterIndex];
int featureIndex = parameterToFeature[parameterIndex];
if (featureIndex == -1) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
if (dataSet.getMultiLabels()[i].matchClass(classIndex)) {
empiricalCount += 1;
}
}
} else {
Vector column = dataSet.getColumn(featureIndex);
MultiLabel[] multiLabels = dataSet.getMultiLabels();
for (Vector.Element element : column.nonZeroes()) {
int dataIndex = element.index();
double featureValue = element.get();
if (multiLabels[dataIndex].matchClass(classIndex)) {
empiricalCount += featureValue;
}
}
}
return empiricalCount;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CRFLoss method calEmpiricalCountForLabelPair.
private double calEmpiricalCountForLabelPair(int parameterIndex) {
double empiricalCount = 0.0;
int start = parameterIndex - numWeightsForFeatures;
int l1 = parameterToL1[start];
int l2 = parameterToL2[start];
int featureCase = start % 4;
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
MultiLabel label = dataSet.getMultiLabels()[i];
switch(featureCase) {
// both l1, l2 equal 0;
case 0:
if (!label.matchClass(l1) && !label.matchClass(l2))
empiricalCount += 1.0;
break;
// l1 = 1; l2 = 0;
case 1:
if (label.matchClass(l1) && !label.matchClass(l2))
empiricalCount += 1.0;
break;
// l1 = 0; l2 = 1;
case 2:
if (!label.matchClass(l1) && label.matchClass(l2))
empiricalCount += 1.0;
break;
// l1 = 1; l2 = 1;
case 3:
if (label.matchClass(l1) && label.matchClass(l2))
empiricalCount += 1.0;
break;
default:
throw new RuntimeException("feature case :" + featureCase + " failed.");
}
}
return empiricalCount;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class InstanceF1Predictor method showPredictBySupport.
public GeneralF1Predictor.Analysis showPredictBySupport(Vector vector, MultiLabel truth) {
// System.out.println("support procedure");
List<MultiLabel> support = cmlcrf.getSupportCombinations();
double[] probs = cmlcrf.predictCombinationProbs(vector);
GeneralF1Predictor generalF1Predictor = new GeneralF1Predictor();
MultiLabel prediction = generalF1Predictor.predict(cmlcrf.getNumClasses(), support, probs);
GeneralF1Predictor.Analysis analysis = GeneralF1Predictor.showSupportPrediction(support, probs, truth, prediction, cmlcrf.getNumClasses());
return analysis;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CMLCRFElasticNet method calEmpiricalCountForFeature.
private double calEmpiricalCountForFeature(int parameterIndex) {
double empiricalCount = 0.0;
int classIndex = parameterToClass[parameterIndex];
int featureIndex = parameterToFeature[parameterIndex];
if (featureIndex == -1) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
if (dataSet.getMultiLabels()[i].matchClass(classIndex)) {
empiricalCount += 1;
}
}
} else {
Vector column = dataSet.getColumn(featureIndex);
MultiLabel[] multiLabels = dataSet.getMultiLabels();
for (Vector.Element element : column.nonZeroes()) {
int dataIndex = element.index();
double featureValue = element.get();
if (multiLabels[dataIndex].matchClass(classIndex)) {
empiricalCount += featureValue;
}
}
}
return empiricalCount;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class MultiLabelSynthesizer method crfArgmax.
public static MultiLabelClfDataSet crfArgmax() {
int numData = 1000;
int numClass = 4;
int numFeature = 10;
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
List<MultiLabel> support = Enumerator.enumerate(numClass);
CMLCRF cmlcrf = new CMLCRF(numClass, numFeature, support);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, -10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -10);
// generate features
for (int i = 0; i < numData; i++) {
for (int j = 0; j < numFeature; j++) {
dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1, 1));
}
}
SubsetAccPredictor predictor = new SubsetAccPredictor(cmlcrf);
// assign labels
for (int i = 0; i < numData; i++) {
MultiLabel label = predictor.predict(dataSet.getRow(i));
dataSet.setLabels(i, label);
}
return dataSet;
}
Aggregations