use of ml.shifu.shifu.core.dtrain.dt.TreeNode in project shifu by ShifuML.
the class IndependentTreeModelUtils method convertZipSpecToBinary.
public boolean convertZipSpecToBinary(File zipSpecFile, File outputGbtFile) {
ZipInputStream zipInputStream = null;
FileOutputStream fos = null;
try {
zipInputStream = new ZipInputStream(new FileInputStream(zipSpecFile));
IndependentTreeModel treeModel = null;
List<List<TreeNode>> trees = null;
ZipEntry zipEntry = null;
do {
zipEntry = zipInputStream.getNextEntry();
if (zipEntry != null) {
if (zipEntry.getName().equals(MODEL_CONF)) {
ByteArrayOutputStream byos = new ByteArrayOutputStream();
IOUtils.copy(zipInputStream, byos);
treeModel = JSONUtils.readValue(new ByteArrayInputStream(byos.toByteArray()), IndependentTreeModel.class);
} else if (zipEntry.getName().equals(MODEL_TREES)) {
DataInputStream dataInputStream = new DataInputStream(zipInputStream);
int size = dataInputStream.readInt();
trees = new ArrayList<List<TreeNode>>(size);
for (int i = 0; i < size; i++) {
int forestSize = dataInputStream.readInt();
List<TreeNode> forest = new ArrayList<TreeNode>(forestSize);
for (int j = 0; j < forestSize; j++) {
TreeNode treeNode = new TreeNode();
treeNode.readFields(dataInputStream);
forest.add(treeNode);
}
trees.add(forest);
}
}
}
} while (zipEntry != null);
if (treeModel != null && CollectionUtils.isNotEmpty(trees)) {
treeModel.setTrees(trees);
fos = new FileOutputStream(outputGbtFile);
treeModel.saveToInputStream(fos);
} else {
return false;
}
} catch (IOException e) {
logger.error("Error occurred when convert the zip format model to binary.", e);
return false;
} finally {
IOUtils.closeQuietly(zipInputStream);
IOUtils.closeQuietly(fos);
}
return true;
}
use of ml.shifu.shifu.core.dtrain.dt.TreeNode in project shifu by ShifuML.
the class IndependentTreeModelUtils method convertBinaryToZipSpec.
public boolean convertBinaryToZipSpec(File treeModelFile, File outputZipFile) {
FileInputStream treeModelInputStream = null;
ZipOutputStream zipOutputStream = null;
try {
treeModelInputStream = new FileInputStream(treeModelFile);
IndependentTreeModel treeModel = IndependentTreeModel.loadFromStream(treeModelInputStream);
List<List<TreeNode>> trees = treeModel.getTrees();
treeModel.setTrees(null);
if (CollectionUtils.isEmpty(trees)) {
logger.error("No trees found in the tree model.");
return false;
}
zipOutputStream = new ZipOutputStream(new FileOutputStream(outputZipFile));
ZipEntry modelEntry = new ZipEntry(MODEL_CONF);
zipOutputStream.putNextEntry(modelEntry);
ByteArrayOutputStream byos = new ByteArrayOutputStream();
JSONUtils.writeValue(new OutputStreamWriter(byos), treeModel);
zipOutputStream.write(byos.toByteArray());
IOUtils.closeQuietly(byos);
ZipEntry treesEntry = new ZipEntry(MODEL_TREES);
zipOutputStream.putNextEntry(treesEntry);
DataOutputStream dataOutputStream = new DataOutputStream(zipOutputStream);
dataOutputStream.writeInt(trees.size());
for (List<TreeNode> forest : trees) {
dataOutputStream.writeInt(forest.size());
for (TreeNode treeNode : forest) {
treeNode.write(dataOutputStream);
}
}
IOUtils.closeQuietly(dataOutputStream);
} catch (IOException e) {
logger.error("Error occurred when convert the tree model to zip format.", e);
return false;
} finally {
IOUtils.closeQuietly(zipOutputStream);
IOUtils.closeQuietly(treeModelInputStream);
}
return true;
}
use of ml.shifu.shifu.core.dtrain.dt.TreeNode in project shifu by ShifuML.
the class TreeEnsemblePmmlCreator method convert.
public MiningModel convert(IndependentTreeModel treeModel) {
MiningModel gbt = new MiningModel();
MiningSchema miningSchema = new TreeModelMiningSchemaCreator(this.modelConfig, this.columnConfigList).build(null);
gbt.setMiningSchema(miningSchema);
if (treeModel.isClassification()) {
gbt.setMiningFunction(MiningFunction.fromValue("classification"));
} else {
gbt.setMiningFunction(MiningFunction.fromValue("regression"));
}
gbt.setTargets(createTargets(this.modelConfig));
Segmentation seg = new Segmentation();
gbt.setSegmentation(seg);
seg.setMultipleModelMethod(MultipleModelMethod.fromValue("weightedAverage"));
List<Segment> list = seg.getSegments();
int idCount = 0;
// such case we only support treeModel is one element list
if (treeModel.getTrees().size() != 1) {
throw new RuntimeException("Bagging model cannot be supported in PMML generation.");
}
for (TreeNode tn : treeModel.getTrees().get(0)) {
TreeNodePmmlElementCreator tnec = new TreeNodePmmlElementCreator(this.modelConfig, this.columnConfigList, treeModel);
org.dmg.pmml.tree.Node root = tnec.convert(tn.getNode());
TreeModelPmmlElementCreator tmec = new TreeModelPmmlElementCreator(this.modelConfig, this.columnConfigList);
org.dmg.pmml.tree.TreeModel tm = tmec.convert(treeModel, root);
tm.setModelName(String.valueOf(idCount));
Segment segment = new Segment();
if (treeModel.isGBDT()) {
segment.setWeight(treeModel.getWeights().get(0).get(idCount) * treeModel.getTrees().size());
} else {
segment.setWeight(treeModel.getWeights().get(0).get(idCount));
}
segment.setId("Segement" + String.valueOf(idCount));
idCount++;
segment.setPredicate(new True());
segment.setModel(tm);
list.add(segment);
}
return gbt;
}
use of ml.shifu.shifu.core.dtrain.dt.TreeNode in project shifu by ShifuML.
the class ExportModelProcessor method run.
/*
* (non-Javadoc)
*
* @see ml.shifu.shifu.core.processor.Processor#run()
*/
@Override
public int run() throws Exception {
setUp(ModelStep.EXPORT);
int status = 0;
File pmmls = new File("pmmls");
FileUtils.forceMkdir(pmmls);
if (StringUtils.isBlank(type)) {
type = PMML;
}
String modelsPath = pathFinder.getModelsPath(SourceType.LOCAL);
if (type.equalsIgnoreCase(ONE_BAGGING_MODEL)) {
if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm()) && !CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
log.warn("Currently one bagging model is only supported in NN/GBT/RF algorithm.");
} else {
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
if (models.size() < 1) {
log.warn("No model is found in {}.", modelsPath);
} else {
log.info("Convert nn models into one binary bagging model.");
Configuration conf = new Configuration();
Path output = new Path(pathFinder.getBaggingModelPath(SourceType.LOCAL), "model.b" + modelConfig.getAlgorithm());
if ("nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
BinaryNNSerializer.save(modelConfig, columnConfigList, models, FileSystem.getLocal(conf), output);
} else if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
List<List<TreeNode>> baggingTrees = new ArrayList<List<TreeNode>>();
for (int i = 0; i < models.size(); i++) {
TreeModel tm = (TreeModel) models.get(i);
// TreeModel only has one TreeNode instance although it is list inside
baggingTrees.add(tm.getIndependentTreeModel().getTrees().get(0));
}
int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
// numerical + categorical = # of all input
int inputCount = inputOutputIndex[0] + inputOutputIndex[1];
BinaryDTSerializer.save(modelConfig, columnConfigList, baggingTrees, modelConfig.getParams().get("Loss").toString(), inputCount, FileSystem.getLocal(conf), output);
}
log.info("Please find one unified bagging model in local {}.", output);
}
}
} else if (type.equalsIgnoreCase(PMML)) {
// typical pmml generation
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), false);
for (int index = 0; index < models.size(); index++) {
String path = "pmmls" + File.separator + modelConfig.getModelSetName() + Integer.toString(index) + ".pmml";
log.info("\t Start to generate " + path);
PMML pmml = translator.build(Arrays.asList(new BasicML[] { models.get(index) }));
PMMLUtils.savePMML(pmml, path);
}
} else if (type.equalsIgnoreCase(ONE_BAGGING_PMML_MODEL)) {
// one unified bagging pmml generation
log.info("Convert models into one bagging pmml model {} format", type);
if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
log.warn("Currently one bagging pmml model is only supported in NN algorithm.");
} else {
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), true);
String path = "pmmls" + File.separator + modelConfig.getModelSetName() + ".pmml";
log.info("\t Start to generate one unified model to: " + path);
PMML pmml = translator.build(models);
PMMLUtils.savePMML(pmml, path);
}
} else if (type.equalsIgnoreCase(COLUMN_STATS)) {
saveColumnStatus();
} else if (type.equalsIgnoreCase(WOE_MAPPING)) {
List<ColumnConfig> exportCatColumns = new ArrayList<ColumnConfig>();
List<String> catVariables = getRequestVars();
for (ColumnConfig columnConfig : this.columnConfigList) {
if (CollectionUtils.isEmpty(catVariables) || isRequestColumn(catVariables, columnConfig)) {
exportCatColumns.add(columnConfig);
}
}
if (CollectionUtils.isNotEmpty(exportCatColumns)) {
List<String> woeMappings = new ArrayList<String>();
for (ColumnConfig columnConfig : exportCatColumns) {
String woeMapText = rebinAndExportWoeMapping(columnConfig);
woeMappings.add(woeMapText);
}
FileUtils.write(new File("woemapping.txt"), StringUtils.join(woeMappings, ",\n"));
}
} else if (type.equalsIgnoreCase(WOE)) {
List<String> woeInfos = new ArrayList<String>();
for (ColumnConfig columnConfig : this.columnConfigList) {
if (columnConfig.getBinLength() > 1 && ((columnConfig.isCategorical() && CollectionUtils.isNotEmpty(columnConfig.getBinCategory())) || (columnConfig.isNumerical() && CollectionUtils.isNotEmpty(columnConfig.getBinBoundary()) && columnConfig.getBinBoundary().size() > 1))) {
List<String> varWoeInfos = generateWoeInfos(columnConfig);
if (CollectionUtils.isNotEmpty(varWoeInfos)) {
woeInfos.addAll(varWoeInfos);
woeInfos.add("");
}
}
FileUtils.writeLines(new File("varwoe_info.txt"), woeInfos);
}
} else if (type.equalsIgnoreCase(CORRELATION)) {
// export correlation into mapping list
if (!ShifuFileUtils.isFileExists(pathFinder.getLocalCorrelationCsvPath(), SourceType.LOCAL)) {
log.warn("The correlation file doesn't exist. Please make sure you have ran `shifu stats -c`.");
return 2;
}
return exportVariableCorr();
} else {
log.error("Unsupported output format - {}", type);
status = -1;
}
clearUp(ModelStep.EXPORT);
log.info("Done.");
return status;
}
use of ml.shifu.shifu.core.dtrain.dt.TreeNode in project shifu by ShifuML.
the class TreeModel method getFeatureImportances.
/**
* Get feature importance of current model.
*
* @return map of feature importance, key is column index.
*/
public Map<Integer, MutablePair<String, Double>> getFeatureImportances() {
Map<Integer, MutablePair<String, Double>> importancesSum = new HashMap<Integer, MutablePair<String, Double>>();
Map<Integer, String> nameMapping = this.getIndependentTreeModel().getNumNameMapping();
int treeSize = this.getIndependentTreeModel().getTrees().size();
// such case we only support treeModel is one element list
if (this.getIndependentTreeModel().getTrees().size() != 1) {
throw new RuntimeException("Bagging model cannot be supported in Tree Model one element feature importance computing.");
}
for (TreeNode tree : this.getIndependentTreeModel().getTrees().get(0)) {
// get current tree importance at first
Map<Integer, Double> subImportances = tree.computeFeatureImportance();
// merge feature importance from different trees
for (Entry<Integer, Double> entry : subImportances.entrySet()) {
String featureName = nameMapping.get(entry.getKey());
MutablePair<String, Double> importance = MutablePair.of(featureName, entry.getValue());
if (!importancesSum.containsKey(entry.getKey())) {
importance.setValue(importance.getValue() / treeSize);
importancesSum.put(entry.getKey(), importance);
} else {
MutablePair<String, Double> current = importancesSum.get(entry.getKey());
current.setValue(current.getValue() + importance.getValue() / treeSize);
importancesSum.put(entry.getKey(), current);
}
}
}
return importancesSum;
}
Aggregations