use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CRFF1Loss method calEmpiricalCountForFeature.
private double calEmpiricalCountForFeature(int parameterIndex) {
double empiricalCount = 0.0;
int classIndex = parameterToClass[parameterIndex];
int featureIndex = parameterToFeature[parameterIndex];
if (featureIndex == -1) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
if (dataSet.getMultiLabels()[i].matchClass(classIndex)) {
empiricalCount += 1;
}
}
} else {
Vector column = dataSet.getColumn(featureIndex);
MultiLabel[] multiLabels = dataSet.getMultiLabels();
for (Vector.Element element : column.nonZeroes()) {
int dataIndex = element.index();
double featureValue = element.get();
if (multiLabels[dataIndex].matchClass(classIndex)) {
empiricalCount += featureValue;
}
}
}
return empiricalCount;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CRFInspector method simplePredictionAnalysis.
public static String simplePredictionAnalysis(CMLCRF crf, PluginPredictor<CMLCRF> pluginPredictor, MultiLabelClfDataSet dataSet, int dataPointIndex, double classProbThreshold) {
StringBuilder sb = new StringBuilder();
MultiLabel trueLabels = dataSet.getMultiLabels()[dataPointIndex];
String id = dataSet.getIdTranslator().toExtId(dataPointIndex);
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
double[] combProbs = crf.predictCombinationProbs(dataSet.getRow(dataPointIndex));
double[] classProbs = crf.calClassProbs(combProbs);
MultiLabel predicted = pluginPredictor.predict(dataSet.getRow(dataPointIndex));
List<Integer> classes = new ArrayList<Integer>();
for (int k = 0; k < crf.getNumClasses(); k++) {
if (classProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predicted.matchClass(k)) {
classes.add(k);
}
}
Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
List<Pair<Integer, Double>> list = classes.stream().map(l -> new Pair<Integer, Double>(l, classProbs[l])).sorted(comparator.reversed()).collect(Collectors.toList());
for (Pair<Integer, Double> pair : list) {
int label = pair.getFirst();
double prob = pair.getSecond();
int match = 0;
if (trueLabels.matchClass(label)) {
match = 1;
}
sb.append(id).append("\t").append(labelTranslator.toExtLabel(label)).append("\t").append("single").append("\t").append(prob).append("\t").append(match).append("\n");
}
double probability = 0;
List<MultiLabel> support = crf.getSupportCombinations();
for (int i = 0; i < support.size(); i++) {
MultiLabel candidate = support.get(i);
if (candidate.equals(predicted)) {
probability = combProbs[i];
break;
}
}
List<Integer> predictedList = predicted.getMatchedLabelsOrdered();
sb.append(id).append("\t");
for (int i = 0; i < predictedList.size(); i++) {
sb.append(labelTranslator.toExtLabel(predictedList.get(i)));
if (i != predictedList.size() - 1) {
sb.append(",");
}
}
sb.append("\t");
int setMatch = 0;
if (predicted.equals(trueLabels)) {
setMatch = 1;
}
sb.append("set").append("\t").append(probability).append("\t").append(setMatch).append("\n");
return sb.toString();
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class NoiseOptimizer method updateTransformProb.
private void updateTransformProb(int dataPoint, int comIndex) {
MultiLabel labels = dataSet.getMultiLabels()[dataPoint];
MultiLabel candidate = combinations.get(comIndex);
if (labels.isSubsetOf(candidate)) {
double prod = 1;
for (int l : candidate.getMatchedLabels()) {
if (labels.matchClass(l)) {
prod *= alphas[l];
} else {
prod *= (1 - alphas[l]);
}
}
transformProbs[dataPoint][comIndex] = prod;
} else {
transformProbs[dataPoint][comIndex] = 0;
}
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class InstanceF1Predictor method predict.
@Override
public MultiLabel predict(Vector vector) {
List<MultiLabel> supports = cmlcrf.getSupportCombinations();
double[] probs = cmlcrf.predictCombinationProbs(vector);
GeneralF1Predictor generalF1Predictor = new GeneralF1Predictor();
return generalF1Predictor.predict(numClasses, supports, probs);
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CMLCRFElasticNet method expandData.
private DataSet expandData(int l) {
SequentialSparseDataSet newData = new SequentialSparseDataSet(numData, numParameters, false);
MultiLabel label = supportedCombinations.get(l);
List<Integer> labelPairForL = combinationToLabelPair.get(l);
// TODO: parallelism
for (int i = 0; i < numData; i++) {
// add feature-label feature
for (int y : label.getMatchedLabels()) {
// set bias as 1
newData.setFeatureValue(i, (numFeature + 1) * y, 1.0);
for (Vector.Element element : dataSet.getRow(i).nonZeroes()) {
int index = element.index();
double value = element.get();
newData.setFeatureValue(i, (numFeature + 1) * y + index + 1, value);
}
}
for (int y : labelPairForL) {
newData.setFeatureValue(i, (numWeightsForFeatures + y), 1.0);
}
}
return newData;
}
Aggregations