use of edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor in project pyramid by cheng-li.
the class App2 method report.
static void report(Config config, String dataName, Logger logger) throws Exception {
logger.info("generating reports for data set " + dataName);
String output = config.getString("output.folder");
String modelName = "model_app3";
File analysisFolder = new File(new File(output, "reports_app3"), dataName + "_reports");
analysisFolder.mkdirs();
FileUtils.cleanDirectory(analysisFolder);
IMLGradientBoosting boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
String predictTarget = config.getString("predict.target");
PluginPredictor<IMLGradientBoosting> pluginPredictorTmp = null;
switch(predictTarget) {
case "subsetAccuracy":
pluginPredictorTmp = new SubsetAccPredictor(boosting);
break;
case "hammingLoss":
pluginPredictorTmp = new HammingPredictor(boosting);
break;
case "instanceFMeasure":
pluginPredictorTmp = new InstanceF1Predictor(boosting);
break;
case "macroFMeasure":
TunedMarginalClassifier tunedMarginalClassifier = (TunedMarginalClassifier) Serialization.deserialize(new File(output, "predictor_macro_f"));
pluginPredictorTmp = new MacroF1Predictor(boosting, tunedMarginalClassifier);
break;
default:
throw new IllegalArgumentException("unknown prediction target measure " + predictTarget);
}
// just to make Lambda expressions happy
final PluginPredictor<IMLGradientBoosting> pluginPredictor = pluginPredictorTmp;
MultiLabelClfDataSet dataSet = loadData(config, dataName);
MLMeasures mlMeasures = new MLMeasures(pluginPredictor, dataSet);
mlMeasures.getMacroAverage().setLabelTranslator(boosting.getLabelTranslator());
logger.info("performance on dataset " + dataName);
logger.info(mlMeasures.toString());
boolean simpleCSV = true;
if (simpleCSV) {
logger.info("start generating simple CSV report");
double probThreshold = config.getDouble("report.classProbThreshold");
File csv = new File(analysisFolder, "report.csv");
List<String> strs = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToObj(i -> IMLGBInspector.simplePredictionAnalysis(boosting, pluginPredictor, dataSet, i, probThreshold)).collect(Collectors.toList());
try (BufferedWriter bw = new BufferedWriter(new FileWriter(csv))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
String str = strs.get(i);
bw.write(str);
}
}
logger.info("finish generating simple CSV report");
}
boolean rulesToJson = config.getBoolean("report.showPredictionDetail");
if (rulesToJson) {
logger.info("start writing rules to json");
int ruleLimit = config.getInt("report.rule.limit");
int numDocsPerFile = config.getInt("report.numDocsPerFile");
int numFiles = (int) Math.ceil((double) dataSet.getNumDataPoints() / numDocsPerFile);
double probThreshold = config.getDouble("report.classProbThreshold");
int labelSetLimit = config.getInt("report.labelSetLimit");
IntStream.range(0, numFiles).forEach(i -> {
int start = i * numDocsPerFile;
int end = start + numDocsPerFile;
List<MultiLabelPredictionAnalysis> partition = IntStream.range(start, Math.min(end, dataSet.getNumDataPoints())).parallel().mapToObj(a -> IMLGBInspector.analyzePrediction(boosting, pluginPredictor, dataSet, a, ruleLimit, labelSetLimit, probThreshold)).collect(Collectors.toList());
ObjectMapper mapper = new ObjectMapper();
String file = "report_" + (i + 1) + ".json";
try {
mapper.writeValue(new File(analysisFolder, file), partition);
} catch (IOException e) {
e.printStackTrace();
}
logger.info("progress = " + Progress.percentage(i + 1, numFiles));
});
logger.info("finish writing rules to json");
}
boolean dataInfoToJson = true;
if (dataInfoToJson) {
logger.info("start writing data info to json");
Set<String> modelLabels = IntStream.range(0, boosting.getNumClasses()).mapToObj(i -> boosting.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
Set<String> dataSetLabels = DataSetUtil.gatherLabels(dataSet).stream().map(i -> dataSet.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
JsonGenerator jsonGenerator = new JsonFactory().createGenerator(new File(analysisFolder, "data_info.json"), JsonEncoding.UTF8);
jsonGenerator.writeStartObject();
jsonGenerator.writeStringField("dataSet", dataName);
jsonGenerator.writeNumberField("numClassesInModel", boosting.getNumClasses());
jsonGenerator.writeNumberField("numClassesInDataSet", dataSetLabels.size());
jsonGenerator.writeNumberField("numClassesInModelDataSetCombined", dataSet.getNumClasses());
Set<String> modelNotDataLabels = SetUtil.complement(modelLabels, dataSetLabels);
Set<String> dataNotModelLabels = SetUtil.complement(dataSetLabels, modelLabels);
jsonGenerator.writeNumberField("numClassesInDataSetButNotModel", dataNotModelLabels.size());
jsonGenerator.writeNumberField("numClassesInModelButNotDataSet", modelNotDataLabels.size());
jsonGenerator.writeArrayFieldStart("classesInDataSetButNotModel");
for (String label : dataNotModelLabels) {
jsonGenerator.writeObject(label);
}
jsonGenerator.writeEndArray();
jsonGenerator.writeArrayFieldStart("classesInModelButNotDataSet");
for (String label : modelNotDataLabels) {
jsonGenerator.writeObject(label);
}
jsonGenerator.writeEndArray();
jsonGenerator.writeNumberField("labelCardinality", dataSet.labelCardinality());
jsonGenerator.writeEndObject();
jsonGenerator.close();
logger.info("finish writing data info to json");
}
boolean modelConfigToJson = true;
if (modelConfigToJson) {
logger.info("start writing model config to json");
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "model_config.json"), config);
logger.info("finish writing model config to json");
}
boolean dataConfigToJson = true;
if (dataConfigToJson) {
logger.info("start writing data config to json");
File dataConfigFile = Paths.get(config.getString("input.folder"), "data_sets", dataName, "data_config.json").toFile();
if (dataConfigFile.exists()) {
FileUtils.copyFileToDirectory(dataConfigFile, analysisFolder);
}
logger.info("finish writing data config to json");
}
boolean performanceToJson = true;
if (performanceToJson) {
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "performance.json"), mlMeasures);
}
boolean individualPerformance = true;
if (individualPerformance) {
logger.info("start writing individual label performance to json");
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "individual_performance.json"), mlMeasures.getMacroAverage());
logger.info("finish writing individual label performance to json");
}
logger.info("reports generated");
}
use of edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor in project pyramid by cheng-li.
the class IMLGBInspector method analyzePrediction.
//todo speed up
public static MultiLabelPredictionAnalysis analyzePrediction(IMLGradientBoosting boosting, PluginPredictor<IMLGradientBoosting> pluginPredictor, MultiLabelClfDataSet dataSet, int dataPointIndex, int ruleLimit, int labelSetLimit, double classProbThreshold) {
MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
IdTranslator idTranslator = dataSet.getIdTranslator();
predictionAnalysis.setInternalId(dataPointIndex);
predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setLabels(labels);
if (pluginPredictor instanceof SubsetAccPredictor || pluginPredictor instanceof InstanceF1Predictor) {
predictionAnalysis.setProbForTrueLabels(boosting.predictAssignmentProbWithConstraint(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]));
}
if (pluginPredictor instanceof HammingPredictor || pluginPredictor instanceof MacroF1Predictor) {
predictionAnalysis.setProbForTrueLabels(boosting.predictAssignmentProbWithoutConstraint(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]));
}
MultiLabel predictedLabels = pluginPredictor.predict(dataSet.getRow(dataPointIndex));
List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
predictionAnalysis.setInternalPrediction(internalPrediction);
List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setPrediction(prediction);
if (pluginPredictor instanceof SubsetAccPredictor || pluginPredictor instanceof InstanceF1Predictor) {
predictionAnalysis.setProbForPredictedLabels(boosting.predictAssignmentProbWithConstraint(dataSet.getRow(dataPointIndex), predictedLabels));
}
if (pluginPredictor instanceof HammingPredictor || pluginPredictor instanceof MacroF1Predictor) {
predictionAnalysis.setProbForPredictedLabels(boosting.predictAssignmentProbWithoutConstraint(dataSet.getRow(dataPointIndex), predictedLabels));
}
double[] classProbs = boosting.predictClassProbs(dataSet.getRow(dataPointIndex));
List<Integer> classes = new ArrayList<Integer>();
for (int k = 0; k < boosting.getNumClasses(); k++) {
if (classProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predictedLabels.matchClass(k)) {
classes.add(k);
}
}
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k : classes) {
ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, labelTranslator, dataSet.getRow(dataPointIndex), k, ruleLimit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
List<MultiLabelPredictionAnalysis.ClassRankInfo> labelRanking = classes.stream().map(label -> {
MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
rankInfo.setClassIndex(label);
rankInfo.setClassName(labelTranslator.toExtLabel(label));
rankInfo.setProb(boosting.predictClassProb(dataSet.getRow(dataPointIndex), label));
return rankInfo;
}).collect(Collectors.toList());
predictionAnalysis.setPredictedRanking(labelRanking);
List<MultiLabelPredictionAnalysis.LabelSetProbInfo> labelSetRanking = null;
if (pluginPredictor instanceof SubsetAccPredictor || pluginPredictor instanceof InstanceF1Predictor) {
double[] labelSetProbs = boosting.predictAllAssignmentProbsWithConstraint(dataSet.getRow(dataPointIndex));
labelSetRanking = IntStream.range(0, boosting.getAssignments().size()).mapToObj(i -> {
MultiLabel multiLabel = boosting.getAssignments().get(i);
double setProb = labelSetProbs[i];
MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
return labelSetProbInfo;
}).sorted(Comparator.comparing(MultiLabelPredictionAnalysis.LabelSetProbInfo::getProbability).reversed()).limit(labelSetLimit).collect(Collectors.toList());
}
if (pluginPredictor instanceof HammingPredictor || pluginPredictor instanceof MacroF1Predictor) {
labelSetRanking = new ArrayList<>();
DynamicProgramming dp = new DynamicProgramming(classProbs);
for (int c = 0; c < labelSetLimit; c++) {
DynamicProgramming.Candidate candidate = dp.nextHighest();
MultiLabel multiLabel = candidate.getMultiLabel();
double setProb = candidate.getProbability();
MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
labelSetRanking.add(labelSetProbInfo);
}
}
predictionAnalysis.setPredictedLabelSetRanking(labelSetRanking);
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor in project pyramid by cheng-li.
the class App6 method report.
static void report(Config config, MultiLabelClfDataSet dataSet, String dataName) throws Exception {
System.out.println("generating reports for data set " + dataName);
String output = config.getString("output.folder");
String modelName = "model_crf";
File analysisFolder = new File(new File(output, "reports_crf"), dataName + "_reports");
analysisFolder.mkdirs();
FileUtils.cleanDirectory(analysisFolder);
CMLCRF crf = (CMLCRF) Serialization.deserialize(new File(output, modelName));
PluginPredictor<CMLCRF> predictorTmp = null;
String predictTarget = config.getString("predict.target");
switch(predictTarget) {
case "subsetAccuracy":
predictorTmp = new SubsetAccPredictor(crf);
break;
case "instanceFMeasure":
predictorTmp = new InstanceF1Predictor(crf);
break;
default:
throw new IllegalArgumentException("predict.target must be subsetAccuracy or instanceFMeasure");
}
// just to make Lambda expressions happy
final PluginPredictor<CMLCRF> predictor = predictorTmp;
MLMeasures mlMeasures = new MLMeasures(predictor, dataSet);
mlMeasures.getMacroAverage().setLabelTranslator(crf.getLabelTranslator());
System.out.println("performance on dataset " + dataName);
System.out.println(mlMeasures);
boolean simpleCSV = true;
if (simpleCSV) {
// System.out.println("start generating simple CSV report");
double probThreshold = config.getDouble("report.classProbThreshold");
File csv = new File(analysisFolder, "report.csv");
List<String> strs = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToObj(i -> CRFInspector.simplePredictionAnalysis(crf, predictor, dataSet, i, probThreshold)).collect(Collectors.toList());
StringBuilder sb = new StringBuilder();
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
String str = strs.get(i);
sb.append(str);
}
FileUtils.writeStringToFile(csv, sb.toString(), false);
// System.out.println("finish generating simple CSV report");
}
boolean dataInfoToJson = true;
if (dataInfoToJson) {
// System.out.println("start writing data info to json");
Set<String> modelLabels = IntStream.range(0, crf.getNumClasses()).mapToObj(i -> crf.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
Set<String> dataSetLabels = DataSetUtil.gatherLabels(dataSet).stream().map(i -> dataSet.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
JsonGenerator jsonGenerator = new JsonFactory().createGenerator(new File(analysisFolder, "data_info.json"), JsonEncoding.UTF8);
jsonGenerator.writeStartObject();
jsonGenerator.writeStringField("dataSet", dataName);
jsonGenerator.writeNumberField("numClassesInModel", crf.getNumClasses());
jsonGenerator.writeNumberField("numClassesInDataSet", dataSetLabels.size());
jsonGenerator.writeNumberField("numClassesInModelDataSetCombined", dataSet.getNumClasses());
Set<String> modelNotDataLabels = SetUtil.complement(modelLabels, dataSetLabels);
Set<String> dataNotModelLabels = SetUtil.complement(dataSetLabels, modelLabels);
jsonGenerator.writeNumberField("numClassesInDataSetButNotModel", dataNotModelLabels.size());
jsonGenerator.writeNumberField("numClassesInModelButNotDataSet", modelNotDataLabels.size());
jsonGenerator.writeArrayFieldStart("classesInDataSetButNotModel");
for (String label : dataNotModelLabels) {
jsonGenerator.writeObject(label);
}
jsonGenerator.writeEndArray();
jsonGenerator.writeArrayFieldStart("classesInModelButNotDataSet");
for (String label : modelNotDataLabels) {
jsonGenerator.writeObject(label);
}
jsonGenerator.writeEndArray();
jsonGenerator.writeNumberField("labelCardinality", dataSet.labelCardinality());
jsonGenerator.writeEndObject();
jsonGenerator.close();
// System.out.println("finish writing data info to json");
}
boolean modelConfigToJson = true;
if (modelConfigToJson) {
// System.out.println("start writing model config to json");
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "model_config.json"), config);
// System.out.println("finish writing model config to json");
}
boolean performanceToJson = true;
if (performanceToJson) {
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "performance.json"), mlMeasures);
}
boolean individualPerformance = true;
if (individualPerformance) {
// System.out.println("start writing individual label performance to json");
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "individual_performance.json"), mlMeasures.getMacroAverage());
// System.out.println("finish writing individual label performance to json");
}
System.out.println("reports generated");
}
Aggregations