use of ml.shifu.shifu.core.TreeModel in project shifu by ShifuML.
the class DTMaster method init.
@Override
public void init(MasterContext<DTMasterParams, DTWorkerParams> context) {
Properties props = context.getProps();
// init model config and column config list at first
SourceType sourceType;
try {
sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType);
this.columnConfigList = CommonUtils.loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
} catch (IOException e) {
throw new RuntimeException(e);
}
// worker number is used to estimate nodes per iteration for stats
this.workerNumber = NumberFormatUtils.getInt(props.getProperty(GuaguaConstants.GUAGUA_WORKER_NUMBER), true);
// check if variables are set final selected
int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
this.inputNum = inputOutputIndex[0] + inputOutputIndex[1];
this.isAfterVarSelect = (inputOutputIndex[3] == 1);
// cache all feature list for sampling features
this.allFeatures = this.getAllFeatureList(columnConfigList, isAfterVarSelect);
int trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
// If grid search, select valid paramters, if not parameters is what in ModelConfig.json
GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
Map<String, Object> validParams = this.modelConfig.getTrain().getParams();
if (gs.hasHyperParam()) {
validParams = gs.getParams(trainerId);
LOG.info("Start grid search master with params: {}", validParams);
}
Object vtObj = validParams.get("ValidationTolerance");
if (vtObj != null) {
try {
validationTolerance = Double.parseDouble(vtObj.toString());
LOG.warn("Validation by tolerance is enabled with value {}.", validationTolerance);
} catch (NumberFormatException ee) {
validationTolerance = 0d;
LOG.warn("Validation by tolerance isn't enabled because of non numerical value of ValidationTolerance: {}.", vtObj);
}
} else {
LOG.warn("Validation by tolerance isn't enabled.");
}
// tree related parameters initialization
Object fssObj = validParams.get("FeatureSubsetStrategy");
if (fssObj != null) {
try {
this.featureSubsetRate = Double.parseDouble(fssObj.toString());
// no need validate featureSubsetRate is in (0,1], as already validated in ModelInspector
this.featureSubsetStrategy = null;
} catch (NumberFormatException ee) {
this.featureSubsetStrategy = FeatureSubsetStrategy.of(fssObj.toString());
}
} else {
LOG.warn("FeatureSubsetStrategy is not set, set to TWOTHRIDS by default in DTMaster.");
this.featureSubsetStrategy = FeatureSubsetStrategy.TWOTHIRDS;
this.featureSubsetRate = 0;
}
// max depth
Object maxDepthObj = validParams.get("MaxDepth");
if (maxDepthObj != null) {
this.maxDepth = Integer.valueOf(maxDepthObj.toString());
} else {
this.maxDepth = 10;
}
// max leaves which is used for leaf-wised tree building, TODO add more benchmarks
Object maxLeavesObj = validParams.get("MaxLeaves");
if (maxLeavesObj != null) {
this.maxLeaves = Integer.valueOf(maxLeavesObj.toString());
} else {
this.maxLeaves = -1;
}
// enable leaf wise tree building once maxLeaves is configured
if (this.maxLeaves > 0) {
this.isLeafWise = true;
}
// maxBatchSplitSize means each time split # of batch nodes
Object maxBatchSplitSizeObj = validParams.get("MaxBatchSplitSize");
if (maxBatchSplitSizeObj != null) {
this.maxBatchSplitSize = Integer.valueOf(maxBatchSplitSizeObj.toString());
} else {
// by default split 32 at most in a batch
this.maxBatchSplitSize = 32;
}
assert this.maxDepth > 0 && this.maxDepth <= 20;
// hide in parameters, this to avoid OOM issue for each iteration
Object maxStatsMemoryMB = validParams.get("MaxStatsMemoryMB");
if (maxStatsMemoryMB != null) {
this.maxStatsMemory = Long.valueOf(validParams.get("MaxStatsMemoryMB").toString()) * 1024 * 1024;
if (this.maxStatsMemory > ((2L * Runtime.getRuntime().maxMemory()) / 3)) {
// if >= 2/3 max memory, take 2/3 max memory to avoid OOM
this.maxStatsMemory = ((2L * Runtime.getRuntime().maxMemory()) / 3);
}
} else {
// by default it is 1/2 of heap, about 1.5G setting in current Shifu
this.maxStatsMemory = Runtime.getRuntime().maxMemory() / 2L;
}
// assert this.maxStatsMemory <= Math.min(Runtime.getRuntime().maxMemory() * 0.6, 800 * 1024 * 1024L);
this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString());
this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
if (this.isGBDT) {
// learning rate only effective in gbdt
this.learningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString());
}
// initialize impurity type according to regression or classfication
String imStr = validParams.get("Impurity").toString();
int numClasses = 2;
if (this.modelConfig.isClassification()) {
numClasses = this.modelConfig.getTags().size();
}
// these two parameters is to stop tree growth parameters
int minInstancesPerNode = Integer.valueOf(validParams.get("MinInstancesPerNode").toString());
double minInfoGain = Double.valueOf(validParams.get("MinInfoGain").toString());
if (imStr.equalsIgnoreCase("entropy")) {
impurity = new Entropy(numClasses, minInstancesPerNode, minInfoGain);
} else if (imStr.equalsIgnoreCase("gini")) {
impurity = new Gini(numClasses, minInstancesPerNode, minInfoGain);
} else {
impurity = new Variance(minInstancesPerNode, minInfoGain);
}
// checkpoint folder and interval (every # iterations to do checkpoint)
this.checkpointInterval = NumberFormatUtils.getInt(context.getProps().getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_INTERVAL, "20"));
this.checkpointOutput = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_FOLDER, "tmp/cp_" + context.getAppId()));
// cache conf to avoid new
this.conf = new Configuration();
// if continuous model training is enabled
this.isContinuousEnabled = Boolean.TRUE.toString().equalsIgnoreCase(context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING));
this.dtEarlyStopDecider = new DTEarlyStopDecider(this.maxDepth);
if (validParams.containsKey("EnableEarlyStop") && Boolean.valueOf(validParams.get("EnableEarlyStop").toString().toLowerCase())) {
this.enableEarlyStop = true;
}
LOG.info("Master init params: isAfterVarSel={}, featureSubsetStrategy={}, featureSubsetRate={} maxDepth={}, maxStatsMemory={}, " + "treeNum={}, impurity={}, workerNumber={}, minInstancesPerNode={}, minInfoGain={}, isRF={}, " + "isGBDT={}, isContinuousEnabled={}, enableEarlyStop={}.", isAfterVarSelect, featureSubsetStrategy, this.featureSubsetRate, maxDepth, maxStatsMemory, treeNum, imStr, this.workerNumber, minInstancesPerNode, minInfoGain, this.isRF, this.isGBDT, this.isContinuousEnabled, this.enableEarlyStop);
this.toDoQueue = new LinkedList<TreeNode>();
if (this.isLeafWise) {
this.toSplitQueue = new PriorityQueue<TreeNode>(64, new Comparator<TreeNode>() {
@Override
public int compare(TreeNode o1, TreeNode o2) {
return Double.compare(o2.getNode().getWgtCntRatio() * o2.getNode().getGain(), o1.getNode().getWgtCntRatio() * o1.getNode().getGain());
}
});
}
// initialize trees
if (context.isFirstIteration()) {
if (this.isRF) {
// for random forest, trees are trained in parallel
this.trees = new CopyOnWriteArrayList<TreeNode>();
for (int i = 0; i < treeNum; i++) {
this.trees.add(new TreeNode(i, new Node(Node.ROOT_INDEX), 1d));
}
}
if (this.isGBDT) {
if (isContinuousEnabled) {
TreeModel existingModel;
try {
Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
existingModel = (TreeModel) ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
if (existingModel == null) {
// null means no existing model file or model file is in wrong format
this.trees = new CopyOnWriteArrayList<TreeNode>();
// learning rate is 1 for 1st
this.trees.add(new TreeNode(0, new Node(Node.ROOT_INDEX), 1d));
LOG.info("Starting to train model from scratch and existing model is empty.");
} else {
this.trees = existingModel.getTrees();
this.existingTreeSize = this.trees.size();
// starting from existing models, first tree learning rate is current learning rate
this.trees.add(new TreeNode(this.existingTreeSize, new Node(Node.ROOT_INDEX), this.existingTreeSize == 0 ? 1d : this.learningRate));
LOG.info("Starting to train model from existing model {} with existing trees {}.", modelPath, existingTreeSize);
}
} catch (IOException e) {
throw new GuaguaRuntimeException(e);
}
} else {
this.trees = new CopyOnWriteArrayList<TreeNode>();
// for GBDT, initialize the first tree. trees are trained sequentially,first tree learning rate is 1
this.trees.add(new TreeNode(0, new Node(Node.ROOT_INDEX), 1.0d));
}
}
} else {
// recover all states once master is fail over
LOG.info("Recover master status from checkpoint file {}", this.checkpointOutput);
recoverMasterStatus(sourceType);
}
}
use of ml.shifu.shifu.core.TreeModel in project shifu by ShifuML.
the class CommonUtils method computeTreeModelFeatureImportance.
/**
* Compute feature importance for all bagging tree models.
*
* @param models
* the tree models, should be instance of TreeModel
* @return feature importance per each column id
* @throws IllegalStateException
* if no any feature importance from models
*/
public static Map<Integer, MutablePair<String, Double>> computeTreeModelFeatureImportance(List<BasicML> models) {
List<Map<Integer, MutablePair<String, Double>>> importanceList = new ArrayList<Map<Integer, MutablePair<String, Double>>>();
for (BasicML basicModel : models) {
if (basicModel instanceof TreeModel) {
TreeModel model = (TreeModel) basicModel;
Map<Integer, MutablePair<String, Double>> importances = model.getFeatureImportances();
importanceList.add(importances);
}
}
if (importanceList.size() < 1) {
throw new IllegalStateException("Feature importance calculation abort due to no tree model found!!");
}
return mergeImportanceList(importanceList);
}
use of ml.shifu.shifu.core.TreeModel in project shifu by ShifuML.
the class TreeEnsemblePMMLTranslator method build.
public PMML build(BasicML basicML) {
PMML pmml = new PMML();
Header header = new Header();
pmml.setHeader(header);
header.setCopyright(" Copyright [2013-2017] PayPal Software Foundation\n" + "\n" + " Licensed under the Apache License, Version 2.0 (the \"License\");\n" + " you may not use this file except in compliance with the License.\n" + " You may obtain a copy of the License at\n" + "\n" + " http://www.apache.org/licenses/LICENSE-2.0\n" + "\n" + " Unless required by applicable law or agreed to in writing, software\n" + " distributed under the License is distributed on an \"AS IS\" BASIS,\n" + " WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + " See the License for the specific language governing permissions and\n" + " limitations under the License.\n");
Application application = new Application();
header.setApplication(application);
String findContainingJar = JarManager.findContainingJar(TreeEnsemblePMMLTranslator.class);
JarFile jar = null;
try {
jar = new JarFile(findContainingJar);
final Manifest manifest = jar.getManifest();
String vendor = manifest.getMainAttributes().getValue("vendor");
application.setName(vendor);
String version = manifest.getMainAttributes().getValue("version");
application.setVersion(version);
} catch (Exception e) {
LOG.warn(e.getMessage());
} finally {
if (jar != null) {
try {
jar.close();
} catch (IOException e) {
LOG.warn(e.getMessage());
}
}
}
pmml.setDataDictionary(dataDictionaryCreator.build(basicML));
List<Model> models = pmml.getModels();
Model miningModel = modelCreator.convert(((TreeModel) basicML).getIndependentTreeModel());
models.add(miningModel);
return pmml;
}
use of ml.shifu.shifu.core.TreeModel in project shifu by ShifuML.
the class DTWorker method init.
@Override
public void init(WorkerContext<DTMasterParams, DTWorkerParams> context) {
Properties props = context.getProps();
try {
SourceType sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType);
this.columnConfigList = CommonUtils.loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
} catch (IOException e) {
throw new RuntimeException(e);
}
this.columnCategoryIndexMapping = new HashMap<Integer, Map<String, Integer>>();
for (ColumnConfig config : this.columnConfigList) {
if (config.isCategorical()) {
if (config.getBinCategory() != null) {
Map<String, Integer> tmpMap = new HashMap<String, Integer>();
for (int i = 0; i < config.getBinCategory().size(); i++) {
List<String> catVals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i));
for (String cval : catVals) {
tmpMap.put(cval, i);
}
}
this.columnCategoryIndexMapping.put(config.getColumnNum(), tmpMap);
}
}
}
this.hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
// create Splitter
String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);
Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
if (kCrossValidation != null && kCrossValidation > 0) {
isKFoldCV = true;
LOG.info("Cross validation is enabled by kCrossValidation: {}.", kCrossValidation);
}
Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight();
if (Double.compare(upSampleWeight, 1d) != 0 && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()))) {
// set mean to upSampleWeight -1 and get sample + 1 to make sure no zero sample value
LOG.info("Enable up sampling with weight {}.", upSampleWeight);
this.upSampleRng = new PoissonDistribution(upSampleWeight - 1);
}
this.isContinuousEnabled = Boolean.TRUE.toString().equalsIgnoreCase(context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING));
this.workerThreadCount = modelConfig.getTrain().getWorkerThreadCount();
this.threadPool = Executors.newFixedThreadPool(this.workerThreadCount);
// enable shut down logic
context.addCompletionCallBack(new WorkerCompletionCallBack<DTMasterParams, DTWorkerParams>() {
@Override
public void callback(WorkerContext<DTMasterParams, DTWorkerParams> context) {
DTWorker.this.threadPool.shutdownNow();
try {
DTWorker.this.threadPool.awaitTermination(2, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
});
this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
this.isOneVsAll = modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll();
GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
Map<String, Object> validParams = this.modelConfig.getTrain().getParams();
if (gs.hasHyperParam()) {
validParams = gs.getParams(this.trainerId);
LOG.info("Start grid search worker with params: {}", validParams);
}
this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString());
double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6"));
LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction);
double validationRate = this.modelConfig.getValidSetRate();
if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) {
// fixed 0.6 and 0.4 of max memory for trainingData and validationData
this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.6), new ArrayList<Data>());
this.validationData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.4), new ArrayList<Data>());
} else {
if (Double.compare(validationRate, 0d) != 0) {
this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * (1 - validationRate)), new ArrayList<Data>());
this.validationData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * validationRate), new ArrayList<Data>());
} else {
this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction), new ArrayList<Data>());
}
}
int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
// numerical + categorical = # of all input
this.inputCount = inputOutputIndex[0] + inputOutputIndex[1];
// regression outputNodeCount is 1, binaryClassfication, it is 1, OneVsAll it is 1, Native classification it is
// 1, with index of 0,1,2,3 denotes different classes
this.isAfterVarSelect = (inputOutputIndex[3] == 1);
this.isManualValidation = (modelConfig.getValidationDataSetRawPath() != null && !"".equals(modelConfig.getValidationDataSetRawPath()));
int numClasses = this.modelConfig.isClassification() ? this.modelConfig.getTags().size() : 2;
String imStr = validParams.get("Impurity").toString();
int minInstancesPerNode = Integer.valueOf(validParams.get("MinInstancesPerNode").toString());
double minInfoGain = Double.valueOf(validParams.get("MinInfoGain").toString());
if (imStr.equalsIgnoreCase("entropy")) {
impurity = new Entropy(numClasses, minInstancesPerNode, minInfoGain);
} else if (imStr.equalsIgnoreCase("gini")) {
impurity = new Gini(numClasses, minInstancesPerNode, minInfoGain);
} else if (imStr.equalsIgnoreCase("friedmanmse")) {
impurity = new FriedmanMSE(minInstancesPerNode, minInfoGain);
} else {
impurity = new Variance(minInstancesPerNode, minInfoGain);
}
this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
String lossStr = validParams.get("Loss").toString();
if (lossStr.equalsIgnoreCase("log")) {
this.loss = new LogLoss();
} else if (lossStr.equalsIgnoreCase("absolute")) {
this.loss = new AbsoluteLoss();
} else if (lossStr.equalsIgnoreCase("halfgradsquared")) {
this.loss = new HalfGradSquaredLoss();
} else if (lossStr.equalsIgnoreCase("squared")) {
this.loss = new SquaredLoss();
} else {
try {
this.loss = (Loss) ClassUtils.newInstance(Class.forName(lossStr));
} catch (ClassNotFoundException e) {
LOG.warn("Class not found for {}, using default SquaredLoss", lossStr);
this.loss = new SquaredLoss();
}
}
if (this.isGBDT) {
this.learningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString());
Object swrObj = validParams.get("GBTSampleWithReplacement");
if (swrObj != null) {
this.gbdtSampleWithReplacement = Boolean.TRUE.toString().equalsIgnoreCase(swrObj.toString());
}
Object dropoutObj = validParams.get(CommonConstants.DROPOUT_RATE);
if (dropoutObj != null) {
this.dropOutRate = Double.valueOf(dropoutObj.toString());
}
}
this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample();
this.checkpointOutput = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_FOLDER, "tmp/cp_" + context.getAppId()));
LOG.info("Worker init params:isAfterVarSel={}, treeNum={}, impurity={}, loss={}, learningRate={}, gbdtSampleWithReplacement={}, isRF={}, isGBDT={}, isStratifiedSampling={}, isKFoldCV={}, kCrossValidation={}, dropOutRate={}", isAfterVarSelect, treeNum, impurity.getClass().getName(), loss.getClass().getName(), this.learningRate, this.gbdtSampleWithReplacement, this.isRF, this.isGBDT, this.isStratifiedSampling, this.isKFoldCV, kCrossValidation, this.dropOutRate);
// for fail over, load existing trees
if (!context.isFirstIteration()) {
if (this.isGBDT) {
// set flag here and recover later in doComputing, this is to make sure recover after load part which
// can load latest trees in #doCompute
isNeedRecoverGBDTPredict = true;
} else {
// RF , trees are recovered from last master results
recoverTrees = context.getLastMasterResult().getTrees();
}
}
if (context.isFirstIteration() && this.isContinuousEnabled && this.isGBDT) {
Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
TreeModel existingModel = null;
try {
existingModel = (TreeModel) ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
} catch (IOException e) {
LOG.error("Error in get existing model, will ignore and start from scratch", e);
}
if (existingModel == null) {
LOG.warn("No model is found even set to continuous model training.");
return;
} else {
recoverTrees = existingModel.getTrees();
LOG.info("Loading existing {} trees", recoverTrees.size());
}
}
}
use of ml.shifu.shifu.core.TreeModel in project shifu by ShifuML.
the class ExportModelProcessor method run.
/*
* (non-Javadoc)
*
* @see ml.shifu.shifu.core.processor.Processor#run()
*/
@Override
public int run() throws Exception {
setUp(ModelStep.EXPORT);
int status = 0;
File pmmls = new File("pmmls");
FileUtils.forceMkdir(pmmls);
if (StringUtils.isBlank(type)) {
type = PMML;
}
String modelsPath = pathFinder.getModelsPath(SourceType.LOCAL);
if (type.equalsIgnoreCase(ONE_BAGGING_MODEL)) {
if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm()) && !CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
log.warn("Currently one bagging model is only supported in NN/GBT/RF algorithm.");
} else {
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
if (models.size() < 1) {
log.warn("No model is found in {}.", modelsPath);
} else {
log.info("Convert nn models into one binary bagging model.");
Configuration conf = new Configuration();
Path output = new Path(pathFinder.getBaggingModelPath(SourceType.LOCAL), "model.b" + modelConfig.getAlgorithm());
if ("nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
BinaryNNSerializer.save(modelConfig, columnConfigList, models, FileSystem.getLocal(conf), output);
} else if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
List<List<TreeNode>> baggingTrees = new ArrayList<List<TreeNode>>();
for (int i = 0; i < models.size(); i++) {
TreeModel tm = (TreeModel) models.get(i);
// TreeModel only has one TreeNode instance although it is list inside
baggingTrees.add(tm.getIndependentTreeModel().getTrees().get(0));
}
int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
// numerical + categorical = # of all input
int inputCount = inputOutputIndex[0] + inputOutputIndex[1];
BinaryDTSerializer.save(modelConfig, columnConfigList, baggingTrees, modelConfig.getParams().get("Loss").toString(), inputCount, FileSystem.getLocal(conf), output);
}
log.info("Please find one unified bagging model in local {}.", output);
}
}
} else if (type.equalsIgnoreCase(PMML)) {
// typical pmml generation
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), false);
for (int index = 0; index < models.size(); index++) {
String path = "pmmls" + File.separator + modelConfig.getModelSetName() + Integer.toString(index) + ".pmml";
log.info("\t Start to generate " + path);
PMML pmml = translator.build(Arrays.asList(new BasicML[] { models.get(index) }));
PMMLUtils.savePMML(pmml, path);
}
} else if (type.equalsIgnoreCase(ONE_BAGGING_PMML_MODEL)) {
// one unified bagging pmml generation
log.info("Convert models into one bagging pmml model {} format", type);
if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
log.warn("Currently one bagging pmml model is only supported in NN algorithm.");
} else {
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), true);
String path = "pmmls" + File.separator + modelConfig.getModelSetName() + ".pmml";
log.info("\t Start to generate one unified model to: " + path);
PMML pmml = translator.build(models);
PMMLUtils.savePMML(pmml, path);
}
} else if (type.equalsIgnoreCase(COLUMN_STATS)) {
saveColumnStatus();
} else if (type.equalsIgnoreCase(WOE_MAPPING)) {
List<ColumnConfig> exportCatColumns = new ArrayList<ColumnConfig>();
List<String> catVariables = getRequestVars();
for (ColumnConfig columnConfig : this.columnConfigList) {
if (CollectionUtils.isEmpty(catVariables) || isRequestColumn(catVariables, columnConfig)) {
exportCatColumns.add(columnConfig);
}
}
if (CollectionUtils.isNotEmpty(exportCatColumns)) {
List<String> woeMappings = new ArrayList<String>();
for (ColumnConfig columnConfig : exportCatColumns) {
String woeMapText = rebinAndExportWoeMapping(columnConfig);
woeMappings.add(woeMapText);
}
FileUtils.write(new File("woemapping.txt"), StringUtils.join(woeMappings, ",\n"));
}
} else if (type.equalsIgnoreCase(WOE)) {
List<String> woeInfos = new ArrayList<String>();
for (ColumnConfig columnConfig : this.columnConfigList) {
if (columnConfig.getBinLength() > 1 && ((columnConfig.isCategorical() && CollectionUtils.isNotEmpty(columnConfig.getBinCategory())) || (columnConfig.isNumerical() && CollectionUtils.isNotEmpty(columnConfig.getBinBoundary()) && columnConfig.getBinBoundary().size() > 1))) {
List<String> varWoeInfos = generateWoeInfos(columnConfig);
if (CollectionUtils.isNotEmpty(varWoeInfos)) {
woeInfos.addAll(varWoeInfos);
woeInfos.add("");
}
}
FileUtils.writeLines(new File("varwoe_info.txt"), woeInfos);
}
} else if (type.equalsIgnoreCase(CORRELATION)) {
// export correlation into mapping list
if (!ShifuFileUtils.isFileExists(pathFinder.getLocalCorrelationCsvPath(), SourceType.LOCAL)) {
log.warn("The correlation file doesn't exist. Please make sure you have ran `shifu stats -c`.");
return 2;
}
return exportVariableCorr();
} else {
log.error("Unsupported output format - {}", type);
status = -1;
}
clearUp(ModelStep.EXPORT);
log.info("Done.");
return status;
}
Aggregations