Search in sources :

Example 1 with TreeNode

use of ml.shifu.shifu.core.dtrain.dt.TreeNode in project shifu by ShifuML.

the class IndependentTreeModelUtils method convertZipSpecToBinary.

public boolean convertZipSpecToBinary(File zipSpecFile, File outputGbtFile) {
    ZipInputStream zipInputStream = null;
    FileOutputStream fos = null;
    try {
        zipInputStream = new ZipInputStream(new FileInputStream(zipSpecFile));
        IndependentTreeModel treeModel = null;
        List<List<TreeNode>> trees = null;
        ZipEntry zipEntry = null;
        do {
            zipEntry = zipInputStream.getNextEntry();
            if (zipEntry != null) {
                if (zipEntry.getName().equals(MODEL_CONF)) {
                    ByteArrayOutputStream byos = new ByteArrayOutputStream();
                    IOUtils.copy(zipInputStream, byos);
                    treeModel = JSONUtils.readValue(new ByteArrayInputStream(byos.toByteArray()), IndependentTreeModel.class);
                } else if (zipEntry.getName().equals(MODEL_TREES)) {
                    DataInputStream dataInputStream = new DataInputStream(zipInputStream);
                    int size = dataInputStream.readInt();
                    trees = new ArrayList<List<TreeNode>>(size);
                    for (int i = 0; i < size; i++) {
                        int forestSize = dataInputStream.readInt();
                        List<TreeNode> forest = new ArrayList<TreeNode>(forestSize);
                        for (int j = 0; j < forestSize; j++) {
                            TreeNode treeNode = new TreeNode();
                            treeNode.readFields(dataInputStream);
                            forest.add(treeNode);
                        }
                        trees.add(forest);
                    }
                }
            }
        } while (zipEntry != null);
        if (treeModel != null && CollectionUtils.isNotEmpty(trees)) {
            treeModel.setTrees(trees);
            fos = new FileOutputStream(outputGbtFile);
            treeModel.saveToInputStream(fos);
        } else {
            return false;
        }
    } catch (IOException e) {
        logger.error("Error occurred when convert the zip format model to binary.", e);
        return false;
    } finally {
        IOUtils.closeQuietly(zipInputStream);
        IOUtils.closeQuietly(fos);
    }
    return true;
}
Also used : IndependentTreeModel(ml.shifu.shifu.core.dtrain.dt.IndependentTreeModel) ZipEntry(java.util.zip.ZipEntry) ArrayList(java.util.ArrayList) ZipInputStream(java.util.zip.ZipInputStream) TreeNode(ml.shifu.shifu.core.dtrain.dt.TreeNode) List(java.util.List) ArrayList(java.util.ArrayList)

Example 2 with TreeNode

use of ml.shifu.shifu.core.dtrain.dt.TreeNode in project shifu by ShifuML.

the class IndependentTreeModelUtils method convertBinaryToZipSpec.

public boolean convertBinaryToZipSpec(File treeModelFile, File outputZipFile) {
    FileInputStream treeModelInputStream = null;
    ZipOutputStream zipOutputStream = null;
    try {
        treeModelInputStream = new FileInputStream(treeModelFile);
        IndependentTreeModel treeModel = IndependentTreeModel.loadFromStream(treeModelInputStream);
        List<List<TreeNode>> trees = treeModel.getTrees();
        treeModel.setTrees(null);
        if (CollectionUtils.isEmpty(trees)) {
            logger.error("No trees found in the tree model.");
            return false;
        }
        zipOutputStream = new ZipOutputStream(new FileOutputStream(outputZipFile));
        ZipEntry modelEntry = new ZipEntry(MODEL_CONF);
        zipOutputStream.putNextEntry(modelEntry);
        ByteArrayOutputStream byos = new ByteArrayOutputStream();
        JSONUtils.writeValue(new OutputStreamWriter(byos), treeModel);
        zipOutputStream.write(byos.toByteArray());
        IOUtils.closeQuietly(byos);
        ZipEntry treesEntry = new ZipEntry(MODEL_TREES);
        zipOutputStream.putNextEntry(treesEntry);
        DataOutputStream dataOutputStream = new DataOutputStream(zipOutputStream);
        dataOutputStream.writeInt(trees.size());
        for (List<TreeNode> forest : trees) {
            dataOutputStream.writeInt(forest.size());
            for (TreeNode treeNode : forest) {
                treeNode.write(dataOutputStream);
            }
        }
        IOUtils.closeQuietly(dataOutputStream);
    } catch (IOException e) {
        logger.error("Error occurred when convert the tree model to zip format.", e);
        return false;
    } finally {
        IOUtils.closeQuietly(zipOutputStream);
        IOUtils.closeQuietly(treeModelInputStream);
    }
    return true;
}
Also used : IndependentTreeModel(ml.shifu.shifu.core.dtrain.dt.IndependentTreeModel) ZipEntry(java.util.zip.ZipEntry) ZipOutputStream(java.util.zip.ZipOutputStream) TreeNode(ml.shifu.shifu.core.dtrain.dt.TreeNode) List(java.util.List) ArrayList(java.util.ArrayList)

