Search in sources :

Example 1 with NNColumnStats

use of ml.shifu.shifu.core.dtrain.nn.NNColumnStats in project shifu by ShifuML.

the class BinaryWDLSerializer method save.

public static void save(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, WideAndDeep wideAndDeep, FileSystem fs, Path output) throws IOException {
    DataOutputStream fos = null;
    try {
        fos = new DataOutputStream(new GZIPOutputStream(fs.create(output)));
        // version
        fos.writeInt(CommonConstants.WDL_FORMAT_VERSION);
        // Reserved two float field, one double field and one string field
        fos.writeFloat(0.0f);
        fos.writeFloat(0.0f);
        fos.writeDouble(0.0d);
        fos.writeUTF("Reserved field");
        // write normStr
        String normStr = modelConfig.getNormalize().getNormType().toString();
        StringUtils.writeString(fos, normStr);
        // compute columns needed
        Map<Integer, String> columnIndexNameMapping = getIndexNameMapping(columnConfigList);
        // write column stats to output
        List<NNColumnStats> csList = new ArrayList<>();
        for (ColumnConfig cc : columnConfigList) {
            if (columnIndexNameMapping.containsKey(cc.getColumnNum())) {
                NNColumnStats cs = new NNColumnStats();
                cs.setCutoff(modelConfig.getNormalizeStdDevCutOff());
                cs.setColumnType(cc.getColumnType());
                cs.setMean(cc.getMean());
                cs.setStddev(cc.getStdDev());
                cs.setColumnNum(cc.getColumnNum());
                cs.setColumnName(cc.getColumnName());
                cs.setBinCategories(cc.getBinCategory());
                cs.setBinBoundaries(cc.getBinBoundary());
                cs.setBinPosRates(cc.getBinPosRate());
                cs.setBinCountWoes(cc.getBinCountWoe());
                cs.setBinWeightWoes(cc.getBinWeightedWoe());
                // TODO cache such computation
                double[] meanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, false);
                cs.setWoeMean(meanAndStdDev[0]);
                cs.setWoeStddev(meanAndStdDev[1]);
                double[] weightMeanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, true);
                cs.setWoeWgtMean(weightMeanAndStdDev[0]);
                cs.setWoeWgtStddev(weightMeanAndStdDev[1]);
                csList.add(cs);
            }
        }
        fos.writeInt(csList.size());
        for (NNColumnStats cs : csList) {
            cs.write(fos);
        }
        // persist WideAndDeep Model
        wideAndDeep.write(fos);
    } finally {
        IOUtils.closeStream(fos);
    }
}
Also used : NNColumnStats(ml.shifu.shifu.core.dtrain.nn.NNColumnStats) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) GZIPOutputStream(java.util.zip.GZIPOutputStream) DataOutputStream(java.io.DataOutputStream) ArrayList(java.util.ArrayList)

Example 2 with NNColumnStats

use of ml.shifu.shifu.core.dtrain.nn.NNColumnStats in project shifu by ShifuML.

the class IndependentWDLModel method loadFromStream.

/**
 * Load model instance from input stream which is saved in WDLOutput for specified binary format.
 *
 * @param input              the input stream, flat input stream or gzip input stream both OK
 * @param isRemoveNameSpace, is remove name space or not
 * @return the WideAndDeep model instance
 * @throws IOException any IOException in de-serialization.
 */
