Search in sources :

Example 1 with ColumnType

use of ml.shifu.shifu.container.obj.ColumnType in project shifu by ShifuML.

the class IndependentNNModel method loadFromStream.

/**
 * Load model instance from input stream which is saved in NNOutput for specified binary format.
 *
 * @param input
 *            the input stream, flat input stream or gzip input stream both OK
 * @param isRemoveNameSpace
 *            is remove name space or not
 * @return the nn model instance
 * @throws IOException
 *             any IOException in de-serialization.
 */
public static IndependentNNModel loadFromStream(InputStream input, boolean isRemoveNameSpace) throws IOException {
    DataInputStream dis = null;
    // check if gzip or not
    try {
        byte[] header = new byte[2];
        BufferedInputStream bis = new BufferedInputStream(input);
        bis.mark(2);
        int result = bis.read(header);
        bis.reset();
        int ss = (header[0] & 0xff) | ((header[1] & 0xff) << 8);
        if (result != -1 && ss == GZIPInputStream.GZIP_MAGIC) {
            dis = new DataInputStream(new GZIPInputStream(bis));
        } else {
            dis = new DataInputStream(bis);
        }
    } catch (java.io.IOException e) {
        dis = new DataInputStream(input);
    }
    int version = dis.readInt();
    IndependentNNModel.setVersion(version);
    String normStr = ml.shifu.shifu.core.dtrain.StringUtils.readString(dis);
    NormType normType = NormType.valueOf(normStr.toUpperCase());
    // for all features
    Map<Integer, String> numNameMap = new HashMap<Integer, String>();
    Map<Integer, List<String>> cateColumnNameNames = new HashMap<Integer, List<String>>();
    // for categorical features
    Map<Integer, Map<String, Double>> cateWoeMap = new HashMap<Integer, Map<String, Double>>();
    Map<Integer, Map<String, Double>> cateWgtWoeMap = new HashMap<Integer, Map<String, Double>>();
    Map<Integer, Map<String, Double>> binPosRateMap = new HashMap<Integer, Map<String, Double>>();
    // for numerical features
    Map<Integer, List<Double>> numerBinBoundaries = new HashMap<Integer, List<Double>>();
    Map<Integer, List<Double>> numerWoes = new HashMap<Integer, List<Double>>();
    Map<Integer, List<Double>> numerWgtWoes = new HashMap<Integer, List<Double>>();
    // for all features
    Map<Integer, Double> numerMeanMap = new HashMap<Integer, Double>();
    Map<Integer, Double> numerStddevMap = new HashMap<Integer, Double>();
    Map<Integer, Double> woeMeanMap = new HashMap<Integer, Double>();
    Map<Integer, Double> woeStddevMap = new HashMap<Integer, Double>();
    Map<Integer, Double> wgtWoeMeanMap = new HashMap<Integer, Double>();
    Map<Integer, Double> wgtWoeStddevMap = new HashMap<Integer, Double>();
    Map<Integer, Double> cutoffMap = new HashMap<Integer, Double>();
    Map<Integer, ColumnType> columnTypeMap = new HashMap<Integer, ColumnType>();
    Map<Integer, Map<String, Integer>> cateIndexMapping = new HashMap<Integer, Map<String, Integer>>();
    int columnSize = dis.readInt();
    for (int i = 0; i < columnSize; i++) {
        NNColumnStats cs = new NNColumnStats();
        cs.readFields(dis);
        List<Double> binWoes = cs.getBinCountWoes();
        List<Double> binWgtWoes = cs.getBinWeightWoes();
        List<Double> binPosRates = cs.getBinPosRates();
        int columnNum = cs.getColumnNum();
        columnTypeMap.put(columnNum, cs.getColumnType());
        if (isRemoveNameSpace) {
            // remove name-space in column name to make it be called by simple name
            numNameMap.put(columnNum, StringUtils.getSimpleColumnName(cs.getColumnName()));
        } else {
            numNameMap.put(columnNum, cs.getColumnName());
        }
        // for categorical features
        Map<String, Double> woeMap = new HashMap<String, Double>();
        Map<String, Double> woeWgtMap = new HashMap<String, Double>();
        Map<String, Double> posRateMap = new HashMap<String, Double>();
        Map<String, Integer> cateIndexMap = new HashMap<String, Integer>();
        if (cs.isCategorical() || cs.isHybrid()) {
            List<String> binCategories = cs.getBinCategories();
            cateColumnNameNames.put(columnNum, binCategories);
            for (int j = 0; j < binCategories.size(); j++) {
                String currCate = binCategories.get(j);
                if (currCate.contains(Constants.CATEGORICAL_GROUP_VAL_DELIMITER)) {
                    // merged category should be flatten, use own split function to avoid depending on guava jar in
                    // prediction
                    String[] splits = StringUtils.split(currCate, Constants.CATEGORICAL_GROUP_VAL_DELIMITER);
                    for (String str : splits) {
                        woeMap.put(str, binWoes.get(j));
                        woeWgtMap.put(str, binWgtWoes.get(j));
                        posRateMap.put(str, binPosRates.get(j));
                        cateIndexMap.put(str, j);
                    }
                } else {
                    woeMap.put(currCate, binWoes.get(j));
                    woeWgtMap.put(currCate, binWgtWoes.get(j));
                    posRateMap.put(currCate, binPosRates.get(j));
                    cateIndexMap.put(currCate, j);
                }
            }
            // append last missing bin
            woeMap.put(Constants.EMPTY_CATEGORY, binWoes.get(binCategories.size()));
            woeWgtMap.put(Constants.EMPTY_CATEGORY, binWgtWoes.get(binCategories.size()));
            posRateMap.put(Constants.EMPTY_CATEGORY, binPosRates.get(binCategories.size()));
        }
        if (cs.isNumerical() || cs.isHybrid()) {
            numerBinBoundaries.put(columnNum, cs.getBinBoundaries());
            numerWoes.put(columnNum, binWoes);
            numerWgtWoes.put(columnNum, binWgtWoes);
        }
        cateWoeMap.put(columnNum, woeMap);
        cateWgtWoeMap.put(columnNum, woeWgtMap);
        binPosRateMap.put(columnNum, posRateMap);
        cateIndexMapping.put(columnNum, cateIndexMap);
        numerMeanMap.put(columnNum, cs.getMean());
        numerStddevMap.put(columnNum, cs.getStddev());
        woeMeanMap.put(columnNum, cs.getWoeMean());
        woeStddevMap.put(columnNum, cs.getWoeStddev());
        wgtWoeMeanMap.put(columnNum, cs.getWoeWgtMean());
        wgtWoeStddevMap.put(columnNum, cs.getWoeWgtStddev());
        cutoffMap.put(columnNum, cs.getCutoff());
    }
    Map<Integer, Integer> columnMap = new HashMap<Integer, Integer>();
    int columnMapSize = dis.readInt();
    for (int i = 0; i < columnMapSize; i++) {
        columnMap.put(dis.readInt(), dis.readInt());
    }
    int size = dis.readInt();
    List<BasicFloatNetwork> networks = new ArrayList<BasicFloatNetwork>();
    for (int i = 0; i < size; i++) {
        networks.add(new PersistBasicFloatNetwork().readNetwork(dis));
    }
    return new IndependentNNModel(networks, normType, numNameMap, cateColumnNameNames, columnMap, cateWoeMap, cateWgtWoeMap, binPosRateMap, numerBinBoundaries, numerWgtWoes, numerWoes, cutoffMap, numerMeanMap, numerStddevMap, woeMeanMap, woeStddevMap, wgtWoeMeanMap, wgtWoeStddevMap, columnTypeMap, cateIndexMapping);
}
Also used : ColumnType(ml.shifu.shifu.container.obj.ColumnType) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) GZIPInputStream(java.util.zip.GZIPInputStream) BufferedInputStream(java.io.BufferedInputStream) ArrayList(java.util.ArrayList) List(java.util.List) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) NormType(ml.shifu.shifu.container.obj.ModelNormalizeConf.NormType) IOException(java.io.IOException) DataInputStream(java.io.DataInputStream) HashMap(java.util.HashMap) Map(java.util.Map) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)