Example 3 with TreeNode

use of ml.shifu.shifu.core.dtrain.dt.TreeNode in project shifu by ShifuML.

the class TreeEnsemblePmmlCreator method convert.

public MiningModel convert(IndependentTreeModel treeModel) {
    MiningModel gbt = new MiningModel();
    MiningSchema miningSchema = new TreeModelMiningSchemaCreator(this.modelConfig, this.columnConfigList).build(null);
    gbt.setMiningSchema(miningSchema);
    if (treeModel.isClassification()) {
        gbt.setMiningFunction(MiningFunction.fromValue("classification"));
    } else {
        gbt.setMiningFunction(MiningFunction.fromValue("regression"));
    }
    gbt.setTargets(createTargets(this.modelConfig));
    Segmentation seg = new Segmentation();
    gbt.setSegmentation(seg);
    seg.setMultipleModelMethod(MultipleModelMethod.fromValue("weightedAverage"));
    List<Segment> list = seg.getSegments();
    int idCount = 0;
    // such case we only support treeModel is one element list
    if (treeModel.getTrees().size() != 1) {
        throw new RuntimeException("Bagging model cannot be supported in PMML generation.");
    }
    for (TreeNode tn : treeModel.getTrees().get(0)) {
        TreeNodePmmlElementCreator tnec = new TreeNodePmmlElementCreator(this.modelConfig, this.columnConfigList, treeModel);
        org.dmg.pmml.tree.Node root = tnec.convert(tn.getNode());
        TreeModelPmmlElementCreator tmec = new TreeModelPmmlElementCreator(this.modelConfig, this.columnConfigList);
        org.dmg.pmml.tree.TreeModel tm = tmec.convert(treeModel, root);
        tm.setModelName(String.valueOf(idCount));
        Segment segment = new Segment();
        if (treeModel.isGBDT()) {
            segment.setWeight(treeModel.getWeights().get(0).get(idCount) * treeModel.getTrees().size());
        } else {
            segment.setWeight(treeModel.getWeights().get(0).get(idCount));
        }
        segment.setId("Segement" + String.valueOf(idCount));
        idCount++;
        segment.setPredicate(new True());
        segment.setModel(tm);
        list.add(segment);
    }
    return gbt;
}
Also used : Segmentation(org.dmg.pmml.mining.Segmentation) True(org.dmg.pmml.True) Segment(org.dmg.pmml.mining.Segment) MiningModel(org.dmg.pmml.mining.MiningModel) MiningSchema(org.dmg.pmml.MiningSchema) TreeNode(ml.shifu.shifu.core.dtrain.dt.TreeNode)

Example 4 with TreeNode

use of ml.shifu.shifu.core.dtrain.dt.TreeNode in project shifu by ShifuML.

the class ExportModelProcessor method run.

/*
     * (non-Javadoc)
     * 
     * @see ml.shifu.shifu.core.processor.Processor#run()
     */
