use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.
the class ModelSpecLoaderUtils method loadSubModels.
/**
* Load sub-models under current model space
*
* @param modelConfig
* - {@link ModelConfig}, need this, since the model file may exist in HDFS
* @param columnConfigList
* - List of {@link ColumnConfig}
* @param evalConfig
* - {@link EvalConfig}, maybe null
* @param sourceType
* - {@link SourceType}, HDFS or Local?
* @param gbtConvertToProb
* - convert to probability or not for gbt model
* @param gbtScoreConvertStrategy
* - gbt score conversion strategy
* @return list of {@link ModelSpec} for sub models
*/
@SuppressWarnings("deprecation")
public static List<ModelSpec> loadSubModels(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, EvalConfig evalConfig, RawSourceData.SourceType sourceType, Boolean gbtConvertToProb, String gbtScoreConvertStrategy) {
List<ModelSpec> modelSpecs = new ArrayList<ModelSpec>();
FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType);
// we have to register PersistBasicFloatNetwork for loading such models
PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
PathFinder pathFinder = new PathFinder(modelConfig);
String modelsPath = null;
if (evalConfig == null || StringUtils.isEmpty(evalConfig.getModelsPath())) {
modelsPath = pathFinder.getModelsPath(sourceType);
} else {
modelsPath = evalConfig.getModelsPath();
}
try {
FileStatus[] fsArr = fs.listStatus(new Path(modelsPath));
for (FileStatus fileStatus : fsArr) {
if (fileStatus.isDir()) {
ModelSpec modelSpec = loadSubModelSpec(modelConfig, columnConfigList, fileStatus, sourceType, gbtConvertToProb, gbtScoreConvertStrategy);
if (modelSpec != null) {
modelSpecs.add(modelSpec);
}
}
}
} catch (IOException e) {
log.error("Error occurred when loading sub-models.", e);
}
return modelSpecs;
}
use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.
the class ModelSpecLoaderUtils method loadModel.
/**
* Loading model according to existing model path.
*
* @param modelConfig
* model config
* @param modelPath
* the path to store model
* @param fs
* file system used to store model
* @param gbtConvertToProb
* convert gbt score to prob or not
* @param gbtScoreConvertStrategy
* specify how to convert gbt raw score
* @return model object or null if no modelPath file,
* @throws IOException
* if loading file for any IOException
*/
public static BasicML loadModel(ModelConfig modelConfig, Path modelPath, FileSystem fs, boolean gbtConvertToProb, String gbtScoreConvertStrategy) throws IOException {
if (!fs.exists(modelPath)) {
// no such existing model, return null.
return null;
}
// we have to register PersistBasicFloatNetwork for loading such models
PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
FSDataInputStream stream = null;
BufferedReader br = null;
try {
stream = fs.open(modelPath);
if (modelPath.getName().endsWith(LogisticRegressionContants.LR_ALG_NAME.toLowerCase())) {
// LR model
br = new BufferedReader(new InputStreamReader(stream));
try {
return LR.loadFromString(br.readLine());
} catch (Exception e) {
// local LR model?
// close and reopen
IOUtils.closeQuietly(br);
stream = fs.open(modelPath);
return BasicML.class.cast(EncogDirectoryPersistence.loadObject(stream));
}
} else if (// RF or GBT
modelPath.getName().endsWith(CommonConstants.RF_ALG_NAME.toLowerCase()) || modelPath.getName().endsWith(CommonConstants.GBT_ALG_NAME.toLowerCase())) {
return TreeModel.loadFromStream(stream, gbtConvertToProb, gbtScoreConvertStrategy);
} else {
GzipStreamPair pair = GzipStreamPair.isGZipFormat(stream);
if (pair.isGzip()) {
return BasicML.class.cast(NNModel.loadFromStream(pair.getInput()));
} else {
return BasicML.class.cast(EncogDirectoryPersistence.loadObject(pair.getInput()));
}
}
} catch (Exception e) {
String msg = "the expecting model file is: " + modelPath;
throw new ShifuException(ShifuErrorCode.ERROR_FAIL_TO_LOAD_MODEL_FILE, e, msg);
} finally {
IOUtils.closeQuietly(br);
IOUtils.closeQuietly(stream);
}
}
use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.
the class NNOutput method initNetwork.
@SuppressWarnings("unchecked")
private void initNetwork(MasterContext<NNParams, NNParams> context) {
int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
boolean isLinearTarget = CommonUtils.isLinearTarget(modelConfig, columnConfigList);
@SuppressWarnings("unused") int inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
// if is one vs all classification, outputNodeCount is set to 1, if classes=2, outputNodeCount is also 1
int classes = modelConfig.getTags().size();
int outputNodeCount = (isLinearTarget || modelConfig.isRegression()) ? inputOutputIndex[1] : (modelConfig.getTrain().isOneVsAll() ? inputOutputIndex[1] : (classes == 2 ? 1 : classes));
int numLayers = (Integer) validParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
List<String> actFunc = (List<String>) validParams.get(CommonConstants.ACTIVATION_FUNC);
List<Integer> hiddenNodeList = (List<Integer>) validParams.get(CommonConstants.NUM_HIDDEN_NODES);
boolean isAfterVarSelect = inputOutputIndex[0] != 0;
// cache all feature list for sampling features
List<Integer> allFeatures = NormalUtils.getAllFeatureList(columnConfigList, isAfterVarSelect);
String subsetStr = context.getProps().getProperty(CommonConstants.SHIFU_NN_FEATURE_SUBSET);
if (StringUtils.isBlank(subsetStr)) {
this.subFeatures = new HashSet<Integer>(allFeatures);
} else {
String[] splits = subsetStr.split(",");
this.subFeatures = new HashSet<Integer>();
for (String split : splits) {
this.subFeatures.add(Integer.parseInt(split));
}
}
int featureInputsCnt = DTrainUtils.getFeatureInputsCnt(modelConfig, columnConfigList, this.subFeatures);
String outputActivationFunc = (String) validParams.get(CommonConstants.OUTPUT_ACTIVATION_FUNC);
this.network = DTrainUtils.generateNetwork(featureInputsCnt, outputNodeCount, numLayers, actFunc, hiddenNodeList, false, this.dropoutRate, this.wgtInit, CommonUtils.isLinearTarget(modelConfig, columnConfigList), outputActivationFunc);
((BasicFloatNetwork) this.network).setFeatureSet(this.subFeatures);
// register here to save models
PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
}
use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.
the class BinaryNNSerializer method save.
public static void save(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, List<BasicML> basicNetworks, FileSystem fs, Path output) throws IOException {
DataOutputStream fos = null;
try {
fos = new DataOutputStream(new GZIPOutputStream(fs.create(output)));
// version
fos.writeInt(CommonConstants.NN_FORMAT_VERSION);
// write normStr
String normStr = modelConfig.getNormalize().getNormType().toString();
ml.shifu.shifu.core.dtrain.StringUtils.writeString(fos, normStr);
// compute columns needed
Map<Integer, String> columnIndexNameMapping = getIndexNameMapping(columnConfigList);
// write column stats to output
List<NNColumnStats> csList = new ArrayList<NNColumnStats>();
for (ColumnConfig cc : columnConfigList) {
if (columnIndexNameMapping.containsKey(cc.getColumnNum())) {
NNColumnStats cs = new NNColumnStats();
cs.setCutoff(modelConfig.getNormalizeStdDevCutOff());
cs.setColumnType(cc.getColumnType());
cs.setMean(cc.getMean());
cs.setStddev(cc.getStdDev());
cs.setColumnNum(cc.getColumnNum());
cs.setColumnName(cc.getColumnName());
cs.setBinCategories(cc.getBinCategory());
cs.setBinBoundaries(cc.getBinBoundary());
cs.setBinPosRates(cc.getBinPosRate());
cs.setBinCountWoes(cc.getBinCountWoe());
cs.setBinWeightWoes(cc.getBinWeightedWoe());
// TODO cache such computation
double[] meanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, false);
cs.setWoeMean(meanAndStdDev[0]);
cs.setWoeStddev(meanAndStdDev[1]);
double[] WgtMeanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, true);
cs.setWoeWgtMean(WgtMeanAndStdDev[0]);
cs.setWoeWgtStddev(WgtMeanAndStdDev[1]);
csList.add(cs);
}
}
fos.writeInt(csList.size());
for (NNColumnStats cs : csList) {
cs.write(fos);
}
// write column index mapping
Map<Integer, Integer> columnMapping = getColumnMapping(columnConfigList);
fos.writeInt(columnMapping.size());
for (Entry<Integer, Integer> entry : columnMapping.entrySet()) {
fos.writeInt(entry.getKey());
fos.writeInt(entry.getValue());
}
// persist network, set it as list
fos.writeInt(basicNetworks.size());
for (BasicML network : basicNetworks) {
new PersistBasicFloatNetwork().saveNetwork(fos, (BasicFloatNetwork) network);
}
} finally {
IOUtils.closeStream(fos);
}
}
Aggregations