public static IndependentWDLModel loadFromStream(InputStream input, boolean isRemoveNameSpace) throws IOException {
    DataInputStream dis;
    // check if gzip or not
    try {
        byte[] header = new byte[2];
        BufferedInputStream bis = new BufferedInputStream(input);
        bis.mark(2);
        int result = bis.read(header);
        bis.reset();
        int ss = (header[0] & 0xff) | ((header[1] & 0xff) << 8);
        if (result != -1 && ss == GZIPInputStream.GZIP_MAGIC) {
            dis = new DataInputStream(new GZIPInputStream(bis));
        } else {
            dis = new DataInputStream(bis);
        }
    } catch (java.io.IOException e) {
        dis = new DataInputStream(input);
    }
    int version = dis.readInt();
    IndependentWDLModel.setVersion(version);
    // Reserved two float field, one double field and one string field
    dis.readFloat();
    dis.readFloat();
    dis.readDouble();
    dis.readUTF();
    // read normStr
    String normStr = StringUtils.readString(dis);
    NormType normType = NormType.valueOf(normStr != null ? normStr.toUpperCase() : null);
    int columnSize = dis.readInt();
    // for all features
    Map<Integer, String> numNameMap = new HashMap<>(columnSize);
    // for numerical features
    Map<Integer, List<Double>> numerBinBoundaries = new HashMap<>(columnSize);
    Map<Integer, List<Double>> numerWoes = new HashMap<>(columnSize);
    Map<Integer, List<Double>> numerWgtWoes = new HashMap<>(columnSize);
    // for all features
    Map<Integer, Double> numerMeanMap = new HashMap<>(columnSize);
    Map<Integer, Double> numerStddevMap = new HashMap<>(columnSize);
    Map<Integer, Double> woeMeanMap = new HashMap<>(columnSize);
    Map<Integer, Double> woeStddevMap = new HashMap<>(columnSize);
    Map<Integer, Double> wgtWoeMeanMap = new HashMap<>(columnSize);
    Map<Integer, Double> wgtWoeStddevMap = new HashMap<>(columnSize);
    Map<Integer, Double> cutoffMap = new HashMap<>(columnSize);
    Map<Integer, Map<String, Integer>> cateIndexMapping = new HashMap<>(columnSize);
    for (int i = 0; i < columnSize; i++) {
        NNColumnStats cs = new NNColumnStats();
        cs.readFields(dis);
        List<Double> binWoes = cs.getBinCountWoes();
        List<Double> binWgtWoes = cs.getBinWeightWoes();
        int columnNum = cs.getColumnNum();
        if (isRemoveNameSpace) {
            // remove name-space in column name to make it be called by simple name
            numNameMap.put(columnNum, StringUtils.getSimpleColumnName(cs.getColumnName()));
        } else {
            numNameMap.put(columnNum, cs.getColumnName());
        }
        // for categorical features
        Map<String, Integer> cateIndexMap = new HashMap<>(cs.getBinCategories().size());
        if (cs.isCategorical() || cs.isHybrid()) {
            List<String> binCategories = cs.getBinCategories();
            for (int j = 0; j < binCategories.size(); j++) {
                String currCate = binCategories.get(j);
                if (currCate.contains(Constants.CATEGORICAL_GROUP_VAL_DELIMITER)) {
                    // merged category should be flatten, use own split function to avoid depending on guava jar in
                    // prediction
                    String[] splits = StringUtils.split(currCate, Constants.CATEGORICAL_GROUP_VAL_DELIMITER);
                    for (String str : splits) {
                        cateIndexMap.put(str, j);
                    }
                } else {
                    cateIndexMap.put(currCate, j);
                }
            }
        }
        if (cs.isNumerical() || cs.isHybrid()) {
            numerBinBoundaries.put(columnNum, cs.getBinBoundaries());
            numerWoes.put(columnNum, binWoes);
            numerWgtWoes.put(columnNum, binWgtWoes);
        }
        cateIndexMapping.put(columnNum, cateIndexMap);
        numerMeanMap.put(columnNum, cs.getMean());
        numerStddevMap.put(columnNum, cs.getStddev());
        woeMeanMap.put(columnNum, cs.getWoeMean());
        woeStddevMap.put(columnNum, cs.getWoeStddev());
        wgtWoeMeanMap.put(columnNum, cs.getWoeWgtMean());
        wgtWoeStddevMap.put(columnNum, cs.getWoeWgtStddev());
        cutoffMap.put(columnNum, cs.getCutoff());
    }
    WideAndDeep wideAndDeep = new WideAndDeep();
    wideAndDeep.readFields(dis);
    return new IndependentWDLModel(wideAndDeep, normType, cutoffMap, numNameMap, cateIndexMapping, numerBinBoundaries, numerWoes, numerWgtWoes, numerMeanMap, numerStddevMap, woeMeanMap, woeStddevMap, wgtWoeMeanMap, wgtWoeStddevMap);
}
Also used : NNColumnStats(ml.shifu.shifu.core.dtrain.nn.NNColumnStats) HashMap(java.util.HashMap) GZIPInputStream(java.util.zip.GZIPInputStream) BufferedInputStream(java.io.BufferedInputStream) ArrayList(java.util.ArrayList) List(java.util.List) NormType(ml.shifu.shifu.container.obj.ModelNormalizeConf.NormType) IOException(java.io.IOException) DataInputStream(java.io.DataInputStream) HashMap(java.util.HashMap) Map(java.util.Map)

Aggregations

ArrayList (java.util.ArrayList)2 NNColumnStats (ml.shifu.shifu.core.dtrain.nn.NNColumnStats)2 BufferedInputStream (java.io.BufferedInputStream)1 DataInputStream (java.io.DataInputStream)1 DataOutputStream (java.io.DataOutputStream)1 IOException (java.io.IOException)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Map (java.util.Map)1 GZIPInputStream (java.util.zip.GZIPInputStream)1 GZIPOutputStream (java.util.zip.GZIPOutputStream)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 NormType (ml.shifu.shifu.container.obj.ModelNormalizeConf.NormType)1