@Override
public int run() throws Exception {
    setUp(ModelStep.EXPORT);
    int status = 0;
    File pmmls = new File("pmmls");
    FileUtils.forceMkdir(pmmls);
    if (StringUtils.isBlank(type)) {
        type = PMML;
    }
    String modelsPath = pathFinder.getModelsPath(SourceType.LOCAL);
    if (type.equalsIgnoreCase(ONE_BAGGING_MODEL)) {
        if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm()) && !CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
            log.warn("Currently one bagging model is only supported in NN/GBT/RF algorithm.");
        } else {
            List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
            if (models.size() < 1) {
                log.warn("No model is found in {}.", modelsPath);
            } else {
                log.info("Convert nn models into one binary bagging model.");
                Configuration conf = new Configuration();
                Path output = new Path(pathFinder.getBaggingModelPath(SourceType.LOCAL), "model.b" + modelConfig.getAlgorithm());
                if ("nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
                    BinaryNNSerializer.save(modelConfig, columnConfigList, models, FileSystem.getLocal(conf), output);
                } else if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
                    List<List<TreeNode>> baggingTrees = new ArrayList<List<TreeNode>>();
                    for (int i = 0; i < models.size(); i++) {
                        TreeModel tm = (TreeModel) models.get(i);
                        // TreeModel only has one TreeNode instance although it is list inside
                        baggingTrees.add(tm.getIndependentTreeModel().getTrees().get(0));
                    }
                    int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
                    // numerical + categorical = # of all input
                    int inputCount = inputOutputIndex[0] + inputOutputIndex[1];
                    BinaryDTSerializer.save(modelConfig, columnConfigList, baggingTrees, modelConfig.getParams().get("Loss").toString(), inputCount, FileSystem.getLocal(conf), output);
                }
                log.info("Please find one unified bagging model in local {}.", output);
            }
        }
    } else if (type.equalsIgnoreCase(PMML)) {
        // typical pmml generation
        List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
        PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), false);
        for (int index = 0; index < models.size(); index++) {
            String path = "pmmls" + File.separator + modelConfig.getModelSetName() + Integer.toString(index) + ".pmml";
            log.info("\t Start to generate " + path);
            PMML pmml = translator.build(Arrays.asList(new BasicML[] { models.get(index) }));
            PMMLUtils.savePMML(pmml, path);
        }
    } else if (type.equalsIgnoreCase(ONE_BAGGING_PMML_MODEL)) {
        // one unified bagging pmml generation
        log.info("Convert models into one bagging pmml model {} format", type);
        if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
            log.warn("Currently one bagging pmml model is only supported in NN algorithm.");
        } else {
            List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
            PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), true);
            String path = "pmmls" + File.separator + modelConfig.getModelSetName() + ".pmml";
            log.info("\t Start to generate one unified model to: " + path);
            PMML pmml = translator.build(models);
            PMMLUtils.savePMML(pmml, path);
        }
    } else if (type.equalsIgnoreCase(COLUMN_STATS)) {
        saveColumnStatus();
    } else if (type.equalsIgnoreCase(WOE_MAPPING)) {
        List<ColumnConfig> exportCatColumns = new ArrayList<ColumnConfig>();
        List<String> catVariables = getRequestVars();
        for (ColumnConfig columnConfig : this.columnConfigList) {
            if (CollectionUtils.isEmpty(catVariables) || isRequestColumn(catVariables, columnConfig)) {
                exportCatColumns.add(columnConfig);
            }
        }
        if (CollectionUtils.isNotEmpty(exportCatColumns)) {
            List<String> woeMappings = new ArrayList<String>();
            for (ColumnConfig columnConfig : exportCatColumns) {
                String woeMapText = rebinAndExportWoeMapping(columnConfig);
                woeMappings.add(woeMapText);
            }
            FileUtils.write(new File("woemapping.txt"), StringUtils.join(woeMappings, ",\n"));
        }
    } else if (type.equalsIgnoreCase(WOE)) {
        List<String> woeInfos = new ArrayList<String>();
        for (ColumnConfig columnConfig : this.columnConfigList) {
            if (columnConfig.getBinLength() > 1 && ((columnConfig.isCategorical() && CollectionUtils.isNotEmpty(columnConfig.getBinCategory())) || (columnConfig.isNumerical() && CollectionUtils.isNotEmpty(columnConfig.getBinBoundary()) && columnConfig.getBinBoundary().size() > 1))) {
                List<String> varWoeInfos = generateWoeInfos(columnConfig);
                if (CollectionUtils.isNotEmpty(varWoeInfos)) {
                    woeInfos.addAll(varWoeInfos);
                    woeInfos.add("");
                }
            }
            FileUtils.writeLines(new File("varwoe_info.txt"), woeInfos);
        }
    } else if (type.equalsIgnoreCase(CORRELATION)) {
        // export correlation into mapping list
        if (!ShifuFileUtils.isFileExists(pathFinder.getLocalCorrelationCsvPath(), SourceType.LOCAL)) {
            log.warn("The correlation file doesn't exist. Please make sure you have ran `shifu stats -c`.");
            return 2;
        }
        return exportVariableCorr();
    } else {
        log.error("Unsupported output format - {}", type);
        status = -1;
    }
    clearUp(ModelStep.EXPORT);
    log.info("Done.");
    return status;
}
Also used : Path(org.apache.hadoop.fs.Path) Configuration(org.apache.hadoop.conf.Configuration) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BasicML(org.encog.ml.BasicML) PMMLTranslator(ml.shifu.shifu.core.pmml.PMMLTranslator) TreeModel(ml.shifu.shifu.core.TreeModel) TreeNode(ml.shifu.shifu.core.dtrain.dt.TreeNode) PMML(org.dmg.pmml.PMML) File(java.io.File)

