use of ml.shifu.shifu.core.dtrain.nn.NNColumnStats in project shifu by ShifuML.
the class BinaryWDLSerializer method save.
public static void save(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, WideAndDeep wideAndDeep, FileSystem fs, Path output) throws IOException {
DataOutputStream fos = null;
try {
fos = new DataOutputStream(new GZIPOutputStream(fs.create(output)));
// version
fos.writeInt(CommonConstants.WDL_FORMAT_VERSION);
// Reserved two float field, one double field and one string field
fos.writeFloat(0.0f);
fos.writeFloat(0.0f);
fos.writeDouble(0.0d);
fos.writeUTF("Reserved field");
// write normStr
String normStr = modelConfig.getNormalize().getNormType().toString();
StringUtils.writeString(fos, normStr);
// compute columns needed
Map<Integer, String> columnIndexNameMapping = getIndexNameMapping(columnConfigList);
// write column stats to output
List<NNColumnStats> csList = new ArrayList<>();
for (ColumnConfig cc : columnConfigList) {
if (columnIndexNameMapping.containsKey(cc.getColumnNum())) {
NNColumnStats cs = new NNColumnStats();
cs.setCutoff(modelConfig.getNormalizeStdDevCutOff());
cs.setColumnType(cc.getColumnType());
cs.setMean(cc.getMean());
cs.setStddev(cc.getStdDev());
cs.setColumnNum(cc.getColumnNum());
cs.setColumnName(cc.getColumnName());
cs.setBinCategories(cc.getBinCategory());
cs.setBinBoundaries(cc.getBinBoundary());
cs.setBinPosRates(cc.getBinPosRate());
cs.setBinCountWoes(cc.getBinCountWoe());
cs.setBinWeightWoes(cc.getBinWeightedWoe());
// TODO cache such computation
double[] meanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, false);
cs.setWoeMean(meanAndStdDev[0]);
cs.setWoeStddev(meanAndStdDev[1]);
double[] weightMeanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, true);
cs.setWoeWgtMean(weightMeanAndStdDev[0]);
cs.setWoeWgtStddev(weightMeanAndStdDev[1]);
csList.add(cs);
}
}
fos.writeInt(csList.size());
for (NNColumnStats cs : csList) {
cs.write(fos);
}
// persist WideAndDeep Model
wideAndDeep.write(fos);
} finally {
IOUtils.closeStream(fos);
}
}
use of ml.shifu.shifu.core.dtrain.nn.NNColumnStats in project shifu by ShifuML.
the class IndependentWDLModel method loadFromStream.
/**
* Load model instance from input stream which is saved in WDLOutput for specified binary format.
*
* @param input the input stream, flat input stream or gzip input stream both OK
* @param isRemoveNameSpace, is remove name space or not
* @return the WideAndDeep model instance
* @throws IOException any IOException in de-serialization.
*/
public static IndependentWDLModel loadFromStream(InputStream input, boolean isRemoveNameSpace) throws IOException {
DataInputStream dis;
// check if gzip or not
try {
byte[] header = new byte[2];
BufferedInputStream bis = new BufferedInputStream(input);
bis.mark(2);
int result = bis.read(header);
bis.reset();
int ss = (header[0] & 0xff) | ((header[1] & 0xff) << 8);
if (result != -1 && ss == GZIPInputStream.GZIP_MAGIC) {
dis = new DataInputStream(new GZIPInputStream(bis));
} else {
dis = new DataInputStream(bis);
}
} catch (java.io.IOException e) {
dis = new DataInputStream(input);
}
int version = dis.readInt();
IndependentWDLModel.setVersion(version);
// Reserved two float field, one double field and one string field
dis.readFloat();
dis.readFloat();
dis.readDouble();
dis.readUTF();
// read normStr
String normStr = StringUtils.readString(dis);
NormType normType = NormType.valueOf(normStr != null ? normStr.toUpperCase() : null);
int columnSize = dis.readInt();
// for all features
Map<Integer, String> numNameMap = new HashMap<>(columnSize);
// for numerical features
Map<Integer, List<Double>> numerBinBoundaries = new HashMap<>(columnSize);
Map<Integer, List<Double>> numerWoes = new HashMap<>(columnSize);
Map<Integer, List<Double>> numerWgtWoes = new HashMap<>(columnSize);
// for all features
Map<Integer, Double> numerMeanMap = new HashMap<>(columnSize);
Map<Integer, Double> numerStddevMap = new HashMap<>(columnSize);
Map<Integer, Double> woeMeanMap = new HashMap<>(columnSize);
Map<Integer, Double> woeStddevMap = new HashMap<>(columnSize);
Map<Integer, Double> wgtWoeMeanMap = new HashMap<>(columnSize);
Map<Integer, Double> wgtWoeStddevMap = new HashMap<>(columnSize);
Map<Integer, Double> cutoffMap = new HashMap<>(columnSize);
Map<Integer, Map<String, Integer>> cateIndexMapping = new HashMap<>(columnSize);
for (int i = 0; i < columnSize; i++) {
NNColumnStats cs = new NNColumnStats();
cs.readFields(dis);
List<Double> binWoes = cs.getBinCountWoes();
List<Double> binWgtWoes = cs.getBinWeightWoes();
int columnNum = cs.getColumnNum();
if (isRemoveNameSpace) {
// remove name-space in column name to make it be called by simple name
numNameMap.put(columnNum, StringUtils.getSimpleColumnName(cs.getColumnName()));
} else {
numNameMap.put(columnNum, cs.getColumnName());
}
// for categorical features
Map<String, Integer> cateIndexMap = new HashMap<>(cs.getBinCategories().size());
if (cs.isCategorical() || cs.isHybrid()) {
List<String> binCategories = cs.getBinCategories();
for (int j = 0; j < binCategories.size(); j++) {
String currCate = binCategories.get(j);
if (currCate.contains(Constants.CATEGORICAL_GROUP_VAL_DELIMITER)) {
// merged category should be flatten, use own split function to avoid depending on guava jar in
// prediction
String[] splits = StringUtils.split(currCate, Constants.CATEGORICAL_GROUP_VAL_DELIMITER);
for (String str : splits) {
cateIndexMap.put(str, j);
}
} else {
cateIndexMap.put(currCate, j);
}
}
}
if (cs.isNumerical() || cs.isHybrid()) {
numerBinBoundaries.put(columnNum, cs.getBinBoundaries());
numerWoes.put(columnNum, binWoes);
numerWgtWoes.put(columnNum, binWgtWoes);
}
cateIndexMapping.put(columnNum, cateIndexMap);
numerMeanMap.put(columnNum, cs.getMean());
numerStddevMap.put(columnNum, cs.getStddev());
woeMeanMap.put(columnNum, cs.getWoeMean());
woeStddevMap.put(columnNum, cs.getWoeStddev());
wgtWoeMeanMap.put(columnNum, cs.getWoeWgtMean());
wgtWoeStddevMap.put(columnNum, cs.getWoeWgtStddev());
cutoffMap.put(columnNum, cs.getCutoff());
}
WideAndDeep wideAndDeep = new WideAndDeep();
wideAndDeep.readFields(dis);
return new IndependentWDLModel(wideAndDeep, normType, cutoffMap, numNameMap, cateIndexMapping, numerBinBoundaries, numerWoes, numerWgtWoes, numerMeanMap, numerStddevMap, woeMeanMap, woeStddevMap, wgtWoeMeanMap, wgtWoeStddevMap);
}
Aggregations