use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class App2 method main.
public static void main(String[] args) throws Exception {
if (args.length != 1) {
throw new IllegalArgumentException("Please specify a properties file.");
}
Config config = new Config(args[0]);
main(config);
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class App2 method train.
static void train(Config config, Logger logger) throws Exception {
String output = config.getString("output.folder");
int numIterations = config.getInt("train.numIterations");
int numLeaves = config.getInt("train.numLeaves");
double learningRate = config.getDouble("train.learningRate");
int minDataPerLeaf = config.getInt("train.minDataPerLeaf");
String modelName = "model_app3";
// double featureSamplingRate = config.getDouble("train.featureSamplingRate");
// double dataSamplingRate = config.getDouble("train.dataSamplingRate");
StopWatch stopWatch = new StopWatch();
stopWatch.start();
MultiLabelClfDataSet dataSet = loadData(config, config.getString("input.trainData"));
MultiLabelClfDataSet testSet = null;
if (config.getBoolean("train.showTestProgress")) {
testSet = loadData(config, config.getString("input.testData"));
}
int numClasses = dataSet.getNumClasses();
logger.info("number of class = " + numClasses);
IMLGBConfig imlgbConfig = new IMLGBConfig.Builder(dataSet).learningRate(learningRate).minDataPerLeaf(minDataPerLeaf).numLeaves(numLeaves).numSplitIntervals(config.getInt("train.numSplitIntervals")).usePrior(config.getBoolean("train.usePrior")).build();
IMLGradientBoosting boosting;
if (config.getBoolean("train.warmStart")) {
boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
} else {
boosting = new IMLGradientBoosting(numClasses);
}
logger.info("During training, the performance is reported using Hamming loss optimal predictor");
logger.info("initialing trainer");
IMLGBTrainer trainer = new IMLGBTrainer(imlgbConfig, boosting);
boolean earlyStop = config.getBoolean("train.earlyStop");
List<EarlyStopper> earlyStoppers = new ArrayList<>();
List<Terminator> terminators = new ArrayList<>();
if (earlyStop) {
for (int l = 0; l < numClasses; l++) {
EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, config.getInt("train.earlyStop.patience"));
earlyStopper.setMinimumIterations(config.getInt("train.earlyStop.minIterations"));
earlyStoppers.add(earlyStopper);
}
for (int l = 0; l < numClasses; l++) {
Terminator terminator = new Terminator();
terminator.setMaxStableIterations(config.getInt("train.earlyStop.patience")).setMinIterations(config.getInt("train.earlyStop.minIterations") / config.getInt("train.showProgress.interval")).setAbsoluteEpsilon(config.getDouble("train.earlyStop.absoluteChange")).setRelativeEpsilon(config.getDouble("train.earlyStop.relativeChange")).setOperation(Terminator.Operation.OR);
terminators.add(terminator);
}
}
logger.info("trainer initialized");
int numLabelsLeftToTrain = numClasses;
int progressInterval = config.getInt("train.showProgress.interval");
for (int i = 1; i <= numIterations; i++) {
logger.info("iteration " + i);
trainer.iterate();
if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
logger.info("training set performance");
logger.info(new MLMeasures(boosting, dataSet).toString());
}
if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
logger.info("test set performance");
logger.info(new MLMeasures(boosting, testSet).toString());
if (earlyStop) {
for (int l = 0; l < numClasses; l++) {
EarlyStopper earlyStopper = earlyStoppers.get(l);
Terminator terminator = terminators.get(l);
if (!trainer.getShouldStop()[l]) {
double kl = KL(boosting, testSet, l);
earlyStopper.add(i, kl);
terminator.add(kl);
if (earlyStopper.shouldStop() || terminator.shouldTerminate()) {
logger.info("training for label " + l + " (" + dataSet.getLabelTranslator().toExtLabel(l) + ") should stop now");
logger.info("the best number of training iterations for the label is " + earlyStopper.getBestIteration());
trainer.setShouldStop(l);
numLabelsLeftToTrain -= 1;
logger.info("the number of labels left to be trained on = " + numLabelsLeftToTrain);
}
}
}
}
}
if (numLabelsLeftToTrain == 0) {
logger.info("all label training finished");
break;
}
}
logger.info("training done");
File serializedModel = new File(output, modelName);
//todo pick best models
boosting.serialize(serializedModel);
logger.info(stopWatch.toString());
if (earlyStop) {
for (int l = 0; l < numClasses; l++) {
logger.info("----------------------------------------------------");
logger.info("test performance history for label " + l + ": " + earlyStoppers.get(l).history());
logger.info("model size for label " + l + " = " + (boosting.getRegressors(l).size() - 1));
}
}
boolean topFeaturesToFile = true;
if (topFeaturesToFile) {
logger.info("start writing top features");
int limit = config.getInt("report.topFeatures.limit");
List<TopFeatures> topFeaturesList = IntStream.range(0, boosting.getNumClasses()).mapToObj(k -> IMLGBInspector.topFeatures(boosting, k, limit)).collect(Collectors.toList());
ObjectMapper mapper = new ObjectMapper();
String file = "top_features.json";
mapper.writeValue(new File(output, file), topFeaturesList);
StringBuilder sb = new StringBuilder();
for (int l = 0; l < boosting.getNumClasses(); l++) {
sb.append("-------------------------").append("\n");
sb.append(dataSet.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
sb.append(feature.simpleString()).append(", ");
}
sb.append("\n");
}
FileUtils.writeStringToFile(new File(output, "top_features.txt"), sb.toString());
logger.info("finish writing top features");
}
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class App2 method report.
static void report(Config config, String dataName, Logger logger) throws Exception {
logger.info("generating reports for data set " + dataName);
String output = config.getString("output.folder");
String modelName = "model_app3";
File analysisFolder = new File(new File(output, "reports_app3"), dataName + "_reports");
analysisFolder.mkdirs();
FileUtils.cleanDirectory(analysisFolder);
IMLGradientBoosting boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
String predictTarget = config.getString("predict.target");
PluginPredictor<IMLGradientBoosting> pluginPredictorTmp = null;
switch(predictTarget) {
case "subsetAccuracy":
pluginPredictorTmp = new SubsetAccPredictor(boosting);
break;
case "hammingLoss":
pluginPredictorTmp = new HammingPredictor(boosting);
break;
case "instanceFMeasure":
pluginPredictorTmp = new InstanceF1Predictor(boosting);
break;
case "macroFMeasure":
TunedMarginalClassifier tunedMarginalClassifier = (TunedMarginalClassifier) Serialization.deserialize(new File(output, "predictor_macro_f"));
pluginPredictorTmp = new MacroF1Predictor(boosting, tunedMarginalClassifier);
break;
default:
throw new IllegalArgumentException("unknown prediction target measure " + predictTarget);
}
// just to make Lambda expressions happy
final PluginPredictor<IMLGradientBoosting> pluginPredictor = pluginPredictorTmp;
MultiLabelClfDataSet dataSet = loadData(config, dataName);
MLMeasures mlMeasures = new MLMeasures(pluginPredictor, dataSet);
mlMeasures.getMacroAverage().setLabelTranslator(boosting.getLabelTranslator());
logger.info("performance on dataset " + dataName);
logger.info(mlMeasures.toString());
boolean simpleCSV = true;
if (simpleCSV) {
logger.info("start generating simple CSV report");
double probThreshold = config.getDouble("report.classProbThreshold");
File csv = new File(analysisFolder, "report.csv");
List<String> strs = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToObj(i -> IMLGBInspector.simplePredictionAnalysis(boosting, pluginPredictor, dataSet, i, probThreshold)).collect(Collectors.toList());
try (BufferedWriter bw = new BufferedWriter(new FileWriter(csv))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
String str = strs.get(i);
bw.write(str);
}
}
logger.info("finish generating simple CSV report");
}
boolean rulesToJson = config.getBoolean("report.showPredictionDetail");
if (rulesToJson) {
logger.info("start writing rules to json");
int ruleLimit = config.getInt("report.rule.limit");
int numDocsPerFile = config.getInt("report.numDocsPerFile");
int numFiles = (int) Math.ceil((double) dataSet.getNumDataPoints() / numDocsPerFile);
double probThreshold = config.getDouble("report.classProbThreshold");
int labelSetLimit = config.getInt("report.labelSetLimit");
IntStream.range(0, numFiles).forEach(i -> {
int start = i * numDocsPerFile;
int end = start + numDocsPerFile;
List<MultiLabelPredictionAnalysis> partition = IntStream.range(start, Math.min(end, dataSet.getNumDataPoints())).parallel().mapToObj(a -> IMLGBInspector.analyzePrediction(boosting, pluginPredictor, dataSet, a, ruleLimit, labelSetLimit, probThreshold)).collect(Collectors.toList());
ObjectMapper mapper = new ObjectMapper();
String file = "report_" + (i + 1) + ".json";
try {
mapper.writeValue(new File(analysisFolder, file), partition);
} catch (IOException e) {
e.printStackTrace();
}
logger.info("progress = " + Progress.percentage(i + 1, numFiles));
});
logger.info("finish writing rules to json");
}
boolean dataInfoToJson = true;
if (dataInfoToJson) {
logger.info("start writing data info to json");
Set<String> modelLabels = IntStream.range(0, boosting.getNumClasses()).mapToObj(i -> boosting.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
Set<String> dataSetLabels = DataSetUtil.gatherLabels(dataSet).stream().map(i -> dataSet.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
JsonGenerator jsonGenerator = new JsonFactory().createGenerator(new File(analysisFolder, "data_info.json"), JsonEncoding.UTF8);
jsonGenerator.writeStartObject();
jsonGenerator.writeStringField("dataSet", dataName);
jsonGenerator.writeNumberField("numClassesInModel", boosting.getNumClasses());
jsonGenerator.writeNumberField("numClassesInDataSet", dataSetLabels.size());
jsonGenerator.writeNumberField("numClassesInModelDataSetCombined", dataSet.getNumClasses());
Set<String> modelNotDataLabels = SetUtil.complement(modelLabels, dataSetLabels);
Set<String> dataNotModelLabels = SetUtil.complement(dataSetLabels, modelLabels);
jsonGenerator.writeNumberField("numClassesInDataSetButNotModel", dataNotModelLabels.size());
jsonGenerator.writeNumberField("numClassesInModelButNotDataSet", modelNotDataLabels.size());
jsonGenerator.writeArrayFieldStart("classesInDataSetButNotModel");
for (String label : dataNotModelLabels) {
jsonGenerator.writeObject(label);
}
jsonGenerator.writeEndArray();
jsonGenerator.writeArrayFieldStart("classesInModelButNotDataSet");
for (String label : modelNotDataLabels) {
jsonGenerator.writeObject(label);
}
jsonGenerator.writeEndArray();
jsonGenerator.writeNumberField("labelCardinality", dataSet.labelCardinality());
jsonGenerator.writeEndObject();
jsonGenerator.close();
logger.info("finish writing data info to json");
}
boolean modelConfigToJson = true;
if (modelConfigToJson) {
logger.info("start writing model config to json");
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "model_config.json"), config);
logger.info("finish writing model config to json");
}
boolean dataConfigToJson = true;
if (dataConfigToJson) {
logger.info("start writing data config to json");
File dataConfigFile = Paths.get(config.getString("input.folder"), "data_sets", dataName, "data_config.json").toFile();
if (dataConfigFile.exists()) {
FileUtils.copyFileToDirectory(dataConfigFile, analysisFolder);
}
logger.info("finish writing data config to json");
}
boolean performanceToJson = true;
if (performanceToJson) {
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "performance.json"), mlMeasures);
}
boolean individualPerformance = true;
if (individualPerformance) {
logger.info("start writing individual label performance to json");
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "individual_performance.json"), mlMeasures.getMacroAverage());
logger.info("finish writing individual label performance to json");
}
logger.info("reports generated");
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class App3 method createApp1Config.
private static Config createApp1Config(Config config) {
Config app1Config = new Config();
String[] same = { "output.folder", "output.trainFolder", "output.testFolder", "output.log", "train.feature.useInitialFeatures", "train.feature.categFeature.filter", "train.feature.categFeature.percentThreshold", "train.feature.ngram.n", "train.feature.ngram.minDf", "train.feature.ngram.slop", "train.feature.missingValue", "train.feature.addExternalNgrams", "train.feature.externalNgramFile", "train.feature.analyzer", "train.feature.filterNgramsByKeyWords", "train.feature.filterNgrams.keyWordsFile", "train.feature.filterNgramsByRegex", "train.feature.filterNgrams.regex", "train.feature.useCodeDescription", "train.feature.codeDesc.File", "train.feature.codeDesc.analyzer", "train.feature.codeDesc.matchField", "train.feature.codeDesc.minMatchPercentage", "index.indexName", "index.clusterName", "index.documentType", "index.clientType", "index.hosts", "index.ports", "train.label.field", "train.label.filter", "train.label.filter.prefix", "train.feature.featureFieldPrefix", "train.feature.ngram.extractionFields", "train.splitQuery", "test.splitQuery", "train.feature.ngram.matchScoreType", "createTrainSet", "createTestSet", "train.feature.ngram.selection", "train.feature.ngram.selectPerLabel", "train.label.order" };
Config.copyExisting(config, app1Config, same);
return app1Config;
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class App3 method createApp2Config.
private static Config createApp2Config(Config config) {
Config app2Config = new Config();
String[] same = { "output.folder", "output.log", "train", "test", "tune", "predict.target", "train.warmStart", "train.usePrior", "train.numIterations", "train.numLeaves", "train.learningRate", "train.minDataPerLeaf", "train.numSplitIntervals", "train.showTrainProgress", "train.showTestProgress", "train.earlyStop.patience", "train.earlyStop.minIterations", "train.earlyStop", "train.earlyStop.absoluteChange", "train.earlyStop.relativeChange", "train.showProgress.interval", "train.generateReports", "tune.data", "tune.FMeasure.beta", "report.topFeatures.limit", "report.rule.limit", "report.numDocsPerFile", "report.classProbThreshold", "report.labelSetLimit", "report.showPredictionDetail" };
Config.copyExisting(config, app2Config, same);
app2Config.setString("input.folder", config.getString("output.folder"));
app2Config.setString("input.trainData", config.getString("output.trainFolder"));
app2Config.setString("input.testData", config.getString("output.testFolder"));
return app2Config;
}
Aggregations