Example 5 with TreeNode

use of ml.shifu.shifu.core.dtrain.dt.TreeNode in project shifu by ShifuML.

the class TreeModel method getFeatureImportances.

/**
 * Get feature importance of current model.
 *
 * @return map of feature importance, key is column index.
 */
public Map<Integer, MutablePair<String, Double>> getFeatureImportances() {
    Map<Integer, MutablePair<String, Double>> importancesSum = new HashMap<Integer, MutablePair<String, Double>>();
    Map<Integer, String> nameMapping = this.getIndependentTreeModel().getNumNameMapping();
    int treeSize = this.getIndependentTreeModel().getTrees().size();
    // such case we only support treeModel is one element list
    if (this.getIndependentTreeModel().getTrees().size() != 1) {
        throw new RuntimeException("Bagging model cannot be supported in Tree Model one element feature importance computing.");
    }
    for (TreeNode tree : this.getIndependentTreeModel().getTrees().get(0)) {
        // get current tree importance at first
        Map<Integer, Double> subImportances = tree.computeFeatureImportance();
        // merge feature importance from different trees
        for (Entry<Integer, Double> entry : subImportances.entrySet()) {
            String featureName = nameMapping.get(entry.getKey());
            MutablePair<String, Double> importance = MutablePair.of(featureName, entry.getValue());
            if (!importancesSum.containsKey(entry.getKey())) {
                importance.setValue(importance.getValue() / treeSize);
                importancesSum.put(entry.getKey(), importance);
            } else {
                MutablePair<String, Double> current = importancesSum.get(entry.getKey());
                current.setValue(current.getValue() + importance.getValue() / treeSize);
                importancesSum.put(entry.getKey(), current);
            }
        }
    }
    return importancesSum;
}
Also used : HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) MutablePair(org.apache.commons.lang3.tuple.MutablePair) TreeNode(ml.shifu.shifu.core.dtrain.dt.TreeNode)

Aggregations

TreeNode (ml.shifu.shifu.core.dtrain.dt.TreeNode)5 ArrayList (java.util.ArrayList)2 List (java.util.List)2 ZipEntry (java.util.zip.ZipEntry)2 IndependentTreeModel (ml.shifu.shifu.core.dtrain.dt.IndependentTreeModel)2 File (java.io.File)1 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1 ZipInputStream (java.util.zip.ZipInputStream)1 ZipOutputStream (java.util.zip.ZipOutputStream)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 TreeModel (ml.shifu.shifu.core.TreeModel)1 PMMLTranslator (ml.shifu.shifu.core.pmml.PMMLTranslator)1 MutablePair (org.apache.commons.lang3.tuple.MutablePair)1 Configuration (org.apache.hadoop.conf.Configuration)1 Path (org.apache.hadoop.fs.Path)1 MiningSchema (org.dmg.pmml.MiningSchema)1 PMML (org.dmg.pmml.PMML)1 True (org.dmg.pmml.True)1 MiningModel (org.dmg.pmml.mining.MiningModel)1