use of edu.neu.ccs.pyramid.regression.ClassScoreCalculation in project pyramid by cheng-li.
the class MLLogisticRegressionInspector method analyzePrediction.
public static MultiLabelPredictionAnalysis analyzePrediction(MLLogisticRegression logisticRegression, MultiLabelClfDataSet dataSet, int dataPointIndex, List<Integer> classes, int limit) {
MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
IdTranslator idTranslator = dataSet.getIdTranslator();
predictionAnalysis.setInternalId(dataPointIndex);
predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setLabels(labels);
predictionAnalysis.setProbForTrueLabels(Double.NaN);
MultiLabel predictedLabels = logisticRegression.predict(dataSet.getRow(dataPointIndex));
List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
predictionAnalysis.setInternalPrediction(internalPrediction);
List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setPrediction(prediction);
predictionAnalysis.setProbForPredictedLabels(Double.NaN);
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k : classes) {
ClassScoreCalculation classScoreCalculation = decisionProcess(logisticRegression, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.regression.ClassScoreCalculation in project pyramid by cheng-li.
the class MLLogisticRegressionInspector method decisionProcess.
public static ClassScoreCalculation decisionProcess(MLLogisticRegression logisticRegression, LabelTranslator labelTranslator, Vector vector, int classIndex, int limit) {
ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), logisticRegression.predictClassScore(vector, classIndex));
List<LinearRule> linearRules = new ArrayList<>();
Rule bias = new ConstantRule(logisticRegression.getWeights().getBiasForClass(classIndex));
classScoreCalculation.addRule(bias);
for (int j = 0; j < logisticRegression.getNumFeatures(); j++) {
Feature feature = logisticRegression.getFeatureList().get(j);
double weight = logisticRegression.getWeights().getWeightsWithoutBiasForClass(classIndex).get(j);
double featureValue = vector.get(j);
double score = weight * featureValue;
LinearRule rule = new LinearRule();
rule.setFeature(feature);
rule.setFeatureValue(featureValue);
rule.setScore(score);
rule.setWeight(weight);
linearRules.add(rule);
}
Comparator<LinearRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
List<LinearRule> sorted = linearRules.stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
for (LinearRule linearRule : sorted) {
classScoreCalculation.addRule(linearRule);
}
return classScoreCalculation;
}
use of edu.neu.ccs.pyramid.regression.ClassScoreCalculation in project pyramid by cheng-li.
the class LKTreeBoostTest method spam_load.
static void spam_load() throws Exception {
System.out.println("loading ensemble");
LKBoost lkBoost = LKBoost.deserialize(new File(TMP, "/LKTreeBoostTest/ensemble.ser"));
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.CLF_DENSE, true);
System.out.println("test data:");
System.out.println(dataSet.getMetaInfo());
System.out.println(dataSet.getLabelTranslator());
double accuracy = Accuracy.accuracy(lkBoost, dataSet);
System.out.println(accuracy);
System.out.println("auc = " + AUC.auc(lkBoost, dataSet));
ConfusionMatrix confusionMatrix = new ConfusionMatrix(lkBoost, dataSet);
System.out.println("confusion matrix:");
System.out.println(confusionMatrix.printWithExtLabels());
System.out.println("top featureList for class 0");
System.out.println(LKBInspector.topFeatures(lkBoost, 0));
System.out.println(new PerClassMeasures(confusionMatrix, 0));
System.out.println(new PerClassMeasures(confusionMatrix, 1));
System.out.println("macor-averaged:");
System.out.println(new MacroAveragedMeasures(confusionMatrix));
// System.out.println(lkTreeBoost);
int[] labels = dataSet.getLabels();
List<double[]> classProbs = lkBoost.predictClassProbs(dataSet);
for (int k = 0; k < dataSet.getNumClasses(); k++) {
int numMatches = 0;
double sumProbs = 0;
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
if (labels[i] == k) {
numMatches += 1;
}
sumProbs += classProbs.get(i)[k];
}
System.out.println("for class " + k);
System.out.println("number of matches =" + numMatches);
System.out.println("sum of probs = " + sumProbs);
}
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
ClassScoreCalculation classScoreCalculation = LKBInspector.decisionProcess(lkBoost, labelTranslator, dataSet.getRow(0), 0, 10);
ObjectMapper mapper = new ObjectMapper();
mapper.writeValue(new File(TMP, "score_calculation.json"), classScoreCalculation);
PredictionAnalysis predictionAnalysis = LKBInspector.analyzePrediction(lkBoost, dataSet, 0, 10);
ObjectMapper mapper1 = new ObjectMapper();
mapper1.writeValue(new File(TMP, "prediction_analysis.json"), predictionAnalysis);
}
Aggregations