use of ml.shifu.shifu.container.obj.ColumnType in project shifu by ShifuML.
the class IndependentNNModel method loadFromStream.
/**
* Load model instance from input stream which is saved in NNOutput for specified binary format.
*
* @param input
* the input stream, flat input stream or gzip input stream both OK
* @param isRemoveNameSpace
* is remove name space or not
* @return the nn model instance
* @throws IOException
* any IOException in de-serialization.
*/
public static IndependentNNModel loadFromStream(InputStream input, boolean isRemoveNameSpace) throws IOException {
DataInputStream dis = null;
// check if gzip or not
try {
byte[] header = new byte[2];
BufferedInputStream bis = new BufferedInputStream(input);
bis.mark(2);
int result = bis.read(header);
bis.reset();
int ss = (header[0] & 0xff) | ((header[1] & 0xff) << 8);
if (result != -1 && ss == GZIPInputStream.GZIP_MAGIC) {
dis = new DataInputStream(new GZIPInputStream(bis));
} else {
dis = new DataInputStream(bis);
}
} catch (java.io.IOException e) {
dis = new DataInputStream(input);
}
int version = dis.readInt();
IndependentNNModel.setVersion(version);
String normStr = ml.shifu.shifu.core.dtrain.StringUtils.readString(dis);
NormType normType = NormType.valueOf(normStr.toUpperCase());
// for all features
Map<Integer, String> numNameMap = new HashMap<Integer, String>();
Map<Integer, List<String>> cateColumnNameNames = new HashMap<Integer, List<String>>();
// for categorical features
Map<Integer, Map<String, Double>> cateWoeMap = new HashMap<Integer, Map<String, Double>>();
Map<Integer, Map<String, Double>> cateWgtWoeMap = new HashMap<Integer, Map<String, Double>>();
Map<Integer, Map<String, Double>> binPosRateMap = new HashMap<Integer, Map<String, Double>>();
// for numerical features
Map<Integer, List<Double>> numerBinBoundaries = new HashMap<Integer, List<Double>>();
Map<Integer, List<Double>> numerWoes = new HashMap<Integer, List<Double>>();
Map<Integer, List<Double>> numerWgtWoes = new HashMap<Integer, List<Double>>();
// for all features
Map<Integer, Double> numerMeanMap = new HashMap<Integer, Double>();
Map<Integer, Double> numerStddevMap = new HashMap<Integer, Double>();
Map<Integer, Double> woeMeanMap = new HashMap<Integer, Double>();
Map<Integer, Double> woeStddevMap = new HashMap<Integer, Double>();
Map<Integer, Double> wgtWoeMeanMap = new HashMap<Integer, Double>();
Map<Integer, Double> wgtWoeStddevMap = new HashMap<Integer, Double>();
Map<Integer, Double> cutoffMap = new HashMap<Integer, Double>();
Map<Integer, ColumnType> columnTypeMap = new HashMap<Integer, ColumnType>();
Map<Integer, Map<String, Integer>> cateIndexMapping = new HashMap<Integer, Map<String, Integer>>();
int columnSize = dis.readInt();
for (int i = 0; i < columnSize; i++) {
NNColumnStats cs = new NNColumnStats();
cs.readFields(dis);
List<Double> binWoes = cs.getBinCountWoes();
List<Double> binWgtWoes = cs.getBinWeightWoes();
List<Double> binPosRates = cs.getBinPosRates();
int columnNum = cs.getColumnNum();
columnTypeMap.put(columnNum, cs.getColumnType());
if (isRemoveNameSpace) {
// remove name-space in column name to make it be called by simple name
numNameMap.put(columnNum, StringUtils.getSimpleColumnName(cs.getColumnName()));
} else {
numNameMap.put(columnNum, cs.getColumnName());
}
// for categorical features
Map<String, Double> woeMap = new HashMap<String, Double>();
Map<String, Double> woeWgtMap = new HashMap<String, Double>();
Map<String, Double> posRateMap = new HashMap<String, Double>();
Map<String, Integer> cateIndexMap = new HashMap<String, Integer>();
if (cs.isCategorical() || cs.isHybrid()) {
List<String> binCategories = cs.getBinCategories();
cateColumnNameNames.put(columnNum, binCategories);
for (int j = 0; j < binCategories.size(); j++) {
String currCate = binCategories.get(j);
if (currCate.contains(Constants.CATEGORICAL_GROUP_VAL_DELIMITER)) {
// merged category should be flatten, use own split function to avoid depending on guava jar in
// prediction
String[] splits = StringUtils.split(currCate, Constants.CATEGORICAL_GROUP_VAL_DELIMITER);
for (String str : splits) {
woeMap.put(str, binWoes.get(j));
woeWgtMap.put(str, binWgtWoes.get(j));
posRateMap.put(str, binPosRates.get(j));
cateIndexMap.put(str, j);
}
} else {
woeMap.put(currCate, binWoes.get(j));
woeWgtMap.put(currCate, binWgtWoes.get(j));
posRateMap.put(currCate, binPosRates.get(j));
cateIndexMap.put(currCate, j);
}
}
// append last missing bin
woeMap.put(Constants.EMPTY_CATEGORY, binWoes.get(binCategories.size()));
woeWgtMap.put(Constants.EMPTY_CATEGORY, binWgtWoes.get(binCategories.size()));
posRateMap.put(Constants.EMPTY_CATEGORY, binPosRates.get(binCategories.size()));
}
if (cs.isNumerical() || cs.isHybrid()) {
numerBinBoundaries.put(columnNum, cs.getBinBoundaries());
numerWoes.put(columnNum, binWoes);
numerWgtWoes.put(columnNum, binWgtWoes);
}
cateWoeMap.put(columnNum, woeMap);
cateWgtWoeMap.put(columnNum, woeWgtMap);
binPosRateMap.put(columnNum, posRateMap);
cateIndexMapping.put(columnNum, cateIndexMap);
numerMeanMap.put(columnNum, cs.getMean());
numerStddevMap.put(columnNum, cs.getStddev());
woeMeanMap.put(columnNum, cs.getWoeMean());
woeStddevMap.put(columnNum, cs.getWoeStddev());
wgtWoeMeanMap.put(columnNum, cs.getWoeWgtMean());
wgtWoeStddevMap.put(columnNum, cs.getWoeWgtStddev());
cutoffMap.put(columnNum, cs.getCutoff());
}
Map<Integer, Integer> columnMap = new HashMap<Integer, Integer>();
int columnMapSize = dis.readInt();
for (int i = 0; i < columnMapSize; i++) {
columnMap.put(dis.readInt(), dis.readInt());
}
int size = dis.readInt();
List<BasicFloatNetwork> networks = new ArrayList<BasicFloatNetwork>();
for (int i = 0; i < size; i++) {
networks.add(new PersistBasicFloatNetwork().readNetwork(dis));
}
return new IndependentNNModel(networks, normType, numNameMap, cateColumnNameNames, columnMap, cateWoeMap, cateWgtWoeMap, binPosRateMap, numerBinBoundaries, numerWgtWoes, numerWoes, cutoffMap, numerMeanMap, numerStddevMap, woeMeanMap, woeStddevMap, wgtWoeMeanMap, wgtWoeStddevMap, columnTypeMap, cateIndexMapping);
}
use of ml.shifu.shifu.container.obj.ColumnType in project shifu by ShifuML.
the class IndependentNNModel method convertDataMapToDoubleArray.
private double[] convertDataMapToDoubleArray(Map<String, Object> dataMap) {
double[] data = new double[this.columnNumIndexMap.size()];
for (Entry<Integer, Integer> entry : this.columnNumIndexMap.entrySet()) {
double value = 0d;
Integer columnNum = entry.getKey();
String columnName = this.numNameMap.get(columnNum);
Object obj = dataMap.get(columnName);
ColumnType columnType = this.columnTypeMap.get(columnNum);
if (columnType == ColumnType.C) {
// categorical column
switch(this.normType) {
case WOE:
case HYBRID:
value = getCategoricalWoeValue(columnNum, obj, false);
break;
case WEIGHT_WOE:
case WEIGHT_HYBRID:
value = getCategoricalWoeValue(columnNum, obj, true);
break;
case WOE_ZSCORE:
case WOE_ZSCALE:
value = getCategoricalWoeZScoreValue(columnNum, obj, false);
break;
case WEIGHT_WOE_ZSCORE:
case WEIGHT_WOE_ZSCALE:
value = getCategoricalWoeZScoreValue(columnNum, obj, true);
break;
case OLD_ZSCALE:
case OLD_ZSCORE:
value = getCategoricalPosRateZScoreValue(columnNum, obj, true);
break;
case ZSCALE:
case ZSCORE:
default:
value = getCategoricalPosRateZScoreValue(columnNum, obj, false);
break;
}
} else if (columnType == ColumnType.N) {
// numerical column
switch(this.normType) {
case WOE:
value = getNumericalWoeValue(columnNum, obj, false);
break;
case WEIGHT_WOE:
value = getNumericalWoeValue(columnNum, obj, true);
break;
case WOE_ZSCORE:
case WOE_ZSCALE:
value = getNumericalWoeZScoreValue(columnNum, obj, false);
break;
case WEIGHT_WOE_ZSCORE:
case WEIGHT_WOE_ZSCALE:
value = getNumericalWoeZScoreValue(columnNum, obj, true);
break;
case OLD_ZSCALE:
case OLD_ZSCORE:
case ZSCALE:
case ZSCORE:
case HYBRID:
case WEIGHT_HYBRID:
default:
value = getNumericalZScoreValue(columnNum, obj);
break;
}
} else if (columnType == ColumnType.H) {
// hybrid column
switch(this.normType) {
case WOE:
value = getHybridWoeValue(columnNum, obj, false);
break;
case WEIGHT_WOE:
value = getHybridWoeValue(columnNum, obj, true);
break;
case WOE_ZSCORE:
case WOE_ZSCALE:
value = getHybridWoeZScoreValue(columnNum, obj, false);
break;
case WEIGHT_WOE_ZSCORE:
case WEIGHT_WOE_ZSCALE:
value = getHybridWoeZScoreValue(columnNum, obj, true);
break;
case OLD_ZSCALE:
case OLD_ZSCORE:
case ZSCALE:
case ZSCORE:
case HYBRID:
case WEIGHT_HYBRID:
default:
throw new IllegalStateException("Column type of " + columnName + " is hybrid, but normType is not woe related.");
}
}
Integer index = entry.getValue();
if (index != null && index < data.length) {
data[index] = value;
}
}
return data;
}
Aggregations