Example 2 with ColumnType

use of ml.shifu.shifu.container.obj.ColumnType in project shifu by ShifuML.

the class IndependentNNModel method convertDataMapToDoubleArray.

private double[] convertDataMapToDoubleArray(Map<String, Object> dataMap) {
    double[] data = new double[this.columnNumIndexMap.size()];
    for (Entry<Integer, Integer> entry : this.columnNumIndexMap.entrySet()) {
        double value = 0d;
        Integer columnNum = entry.getKey();
        String columnName = this.numNameMap.get(columnNum);
        Object obj = dataMap.get(columnName);
        ColumnType columnType = this.columnTypeMap.get(columnNum);
        if (columnType == ColumnType.C) {
            // categorical column
            switch(this.normType) {
                case WOE:
                case HYBRID:
                    value = getCategoricalWoeValue(columnNum, obj, false);
                    break;
                case WEIGHT_WOE:
                case WEIGHT_HYBRID:
                    value = getCategoricalWoeValue(columnNum, obj, true);
                    break;
                case WOE_ZSCORE:
                case WOE_ZSCALE:
                    value = getCategoricalWoeZScoreValue(columnNum, obj, false);
                    break;
                case WEIGHT_WOE_ZSCORE:
                case WEIGHT_WOE_ZSCALE:
                    value = getCategoricalWoeZScoreValue(columnNum, obj, true);
                    break;
                case OLD_ZSCALE:
                case OLD_ZSCORE:
                    value = getCategoricalPosRateZScoreValue(columnNum, obj, true);
                    break;
                case ZSCALE:
                case ZSCORE:
                default:
                    value = getCategoricalPosRateZScoreValue(columnNum, obj, false);
                    break;
            }
        } else if (columnType == ColumnType.N) {
            // numerical column
            switch(this.normType) {
                case WOE:
                    value = getNumericalWoeValue(columnNum, obj, false);
                    break;
                case WEIGHT_WOE:
                    value = getNumericalWoeValue(columnNum, obj, true);
                    break;
                case WOE_ZSCORE:
                case WOE_ZSCALE:
                    value = getNumericalWoeZScoreValue(columnNum, obj, false);
                    break;
                case WEIGHT_WOE_ZSCORE:
                case WEIGHT_WOE_ZSCALE:
                    value = getNumericalWoeZScoreValue(columnNum, obj, true);
                    break;
                case OLD_ZSCALE:
                case OLD_ZSCORE:
                case ZSCALE:
                case ZSCORE:
                case HYBRID:
                case WEIGHT_HYBRID:
                default:
                    value = getNumericalZScoreValue(columnNum, obj);
                    break;
            }
        } else if (columnType == ColumnType.H) {
            // hybrid column
            switch(this.normType) {
                case WOE:
                    value = getHybridWoeValue(columnNum, obj, false);
                    break;
                case WEIGHT_WOE:
                    value = getHybridWoeValue(columnNum, obj, true);
                    break;
                case WOE_ZSCORE:
                case WOE_ZSCALE:
                    value = getHybridWoeZScoreValue(columnNum, obj, false);
                    break;
                case WEIGHT_WOE_ZSCORE:
                case WEIGHT_WOE_ZSCALE:
                    value = getHybridWoeZScoreValue(columnNum, obj, true);
                    break;
                case OLD_ZSCALE:
                case OLD_ZSCORE:
                case ZSCALE:
                case ZSCORE:
                case HYBRID:
                case WEIGHT_HYBRID:
                default:
                    throw new IllegalStateException("Column type of " + columnName + " is hybrid, but normType is not woe related.");
            }
        }
        Integer index = entry.getValue();
        if (index != null && index < data.length) {
            data[index] = value;
        }
    }
    return data;
}
Also used : ColumnType(ml.shifu.shifu.container.obj.ColumnType)

Aggregations

ColumnType (ml.shifu.shifu.container.obj.ColumnType)2 BufferedInputStream (java.io.BufferedInputStream)1 DataInputStream (java.io.DataInputStream)1 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Map (java.util.Map)1 GZIPInputStream (java.util.zip.GZIPInputStream)1 NormType (ml.shifu.shifu.container.obj.ModelNormalizeConf.NormType)1 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)1 PersistBasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)1