use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class App3 method main.
public static void main(String[] args) throws Exception {
if (args.length != 1) {
throw new IllegalArgumentException("Please specify a properties file.");
}
Config config = new Config(args[0]);
Logger logger = Logger.getAnonymousLogger();
String logFile = config.getString("output.log");
FileHandler fileHandler = null;
if (!logFile.isEmpty()) {
new File(logFile).getParentFile().mkdirs();
//todo should append?
fileHandler = new FileHandler(logFile, true);
java.util.logging.Formatter formatter = new SimpleFormatter();
fileHandler.setFormatter(formatter);
logger.addHandler(fileHandler);
logger.setUseParentHandlers(false);
}
logger.info(config.toString());
if (fileHandler != null) {
fileHandler.close();
}
File output = new File(config.getString("output.folder"));
output.mkdirs();
Config app1Config = createApp1Config(config);
Config app2Config = createApp2Config(config);
App1.main(app1Config);
App2.main(app2Config);
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class App6 method main.
public static void main(String[] args) throws Exception {
if (args.length != 1) {
throw new IllegalArgumentException("Please specify a properties file.");
}
Config config = new Config(args[0]);
System.out.println(config);
if (config.getBoolean("train")) {
train(config);
}
if (config.getBoolean("test")) {
test(config);
}
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class App6 method report.
static void report(Config config, MultiLabelClfDataSet dataSet, String dataName) throws Exception {
System.out.println("generating reports for data set " + dataName);
String output = config.getString("output.folder");
String modelName = "model_crf";
File analysisFolder = new File(new File(output, "reports_crf"), dataName + "_reports");
analysisFolder.mkdirs();
FileUtils.cleanDirectory(analysisFolder);
CMLCRF crf = (CMLCRF) Serialization.deserialize(new File(output, modelName));
PluginPredictor<CMLCRF> predictorTmp = null;
String predictTarget = config.getString("predict.target");
switch(predictTarget) {
case "subsetAccuracy":
predictorTmp = new SubsetAccPredictor(crf);
break;
case "instanceFMeasure":
predictorTmp = new InstanceF1Predictor(crf);
break;
default:
throw new IllegalArgumentException("predict.target must be subsetAccuracy or instanceFMeasure");
}
// just to make Lambda expressions happy
final PluginPredictor<CMLCRF> predictor = predictorTmp;
MLMeasures mlMeasures = new MLMeasures(predictor, dataSet);
mlMeasures.getMacroAverage().setLabelTranslator(crf.getLabelTranslator());
System.out.println("performance on dataset " + dataName);
System.out.println(mlMeasures);
boolean simpleCSV = true;
if (simpleCSV) {
// System.out.println("start generating simple CSV report");
double probThreshold = config.getDouble("report.classProbThreshold");
File csv = new File(analysisFolder, "report.csv");
List<String> strs = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToObj(i -> CRFInspector.simplePredictionAnalysis(crf, predictor, dataSet, i, probThreshold)).collect(Collectors.toList());
StringBuilder sb = new StringBuilder();
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
String str = strs.get(i);
sb.append(str);
}
FileUtils.writeStringToFile(csv, sb.toString(), false);
// System.out.println("finish generating simple CSV report");
}
boolean dataInfoToJson = true;
if (dataInfoToJson) {
// System.out.println("start writing data info to json");
Set<String> modelLabels = IntStream.range(0, crf.getNumClasses()).mapToObj(i -> crf.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
Set<String> dataSetLabels = DataSetUtil.gatherLabels(dataSet).stream().map(i -> dataSet.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
JsonGenerator jsonGenerator = new JsonFactory().createGenerator(new File(analysisFolder, "data_info.json"), JsonEncoding.UTF8);
jsonGenerator.writeStartObject();
jsonGenerator.writeStringField("dataSet", dataName);
jsonGenerator.writeNumberField("numClassesInModel", crf.getNumClasses());
jsonGenerator.writeNumberField("numClassesInDataSet", dataSetLabels.size());
jsonGenerator.writeNumberField("numClassesInModelDataSetCombined", dataSet.getNumClasses());
Set<String> modelNotDataLabels = SetUtil.complement(modelLabels, dataSetLabels);
Set<String> dataNotModelLabels = SetUtil.complement(dataSetLabels, modelLabels);
jsonGenerator.writeNumberField("numClassesInDataSetButNotModel", dataNotModelLabels.size());
jsonGenerator.writeNumberField("numClassesInModelButNotDataSet", modelNotDataLabels.size());
jsonGenerator.writeArrayFieldStart("classesInDataSetButNotModel");
for (String label : dataNotModelLabels) {
jsonGenerator.writeObject(label);
}
jsonGenerator.writeEndArray();
jsonGenerator.writeArrayFieldStart("classesInModelButNotDataSet");
for (String label : modelNotDataLabels) {
jsonGenerator.writeObject(label);
}
jsonGenerator.writeEndArray();
jsonGenerator.writeNumberField("labelCardinality", dataSet.labelCardinality());
jsonGenerator.writeEndObject();
jsonGenerator.close();
// System.out.println("finish writing data info to json");
}
boolean modelConfigToJson = true;
if (modelConfigToJson) {
// System.out.println("start writing model config to json");
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "model_config.json"), config);
// System.out.println("finish writing model config to json");
}
boolean performanceToJson = true;
if (performanceToJson) {
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "performance.json"), mlMeasures);
}
boolean individualPerformance = true;
if (individualPerformance) {
// System.out.println("start writing individual label performance to json");
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "individual_performance.json"), mlMeasures.getMacroAverage());
// System.out.println("finish writing individual label performance to json");
}
System.out.println("reports generated");
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMEN method reportF1Prediction.
private static void reportF1Prediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance F1 optimal predictor");
String output = config.getString("output.dir");
PluginF1 pluginF1 = new PluginF1(cbm);
List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(new File(output, "support"));
pluginF1.setSupport(support);
pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = pluginF1.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance F1 optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMLR method reportF1Prediction.
private static void reportF1Prediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance F1 optimal predictor");
String output = config.getString("output.dir");
PluginF1 pluginF1 = new PluginF1(cbm);
List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(new File(output, "support"));
pluginF1.setSupport(support);
pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = pluginF1.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance F1 optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
Aggregations