use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.
the class VarSelectMapper method loadModel.
/**
* Load first model in model path as a {@link MLRegression} instance.
*/
private synchronized void loadModel() throws IOException {
LOG.debug("Before loading model with memory {} in thread {}.", MemoryUtils.getRuntimeMemoryStats(), Thread.currentThread().getName());
long start = System.currentTimeMillis();
PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(SourceType.LOCAL);
// load model from local d-cache model file
model = (MLRegression) ModelSpecLoaderUtils.loadModel(modelConfig, new Path("model0." + modelConfig.getAlgorithm().toLowerCase()), fs);
LOG.debug("After load model class {} with time {}ms and memory {} in thread {}.", model.getClass().getName(), (System.currentTimeMillis() - start), MemoryUtils.getRuntimeMemoryStats(), Thread.currentThread().getName());
}
use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.
the class IndependentNNModel method loadFromStream.
/**
* Load model instance from input stream which is saved in NNOutput for specified binary format.
*
* @param input
* the input stream, flat input stream or gzip input stream both OK
* @param isRemoveNameSpace
* is remove name space or not
* @return the nn model instance
* @throws IOException
* any IOException in de-serialization.
*/
public static IndependentNNModel loadFromStream(InputStream input, boolean isRemoveNameSpace) throws IOException {
DataInputStream dis = null;
// check if gzip or not
try {
byte[] header = new byte[2];
BufferedInputStream bis = new BufferedInputStream(input);
bis.mark(2);
int result = bis.read(header);
bis.reset();
int ss = (header[0] & 0xff) | ((header[1] & 0xff) << 8);
if (result != -1 && ss == GZIPInputStream.GZIP_MAGIC) {
dis = new DataInputStream(new GZIPInputStream(bis));
} else {
dis = new DataInputStream(bis);
}
} catch (java.io.IOException e) {
dis = new DataInputStream(input);
}
int version = dis.readInt();
IndependentNNModel.setVersion(version);
String normStr = ml.shifu.shifu.core.dtrain.StringUtils.readString(dis);
NormType normType = NormType.valueOf(normStr.toUpperCase());
// for all features
Map<Integer, String> numNameMap = new HashMap<Integer, String>();
Map<Integer, List<String>> cateColumnNameNames = new HashMap<Integer, List<String>>();
// for categorical features
Map<Integer, Map<String, Double>> cateWoeMap = new HashMap<Integer, Map<String, Double>>();
Map<Integer, Map<String, Double>> cateWgtWoeMap = new HashMap<Integer, Map<String, Double>>();
Map<Integer, Map<String, Double>> binPosRateMap = new HashMap<Integer, Map<String, Double>>();
// for numerical features
Map<Integer, List<Double>> numerBinBoundaries = new HashMap<Integer, List<Double>>();
Map<Integer, List<Double>> numerWoes = new HashMap<Integer, List<Double>>();
Map<Integer, List<Double>> numerWgtWoes = new HashMap<Integer, List<Double>>();
// for all features
Map<Integer, Double> numerMeanMap = new HashMap<Integer, Double>();
Map<Integer, Double> numerStddevMap = new HashMap<Integer, Double>();
Map<Integer, Double> woeMeanMap = new HashMap<Integer, Double>();
Map<Integer, Double> woeStddevMap = new HashMap<Integer, Double>();
Map<Integer, Double> wgtWoeMeanMap = new HashMap<Integer, Double>();
Map<Integer, Double> wgtWoeStddevMap = new HashMap<Integer, Double>();
Map<Integer, Double> cutoffMap = new HashMap<Integer, Double>();
Map<Integer, ColumnType> columnTypeMap = new HashMap<Integer, ColumnType>();
Map<Integer, Map<String, Integer>> cateIndexMapping = new HashMap<Integer, Map<String, Integer>>();
int columnSize = dis.readInt();
for (int i = 0; i < columnSize; i++) {
NNColumnStats cs = new NNColumnStats();
cs.readFields(dis);
List<Double> binWoes = cs.getBinCountWoes();
List<Double> binWgtWoes = cs.getBinWeightWoes();
List<Double> binPosRates = cs.getBinPosRates();
int columnNum = cs.getColumnNum();
columnTypeMap.put(columnNum, cs.getColumnType());
if (isRemoveNameSpace) {
// remove name-space in column name to make it be called by simple name
numNameMap.put(columnNum, StringUtils.getSimpleColumnName(cs.getColumnName()));
} else {
numNameMap.put(columnNum, cs.getColumnName());
}
// for categorical features
Map<String, Double> woeMap = new HashMap<String, Double>();
Map<String, Double> woeWgtMap = new HashMap<String, Double>();
Map<String, Double> posRateMap = new HashMap<String, Double>();
Map<String, Integer> cateIndexMap = new HashMap<String, Integer>();
if (cs.isCategorical() || cs.isHybrid()) {
List<String> binCategories = cs.getBinCategories();
cateColumnNameNames.put(columnNum, binCategories);
for (int j = 0; j < binCategories.size(); j++) {
String currCate = binCategories.get(j);
if (currCate.contains(Constants.CATEGORICAL_GROUP_VAL_DELIMITER)) {
// merged category should be flatten, use own split function to avoid depending on guava jar in
// prediction
String[] splits = StringUtils.split(currCate, Constants.CATEGORICAL_GROUP_VAL_DELIMITER);
for (String str : splits) {
woeMap.put(str, binWoes.get(j));
woeWgtMap.put(str, binWgtWoes.get(j));
posRateMap.put(str, binPosRates.get(j));
cateIndexMap.put(str, j);
}
} else {
woeMap.put(currCate, binWoes.get(j));
woeWgtMap.put(currCate, binWgtWoes.get(j));
posRateMap.put(currCate, binPosRates.get(j));
cateIndexMap.put(currCate, j);
}
}
// append last missing bin
woeMap.put(Constants.EMPTY_CATEGORY, binWoes.get(binCategories.size()));
woeWgtMap.put(Constants.EMPTY_CATEGORY, binWgtWoes.get(binCategories.size()));
posRateMap.put(Constants.EMPTY_CATEGORY, binPosRates.get(binCategories.size()));
}
if (cs.isNumerical() || cs.isHybrid()) {
numerBinBoundaries.put(columnNum, cs.getBinBoundaries());
numerWoes.put(columnNum, binWoes);
numerWgtWoes.put(columnNum, binWgtWoes);
}
cateWoeMap.put(columnNum, woeMap);
cateWgtWoeMap.put(columnNum, woeWgtMap);
binPosRateMap.put(columnNum, posRateMap);
cateIndexMapping.put(columnNum, cateIndexMap);
numerMeanMap.put(columnNum, cs.getMean());
numerStddevMap.put(columnNum, cs.getStddev());
woeMeanMap.put(columnNum, cs.getWoeMean());
woeStddevMap.put(columnNum, cs.getWoeStddev());
wgtWoeMeanMap.put(columnNum, cs.getWoeWgtMean());
wgtWoeStddevMap.put(columnNum, cs.getWoeWgtStddev());
cutoffMap.put(columnNum, cs.getCutoff());
}
Map<Integer, Integer> columnMap = new HashMap<Integer, Integer>();
int columnMapSize = dis.readInt();
for (int i = 0; i < columnMapSize; i++) {
columnMap.put(dis.readInt(), dis.readInt());
}
int size = dis.readInt();
List<BasicFloatNetwork> networks = new ArrayList<BasicFloatNetwork>();
for (int i = 0; i < size; i++) {
networks.add(new PersistBasicFloatNetwork().readNetwork(dis));
}
return new IndependentNNModel(networks, normType, numNameMap, cateColumnNameNames, columnMap, cateWoeMap, cateWgtWoeMap, binPosRateMap, numerBinBoundaries, numerWgtWoes, numerWoes, cutoffMap, numerMeanMap, numerStddevMap, woeMeanMap, woeStddevMap, wgtWoeMeanMap, wgtWoeStddevMap, columnTypeMap, cateIndexMapping);
}
use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.
the class ModelSpecLoaderUtils method locateBasicModels.
/**
* Find model spec files
*
* @param modelConfig
* model config
* @param evalConfig
* eval configuration
* @param sourceType
* {@link SourceType} LOCAL or HDFS?
* @return The basic model file list
* @throws IOException
* Exception when fail to load basic models
*/
public static List<FileStatus> locateBasicModels(ModelConfig modelConfig, EvalConfig evalConfig, SourceType sourceType) throws IOException {
// we have to register PersistBasicFloatNetwork for loading such models
PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
List<FileStatus> listStatus = findModels(modelConfig, evalConfig, sourceType);
if (CollectionUtils.isEmpty(listStatus)) {
// throw new ShifuException(ShifuErrorCode.ERROR_MODEL_FILE_NOT_FOUND);
// disable exception, since we there maybe sub-models
listStatus = findGenericModels(modelConfig, evalConfig, sourceType);
// if models not found, continue which makes eval works when training is in progress.
if (CollectionUtils.isNotEmpty(listStatus)) {
return listStatus;
}
}
// to avoid the *unix and windows file list order
Collections.sort(listStatus, new Comparator<FileStatus>() {
@Override
public int compare(FileStatus f1, FileStatus f2) {
return f1.getPath().getName().compareToIgnoreCase(f2.getPath().getName());
}
});
// added in shifu 0.2.5 to slice models not belonging to last training
int baggingModelSize = modelConfig.getTrain().getBaggingNum();
if (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()) {
baggingModelSize = modelConfig.getTags().size();
}
Integer kCrossValidation = modelConfig.getTrain().getNumKFold();
if (kCrossValidation != null && kCrossValidation > 0) {
// if kfold is enabled , bagging set it to bagging model size
baggingModelSize = kCrossValidation;
}
GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
if (gs.hasHyperParam()) {
// if it is grid search, set model size to all flatten params
baggingModelSize = gs.getFlattenParams().size();
}
listStatus = listStatus.size() <= baggingModelSize ? listStatus : listStatus.subList(0, baggingModelSize);
return listStatus;
}
use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.
the class ModelSpecLoaderUtils method loadBasicModels.
/**
* Load neural network models from specified file path
*
* @param modelsPath
* - a file or directory that contains .nn files
* @param alg
* the algorithm
* @param isConvertToProb
* if convert to prob for gbt model
* @param gbtScoreConvertStrategy
* specify how to convert gbt raw score
* @return - a list of @BasicML
* @throws IOException
* - throw exception when loading model files
*/
public static List<BasicML> loadBasicModels(final String modelsPath, final ALGORITHM alg, boolean isConvertToProb, String gbtScoreConvertStrategy) throws IOException {
if (modelsPath == null || alg == null || ALGORITHM.DT.equals(alg)) {
throw new IllegalArgumentException("The model path shouldn't be null");
}
// we have to register PersistBasicFloatNetwork for loading such models
if (ALGORITHM.NN.equals(alg)) {
PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
}
File modelsPathDir = new File(modelsPath);
File[] modelFiles = modelsPathDir.listFiles(new FilenameFilter() {
@Override
public boolean accept(File dir, String name) {
return name.endsWith("." + alg.name().toLowerCase());
}
});
if (modelFiles != null) {
// sort file names
Arrays.sort(modelFiles, new Comparator<File>() {
@Override
public int compare(File from, File to) {
return from.getName().compareTo(to.getName());
}
});
List<BasicML> models = new ArrayList<BasicML>(modelFiles.length);
for (File nnf : modelFiles) {
InputStream is = null;
try {
is = new FileInputStream(nnf);
if (ALGORITHM.NN.equals(alg)) {
GzipStreamPair pair = GzipStreamPair.isGZipFormat(is);
if (pair.isGzip()) {
models.add(BasicML.class.cast(NNModel.loadFromStream(pair.getInput())));
} else {
models.add(BasicML.class.cast(EncogDirectoryPersistence.loadObject(pair.getInput())));
}
} else if (ALGORITHM.LR.equals(alg)) {
models.add(LR.loadFromStream(is));
} else if (ALGORITHM.GBT.equals(alg) || ALGORITHM.RF.equals(alg)) {
models.add(TreeModel.loadFromStream(is, isConvertToProb, gbtScoreConvertStrategy));
}
} finally {
IOUtils.closeQuietly(is);
}
}
return models;
} else {
throw new IOException(String.format("Failed to list files in %s", modelsPathDir.getAbsolutePath()));
}
}
use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.
the class NNModelSpecTest method testModelFitIn.
@Test
public void testModelFitIn() {
PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model5.nn")));
BasicNetwork basicNetwork = (BasicNetwork) basicML;
FlatNetwork flatNetwork = basicNetwork.getFlat();
BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model6.nn")));
BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
NNMaster master = new NNMaster();
Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }));
Assert.assertEquals(fixedWeightIndexSet.size(), 931);
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }), false);
Assert.assertEquals(fixedWeightIndexSet.size(), 910);
}
Aggregations