use of edu.neu.ccs.pyramid.feature.TopFeatures in project pyramid by cheng-li.
the class MLLogisticRegressionInspector method topFeatures.
public static TopFeatures topFeatures(MLLogisticRegression logisticRegression, int classIndex, int limit) {
FeatureList featureList = logisticRegression.getFeatureList();
Vector weights = logisticRegression.getWeights().getWeightsWithoutBiasForClass(classIndex);
Comparator<FeatureUtility> comparator = Comparator.comparing(FeatureUtility::getUtility);
List<Feature> list = IntStream.range(0, weights.size()).mapToObj(i -> new FeatureUtility(featureList.get(i)).setUtility(weights.get(i))).filter(featureUtility -> featureUtility.getUtility() > 0).sorted(comparator.reversed()).map(FeatureUtility::getFeature).limit(limit).collect(Collectors.toList());
TopFeatures topFeatures = new TopFeatures();
topFeatures.setTopFeatures(list);
topFeatures.setClassIndex(classIndex);
LabelTranslator labelTranslator = logisticRegression.getLabelTranslator();
topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
return topFeatures;
}
use of edu.neu.ccs.pyramid.feature.TopFeatures in project pyramid by cheng-li.
the class AdaBoostMHInspector method topFeatures.
public static TopFeatures topFeatures(AdaBoostMH boosting, int classIndex, int limit) {
Map<Feature, Double> totalContributions = new HashMap<>();
List<Regressor> regressors = boosting.getRegressors(classIndex);
List<RegressionTree> trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree) regressor).collect(Collectors.toList());
for (RegressionTree tree : trees) {
Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
Feature feature = entry.getKey();
Double contribution = entry.getValue();
double oldValue = totalContributions.getOrDefault(feature, 0.0);
double newValue = oldValue + contribution;
totalContributions.put(feature, newValue);
}
}
Comparator<Map.Entry<Feature, Double>> comparator = Comparator.comparing(Map.Entry::getValue);
List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).limit(limit).map(Map.Entry::getKey).collect(Collectors.toList());
TopFeatures topFeatures = new TopFeatures();
topFeatures.setTopFeatures(list);
topFeatures.setClassIndex(classIndex);
LabelTranslator labelTranslator = boosting.getLabelTranslator();
topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
return topFeatures;
}
use of edu.neu.ccs.pyramid.feature.TopFeatures in project pyramid by cheng-li.
the class App2 method train.
static void train(Config config, Logger logger) throws Exception {
String output = config.getString("output.folder");
int numIterations = config.getInt("train.numIterations");
int numLeaves = config.getInt("train.numLeaves");
double learningRate = config.getDouble("train.learningRate");
int minDataPerLeaf = config.getInt("train.minDataPerLeaf");
String modelName = "model_app3";
// double featureSamplingRate = config.getDouble("train.featureSamplingRate");
// double dataSamplingRate = config.getDouble("train.dataSamplingRate");
StopWatch stopWatch = new StopWatch();
stopWatch.start();
MultiLabelClfDataSet dataSet = loadData(config, config.getString("input.trainData"));
MultiLabelClfDataSet testSet = null;
if (config.getBoolean("train.showTestProgress")) {
testSet = loadData(config, config.getString("input.testData"));
}
int numClasses = dataSet.getNumClasses();
logger.info("number of class = " + numClasses);
IMLGBConfig imlgbConfig = new IMLGBConfig.Builder(dataSet).learningRate(learningRate).minDataPerLeaf(minDataPerLeaf).numLeaves(numLeaves).numSplitIntervals(config.getInt("train.numSplitIntervals")).usePrior(config.getBoolean("train.usePrior")).build();
IMLGradientBoosting boosting;
if (config.getBoolean("train.warmStart")) {
boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
} else {
boosting = new IMLGradientBoosting(numClasses);
}
logger.info("During training, the performance is reported using Hamming loss optimal predictor");
logger.info("initialing trainer");
IMLGBTrainer trainer = new IMLGBTrainer(imlgbConfig, boosting);
boolean earlyStop = config.getBoolean("train.earlyStop");
List<EarlyStopper> earlyStoppers = new ArrayList<>();
List<Terminator> terminators = new ArrayList<>();
if (earlyStop) {
for (int l = 0; l < numClasses; l++) {
EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, config.getInt("train.earlyStop.patience"));
earlyStopper.setMinimumIterations(config.getInt("train.earlyStop.minIterations"));
earlyStoppers.add(earlyStopper);
}
for (int l = 0; l < numClasses; l++) {
Terminator terminator = new Terminator();
terminator.setMaxStableIterations(config.getInt("train.earlyStop.patience")).setMinIterations(config.getInt("train.earlyStop.minIterations") / config.getInt("train.showProgress.interval")).setAbsoluteEpsilon(config.getDouble("train.earlyStop.absoluteChange")).setRelativeEpsilon(config.getDouble("train.earlyStop.relativeChange")).setOperation(Terminator.Operation.OR);
terminators.add(terminator);
}
}
logger.info("trainer initialized");
int numLabelsLeftToTrain = numClasses;
int progressInterval = config.getInt("train.showProgress.interval");
for (int i = 1; i <= numIterations; i++) {
logger.info("iteration " + i);
trainer.iterate();
if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
logger.info("training set performance");
logger.info(new MLMeasures(boosting, dataSet).toString());
}
if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
logger.info("test set performance");
logger.info(new MLMeasures(boosting, testSet).toString());
if (earlyStop) {
for (int l = 0; l < numClasses; l++) {
EarlyStopper earlyStopper = earlyStoppers.get(l);
Terminator terminator = terminators.get(l);
if (!trainer.getShouldStop()[l]) {
double kl = KL(boosting, testSet, l);
earlyStopper.add(i, kl);
terminator.add(kl);
if (earlyStopper.shouldStop() || terminator.shouldTerminate()) {
logger.info("training for label " + l + " (" + dataSet.getLabelTranslator().toExtLabel(l) + ") should stop now");
logger.info("the best number of training iterations for the label is " + earlyStopper.getBestIteration());
trainer.setShouldStop(l);
numLabelsLeftToTrain -= 1;
logger.info("the number of labels left to be trained on = " + numLabelsLeftToTrain);
}
}
}
}
}
if (numLabelsLeftToTrain == 0) {
logger.info("all label training finished");
break;
}
}
logger.info("training done");
File serializedModel = new File(output, modelName);
//todo pick best models
boosting.serialize(serializedModel);
logger.info(stopWatch.toString());
if (earlyStop) {
for (int l = 0; l < numClasses; l++) {
logger.info("----------------------------------------------------");
logger.info("test performance history for label " + l + ": " + earlyStoppers.get(l).history());
logger.info("model size for label " + l + " = " + (boosting.getRegressors(l).size() - 1));
}
}
boolean topFeaturesToFile = true;
if (topFeaturesToFile) {
logger.info("start writing top features");
int limit = config.getInt("report.topFeatures.limit");
List<TopFeatures> topFeaturesList = IntStream.range(0, boosting.getNumClasses()).mapToObj(k -> IMLGBInspector.topFeatures(boosting, k, limit)).collect(Collectors.toList());
ObjectMapper mapper = new ObjectMapper();
String file = "top_features.json";
mapper.writeValue(new File(output, file), topFeaturesList);
StringBuilder sb = new StringBuilder();
for (int l = 0; l < boosting.getNumClasses(); l++) {
sb.append("-------------------------").append("\n");
sb.append(dataSet.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
sb.append(feature.simpleString()).append(", ");
}
sb.append("\n");
}
FileUtils.writeStringToFile(new File(output, "top_features.txt"), sb.toString());
logger.info("finish writing top features");
}
}
Aggregations