use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class ArffFormat method writeConfigFile.
private static void writeConfigFile(ClfDataSet dataSet, File arffFile) {
File configFile = new File(arffFile, ARFF_CONFIG_FILE_NAME);
Config config = new Config();
config.setInt(ARFF_CONFIG_NUM_DATA_POINTS, dataSet.getNumDataPoints());
config.setInt(ARFF_CONFIG_NUM_FEATURES, dataSet.getNumFeatures());
config.setInt(ARFF_CONFIG_NUM_CLASSES, dataSet.getNumClasses());
try {
config.store(configFile);
} catch (Exception e) {
e.printStackTrace();
}
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class ArffFormat method writeConfigFile.
private static void writeConfigFile(MultiLabelClfDataSet dataSet, File arffFile) {
File configFile = new File(arffFile, ARFF_CONFIG_FILE_NAME);
Config config = new Config();
config.setInt(ARFF_CONFIG_NUM_DATA_POINTS, dataSet.getNumDataPoints());
config.setInt(ARFF_CONFIG_NUM_FEATURES, dataSet.getNumFeatures());
config.setInt(ARFF_CONFIG_NUM_CLASSES, dataSet.getNumClasses());
try {
config.store(configFile);
} catch (Exception e) {
e.printStackTrace();
}
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class BRRerank method main.
public static void main(String[] args) throws Exception {
if (args.length != 1) {
throw new IllegalArgumentException("Please specify a properties file.");
}
Config config = new Config(args[0]);
if (config.getBoolean("trainBR")) {
System.out.println("Start training BR classifier");
Config brConfig = produceBRConfig(config);
CBMEN.main(brConfig);
System.out.println("Finish training BR classifier");
}
Config caliConfig = produceCaliConfig(config);
calibrate(caliConfig);
report(config);
calibration_eval(config);
classification_eval(config);
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class BRRerank method report.
private static void report(Config config) throws Exception {
MultiLabelClfDataSet dataset = TRECFormat.loadMultiLabelClfDataSet(Paths.get(config.getString("dataPath"), "test").toFile(), DataSetType.ML_CLF_SPARSE, true);
CBM cbm = (CBM) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "model"));
cbm.setAllowEmpty(true);
MultiLabelClassifier classifier = null;
LabelCalibrator labelCalibrator = (LabelCalibrator) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "label_calibrator"));
VectorCalibrator setCalibrator = (VectorCalibrator) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "set_calibrator"));
PredictionFeatureExtractor predictionFeatureExtractor = (PredictionFeatureExtractor) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "calibration_feature_extractor"));
CalibrationDataGenerator calibrationDataGenerator = new CalibrationDataGenerator(labelCalibrator, predictionFeatureExtractor);
switch(config.getString("predictMode")) {
case "independent":
classifier = new IndependentPredictor(cbm, labelCalibrator);
break;
case "rerank":
classifier = (Reranker) setCalibrator;
break;
default:
throw new IllegalArgumentException("illegal predict.mode");
}
MultiLabel[] predictions = classifier.predict(dataset);
List<CalibInfo> confidenceScores = IntStream.range(0, dataset.getNumDataPoints()).parallel().boxed().map(i -> {
CalibrationDataGenerator.CalibrationInstance predictionInstance = calibrationDataGenerator.createInstance(cbm, dataset.getRow(i), predictions[i], dataset.getMultiLabels()[i], "accuracy");
double calibrated = setCalibrator.calibrate(predictionInstance.vector);
CalibInfo calibInfo = new CalibInfo();
calibInfo.uncalibrated = predictionInstance.vector.get(0);
calibInfo.calibrated = calibrated;
calibInfo.accuracy = predictionInstance.correctness;
return calibInfo;
}).collect(Collectors.toList());
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("set_prediction").append("\t").append("uncalibrated_confidence").append("\t").append("calibrated_confidence").append("\t").append("ground_truth").append("\t").append("set_accuracy").append("\n");
for (int i = 0; i < dataset.getNumDataPoints(); i++) {
stringBuilder.append(predictions[i]).append("\t").append(confidenceScores.get(i).uncalibrated).append("\t").append(confidenceScores.get(i).calibrated).append("\t").append(dataset.getMultiLabels()[i]).append("\t").append(predictions[i].equals(dataset.getMultiLabels()[i]) ? 1 : 0).append("\n");
}
FileUtils.writeStringToFile(Paths.get(config.getString("outputDir"), "reports", "set_prediction_and_confidence.txt").toFile(), stringBuilder.toString());
System.out.println("set predictions and confidence scores are saved to " + Paths.get(config.getString("outputDir"), "reports", "set_prediction_and_confidence.txt").toString() + "\n");
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMLR method main.
public static void main(String[] args) throws Exception {
if (args.length != 1) {
throw new IllegalArgumentException("Please specify a properties file.");
}
Config config = new Config(args[0]);
System.out.println(config);
VERBOSE = config.getBoolean("output.verbose");
new File(config.getString("output.dir")).mkdirs();
if (config.getBoolean("tune")) {
System.out.println("============================================================");
System.out.println("Start hyper parameter tuning");
StopWatch stopWatch = new StopWatch();
stopWatch.start();
List<TuneResult> tuneResults = new ArrayList<>();
List<MultiLabelClfDataSet> dataSets = loadTrainValidData(config);
List<Double> variances = config.getDoubles("tune.variance.candidates");
List<Integer> components = config.getIntegers("tune.numComponents.candidates");
for (double variance : variances) {
for (int component : components) {
StopWatch stopWatch1 = new StopWatch();
stopWatch1.start();
HyperParameters hyperParameters = new HyperParameters();
hyperParameters.numComponents = component;
hyperParameters.variance = variance;
System.out.println("---------------------------");
System.out.println("Trying hyper parameters:");
System.out.println("train.numComponents = " + hyperParameters.numComponents);
System.out.println("train.variance = " + hyperParameters.variance);
TuneResult tuneResult = tune(config, hyperParameters, dataSets.get(0), dataSets.get(1));
System.out.println("Found optimal train.iterations = " + tuneResult.hyperParameters.iterations);
System.out.println("Validation performance = " + tuneResult.performance);
tuneResults.add(tuneResult);
System.out.println("Time spent on trying this set of hyper parameters = " + stopWatch1);
}
}
Comparator<TuneResult> comparator = Comparator.comparing(res -> res.performance);
TuneResult best;
String predictTarget = config.getString("tune.targetMetric");
switch(predictTarget) {
case "instance_set_accuracy":
best = tuneResults.stream().max(comparator).get();
break;
case "instance_f1":
best = tuneResults.stream().max(comparator).get();
break;
case "instance_hamming_loss":
best = tuneResults.stream().min(comparator).get();
break;
case "instance_log_likelihood":
best = tuneResults.stream().max(comparator).get();
break;
default:
throw new IllegalArgumentException("tune.targetMetric should be instance_set_accuracy, instance_f1, instance_hamming_loss, instance_log_likelihood");
}
System.out.println("---------------------------");
System.out.println("Hyper parameter tuning done.");
System.out.println("Time spent on entire hyper parameter tuning = " + stopWatch);
System.out.println("Best validation performance = " + best.performance);
System.out.println("Best hyper parameters:");
System.out.println("train.numComponents = " + best.hyperParameters.numComponents);
System.out.println("train.variance = " + best.hyperParameters.variance);
System.out.println("train.iterations = " + best.hyperParameters.iterations);
Config tunedHypers = best.hyperParameters.asConfig();
tunedHypers.store(new File(config.getString("output.dir"), "tuned_hyper_parameters.properties"));
System.out.println("Tuned hyper parameters saved to " + new File(config.getString("output.dir"), "tuned_hyper_parameters.properties").getAbsolutePath());
System.out.println("============================================================");
}
if (config.getBoolean("train")) {
System.out.println("============================================================");
if (config.getBoolean("train.useTunedHyperParameters")) {
File hyperFile = new File(config.getString("output.dir"), "tuned_hyper_parameters.properties");
if (!hyperFile.exists()) {
System.out.println("train.useTunedHyperParameters is set to true. But no tuned hyper parameters can be found in the output directory.");
System.out.println("Please either run hyper parameter tuning, or provide hyper parameters manually and set train.useTunedHyperParameters=false.");
System.exit(1);
}
Config tunedHypers = new Config(hyperFile);
HyperParameters hyperParameters = new HyperParameters(tunedHypers);
System.out.println("Start training with tuned hyper parameters:");
System.out.println("train.numComponents = " + hyperParameters.numComponents);
System.out.println("train.variance = " + hyperParameters.variance);
System.out.println("train.iterations = " + hyperParameters.iterations);
MultiLabelClfDataSet trainSet = loadTrainData(config);
train(config, hyperParameters, trainSet);
} else {
HyperParameters hyperParameters = new HyperParameters(config);
System.out.println("Start training with given hyper parameters:");
System.out.println("train.numComponents = " + hyperParameters.numComponents);
System.out.println("train.variance = " + hyperParameters.variance);
System.out.println("train.iterations = " + hyperParameters.iterations);
MultiLabelClfDataSet trainSet = loadTrainData(config);
train(config, hyperParameters, trainSet);
}
System.out.println("============================================================");
}
if (config.getBoolean("test")) {
System.out.println("============================================================");
test(config);
System.out.println("============================================================");
}
}
Aggregations