use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.
the class LogisticRegressionInspector method analyzePrediction.
public static PredictionAnalysis analyzePrediction(LogisticRegression logisticRegression, ClfDataSet dataSet, int dataPointIndex, int limit) {
PredictionAnalysis predictionAnalysis = new PredictionAnalysis();
IdTranslator idTranslator = dataSet.getIdTranslator();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
predictionAnalysis.setInternalId(dataPointIndex).setId(idTranslator.toExtId(dataPointIndex)).setInternalLabel(dataSet.getLabels()[dataPointIndex]).setLabel(labelTranslator.toExtLabel(dataSet.getLabels()[dataPointIndex]));
int prediction = logisticRegression.predict(dataSet.getRow(dataPointIndex));
predictionAnalysis.setInternalPrediction(prediction);
predictionAnalysis.setPrediction(labelTranslator.toExtLabel(prediction));
double[] probs = logisticRegression.predictClassProbs(dataSet.getRow(dataPointIndex));
List<ClassProbability> classProbabilities = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassProbability classProbability = new ClassProbability(k, labelTranslator.toExtLabel(k), probs[k]);
classProbabilities.add(classProbability);
}
predictionAnalysis.setClassProbabilities(classProbabilities);
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassScoreCalculation classScoreCalculation = decisionProcess(logisticRegression, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.
the class LogisticRegressionInspector method topFeatures.
public static TopFeatures topFeatures(LogisticRegression logisticRegression, DataSet dataSet, int classIndex, int limit) {
FeatureList featureList = logisticRegression.getFeatureList();
Vector weights = logisticRegression.getWeights().getWeightsWithoutBiasForClass(classIndex);
Comparator<FeatureUtility> comparator = Comparator.comparing(featureUtility -> Math.abs(featureUtility.getUtility()));
List<FeatureUtility> list = IntStream.range(0, weights.size()).parallel().mapToObj(i -> {
Vector column = dataSet.getColumn(i);
if (column.getNumNonZeroElements() == 0) {
return new FeatureUtility(featureList.get(i)).setUtility(0);
}
double weight = weights.get(i);
double sum = 0;
for (Vector.Element element : column.nonZeroes()) {
sum += weight * element.get();
}
sum /= column.getNumNonZeroElements();
return new FeatureUtility(featureList.get(i)).setUtility(sum);
}).sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
TopFeatures topFeatures = new TopFeatures();
topFeatures.setTopFeatures(list.stream().map(FeatureUtility::getFeature).collect(Collectors.toList()));
topFeatures.setUtilities(list.stream().map(FeatureUtility::getUtility).collect(Collectors.toList()));
topFeatures.setClassIndex(classIndex);
LabelTranslator labelTranslator = logisticRegression.getLabelTranslator();
topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
return topFeatures;
}
use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.
the class CRFInspector method simplePredictionAnalysis.
public static String simplePredictionAnalysis(CMLCRF crf, PluginPredictor<CMLCRF> pluginPredictor, MultiLabelClfDataSet dataSet, int dataPointIndex, double classProbThreshold) {
StringBuilder sb = new StringBuilder();
MultiLabel trueLabels = dataSet.getMultiLabels()[dataPointIndex];
String id = dataSet.getIdTranslator().toExtId(dataPointIndex);
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
double[] combProbs = crf.predictCombinationProbs(dataSet.getRow(dataPointIndex));
double[] classProbs = crf.calClassProbs(combProbs);
MultiLabel predicted = pluginPredictor.predict(dataSet.getRow(dataPointIndex));
List<Integer> classes = new ArrayList<Integer>();
for (int k = 0; k < crf.getNumClasses(); k++) {
if (classProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predicted.matchClass(k)) {
classes.add(k);
}
}
Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
List<Pair<Integer, Double>> list = classes.stream().map(l -> new Pair<Integer, Double>(l, classProbs[l])).sorted(comparator.reversed()).collect(Collectors.toList());
for (Pair<Integer, Double> pair : list) {
int label = pair.getFirst();
double prob = pair.getSecond();
int match = 0;
if (trueLabels.matchClass(label)) {
match = 1;
}
sb.append(id).append("\t").append(labelTranslator.toExtLabel(label)).append("\t").append("single").append("\t").append(prob).append("\t").append(match).append("\n");
}
double probability = 0;
List<MultiLabel> support = crf.getSupportCombinations();
for (int i = 0; i < support.size(); i++) {
MultiLabel candidate = support.get(i);
if (candidate.equals(predicted)) {
probability = combProbs[i];
break;
}
}
List<Integer> predictedList = predicted.getMatchedLabelsOrdered();
sb.append(id).append("\t");
for (int i = 0; i < predictedList.size(); i++) {
sb.append(labelTranslator.toExtLabel(predictedList.get(i)));
if (i != predictedList.size() - 1) {
sb.append(",");
}
}
sb.append("\t");
int setMatch = 0;
if (predicted.equals(trueLabels)) {
setMatch = 1;
}
sb.append("set").append("\t").append(probability).append("\t").append(setMatch).append("\n");
return sb.toString();
}
use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.
the class LogisticRegressionInspector method topFeatures.
// todo if featureList are on different scales, weights are not comparable
// public static List<FeatureUtility> topFeatures(LogisticRegression logisticRegression,
// int k){
// FeatureList featureList = logisticRegression.getFeatureList();
// Vector weights = logisticRegression.getWeights().getWeightsWithoutBiasForClass(k);
// Comparator<FeatureUtility> comparator = Comparator.comparing(FeatureUtility::getUtility);
// List<FeatureUtility> list = IntStream.range(0,weights.size())
// .mapToObj(i -> new FeatureUtility(featureList.get(i)).setUtility(weights.get(i)))
// .filter(featureUtility -> featureUtility.getUtility()>0)
// .sorted(comparator.reversed())
// .collect(Collectors.toList());
// IntStream.range(0,list.size()).forEach(i-> list.get(i).setRank(i));
// return list;
// }
public static TopFeatures topFeatures(LogisticRegression logisticRegression, int classIndex, int limit) {
FeatureList featureList = logisticRegression.getFeatureList();
Vector weights = logisticRegression.getWeights().getWeightsWithoutBiasForClass(classIndex);
Comparator<FeatureUtility> comparator = Comparator.comparing(featureUtility -> Math.abs(featureUtility.getUtility()));
List<FeatureUtility> list = IntStream.range(0, weights.size()).mapToObj(i -> new FeatureUtility(featureList.get(i)).setUtility(weights.get(i))).sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
TopFeatures topFeatures = new TopFeatures();
topFeatures.setTopFeatures(list.stream().map(FeatureUtility::getFeature).collect(Collectors.toList()));
topFeatures.setUtilities(list.stream().map(FeatureUtility::getUtility).collect(Collectors.toList()));
topFeatures.setClassIndex(classIndex);
LabelTranslator labelTranslator = logisticRegression.getLabelTranslator();
topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
return topFeatures;
}
use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.
the class LKBInspector method analyzePrediction.
// todo speed up
public static PredictionAnalysis analyzePrediction(LKBoost boosting, ClfDataSet dataSet, int dataPointIndex, int limit) {
PredictionAnalysis predictionAnalysis = new PredictionAnalysis();
IdTranslator idTranslator = dataSet.getIdTranslator();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
predictionAnalysis.setInternalId(dataPointIndex).setId(idTranslator.toExtId(dataPointIndex)).setInternalLabel(dataSet.getLabels()[dataPointIndex]).setLabel(labelTranslator.toExtLabel(dataSet.getLabels()[dataPointIndex]));
int prediction = boosting.predict(dataSet.getRow(dataPointIndex));
predictionAnalysis.setInternalPrediction(prediction);
predictionAnalysis.setPrediction(labelTranslator.toExtLabel(prediction));
double[] probs = boosting.predictClassProbs(dataSet.getRow(dataPointIndex));
List<ClassProbability> classProbabilities = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassProbability classProbability = new ClassProbability(k, labelTranslator.toExtLabel(k), probs[k]);
classProbabilities.add(classProbability);
}
predictionAnalysis.setClassProbabilities(classProbabilities);
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
return predictionAnalysis;
}
Aggregations