use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.
the class NNMaster method initOrRecoverParams.
private NNParams initOrRecoverParams(MasterContext<NNParams, NNParams> context) {
// read existing model weights
NNParams params = null;
try {
Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
BasicML basicML = ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
params = initWeights();
BasicFloatNetwork existingModel = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(basicML);
if (existingModel != null) {
LOG.info("Starting to train model from existing model {}.", modelPath);
int mspecCompareResult = new NNStructureComparator().compare(this.flatNetwork, existingModel.getFlat());
if (mspecCompareResult == 0) {
// same model structure
params.setWeights(existingModel.getFlat().getWeights());
this.fixedWeightIndexSet = getFixedWights(fixedLayers);
} else if (mspecCompareResult == 1) {
// new model structure is larger than existing one
this.fixedWeightIndexSet = fitExistingModelIn(existingModel.getFlat(), this.flatNetwork, this.fixedLayers, this.fixedBias);
} else {
// new model structure is smaller, couldn't hold existing one
throw new GuaguaRuntimeException("Network changed for recover or continuous training. " + "New network couldn't hold existing network!");
}
} else {
LOG.info("Starting to train model from scratch.");
}
} catch (IOException e) {
throw new GuaguaRuntimeException(e);
}
return params;
}
use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.
the class IndependentNNModel method loadFromStream.
/**
* Load model instance from input stream which is saved in NNOutput for specified binary format.
*
* @param input
* the input stream, flat input stream or gzip input stream both OK
* @param isRemoveNameSpace
* is remove name space or not
* @return the nn model instance
* @throws IOException
* any IOException in de-serialization.
*/
public static IndependentNNModel loadFromStream(InputStream input, boolean isRemoveNameSpace) throws IOException {
DataInputStream dis = null;
// check if gzip or not
try {
byte[] header = new byte[2];
BufferedInputStream bis = new BufferedInputStream(input);
bis.mark(2);
int result = bis.read(header);
bis.reset();
int ss = (header[0] & 0xff) | ((header[1] & 0xff) << 8);
if (result != -1 && ss == GZIPInputStream.GZIP_MAGIC) {
dis = new DataInputStream(new GZIPInputStream(bis));
} else {
dis = new DataInputStream(bis);
}
} catch (java.io.IOException e) {
dis = new DataInputStream(input);
}
int version = dis.readInt();
IndependentNNModel.setVersion(version);
String normStr = ml.shifu.shifu.core.dtrain.StringUtils.readString(dis);
NormType normType = NormType.valueOf(normStr.toUpperCase());
// for all features
Map<Integer, String> numNameMap = new HashMap<Integer, String>();
Map<Integer, List<String>> cateColumnNameNames = new HashMap<Integer, List<String>>();
// for categorical features
Map<Integer, Map<String, Double>> cateWoeMap = new HashMap<Integer, Map<String, Double>>();
Map<Integer, Map<String, Double>> cateWgtWoeMap = new HashMap<Integer, Map<String, Double>>();
Map<Integer, Map<String, Double>> binPosRateMap = new HashMap<Integer, Map<String, Double>>();
// for numerical features
Map<Integer, List<Double>> numerBinBoundaries = new HashMap<Integer, List<Double>>();
Map<Integer, List<Double>> numerWoes = new HashMap<Integer, List<Double>>();
Map<Integer, List<Double>> numerWgtWoes = new HashMap<Integer, List<Double>>();
// for all features
Map<Integer, Double> numerMeanMap = new HashMap<Integer, Double>();
Map<Integer, Double> numerStddevMap = new HashMap<Integer, Double>();
Map<Integer, Double> woeMeanMap = new HashMap<Integer, Double>();
Map<Integer, Double> woeStddevMap = new HashMap<Integer, Double>();
Map<Integer, Double> wgtWoeMeanMap = new HashMap<Integer, Double>();
Map<Integer, Double> wgtWoeStddevMap = new HashMap<Integer, Double>();
Map<Integer, Double> cutoffMap = new HashMap<Integer, Double>();
Map<Integer, ColumnType> columnTypeMap = new HashMap<Integer, ColumnType>();
Map<Integer, Map<String, Integer>> cateIndexMapping = new HashMap<Integer, Map<String, Integer>>();
int columnSize = dis.readInt();
for (int i = 0; i < columnSize; i++) {
NNColumnStats cs = new NNColumnStats();
cs.readFields(dis);
List<Double> binWoes = cs.getBinCountWoes();
List<Double> binWgtWoes = cs.getBinWeightWoes();
List<Double> binPosRates = cs.getBinPosRates();
int columnNum = cs.getColumnNum();
columnTypeMap.put(columnNum, cs.getColumnType());
if (isRemoveNameSpace) {
// remove name-space in column name to make it be called by simple name
numNameMap.put(columnNum, StringUtils.getSimpleColumnName(cs.getColumnName()));
} else {
numNameMap.put(columnNum, cs.getColumnName());
}
// for categorical features
Map<String, Double> woeMap = new HashMap<String, Double>();
Map<String, Double> woeWgtMap = new HashMap<String, Double>();
Map<String, Double> posRateMap = new HashMap<String, Double>();
Map<String, Integer> cateIndexMap = new HashMap<String, Integer>();
if (cs.isCategorical() || cs.isHybrid()) {
List<String> binCategories = cs.getBinCategories();
cateColumnNameNames.put(columnNum, binCategories);
for (int j = 0; j < binCategories.size(); j++) {
String currCate = binCategories.get(j);
if (currCate.contains(Constants.CATEGORICAL_GROUP_VAL_DELIMITER)) {
// merged category should be flatten, use own split function to avoid depending on guava jar in
// prediction
String[] splits = StringUtils.split(currCate, Constants.CATEGORICAL_GROUP_VAL_DELIMITER);
for (String str : splits) {
woeMap.put(str, binWoes.get(j));
woeWgtMap.put(str, binWgtWoes.get(j));
posRateMap.put(str, binPosRates.get(j));
cateIndexMap.put(str, j);
}
} else {
woeMap.put(currCate, binWoes.get(j));
woeWgtMap.put(currCate, binWgtWoes.get(j));
posRateMap.put(currCate, binPosRates.get(j));
cateIndexMap.put(currCate, j);
}
}
// append last missing bin
woeMap.put(Constants.EMPTY_CATEGORY, binWoes.get(binCategories.size()));
woeWgtMap.put(Constants.EMPTY_CATEGORY, binWgtWoes.get(binCategories.size()));
posRateMap.put(Constants.EMPTY_CATEGORY, binPosRates.get(binCategories.size()));
}
if (cs.isNumerical() || cs.isHybrid()) {
numerBinBoundaries.put(columnNum, cs.getBinBoundaries());
numerWoes.put(columnNum, binWoes);
numerWgtWoes.put(columnNum, binWgtWoes);
}
cateWoeMap.put(columnNum, woeMap);
cateWgtWoeMap.put(columnNum, woeWgtMap);
binPosRateMap.put(columnNum, posRateMap);
cateIndexMapping.put(columnNum, cateIndexMap);
numerMeanMap.put(columnNum, cs.getMean());
numerStddevMap.put(columnNum, cs.getStddev());
woeMeanMap.put(columnNum, cs.getWoeMean());
woeStddevMap.put(columnNum, cs.getWoeStddev());
wgtWoeMeanMap.put(columnNum, cs.getWoeWgtMean());
wgtWoeStddevMap.put(columnNum, cs.getWoeWgtStddev());
cutoffMap.put(columnNum, cs.getCutoff());
}
Map<Integer, Integer> columnMap = new HashMap<Integer, Integer>();
int columnMapSize = dis.readInt();
for (int i = 0; i < columnMapSize; i++) {
columnMap.put(dis.readInt(), dis.readInt());
}
int size = dis.readInt();
List<BasicFloatNetwork> networks = new ArrayList<BasicFloatNetwork>();
for (int i = 0; i < size; i++) {
networks.add(new PersistBasicFloatNetwork().readNetwork(dis));
}
return new IndependentNNModel(networks, normType, numNameMap, cateColumnNameNames, columnMap, cateWoeMap, cateWgtWoeMap, binPosRateMap, numerBinBoundaries, numerWgtWoes, numerWoes, cutoffMap, numerMeanMap, numerStddevMap, woeMeanMap, woeStddevMap, wgtWoeMeanMap, wgtWoeStddevMap, columnTypeMap, cateIndexMapping);
}
use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.
the class ZscoreLocalTransformCreator method build.
@Override
public LocalTransformations build(BasicML basicML) {
LocalTransformations localTransformations = new LocalTransformations();
if (basicML instanceof BasicFloatNetwork) {
BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
Set<Integer> featureSet = bfn.getFeatureSet();
for (ColumnConfig config : columnConfigList) {
if (config.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(config.getColumnNum()))) {
double cutoff = modelConfig.getNormalizeStdDevCutOff();
List<DerivedField> deriviedFields = config.isCategorical() ? createCategoricalDerivedField(config, cutoff, modelConfig.getNormalizeType()) : createNumericalDerivedField(config, cutoff, modelConfig.getNormalizeType());
localTransformations.addDerivedFields(deriviedFields.toArray(new DerivedField[deriviedFields.size()]));
}
}
} else {
for (ColumnConfig config : columnConfigList) {
if (config.isFinalSelect()) {
double cutoff = modelConfig.getNormalizeStdDevCutOff();
List<DerivedField> deriviedFields = config.isCategorical() ? createCategoricalDerivedField(config, cutoff, modelConfig.getNormalizeType()) : createNumericalDerivedField(config, cutoff, modelConfig.getNormalizeType());
localTransformations.addDerivedFields(deriviedFields.toArray(new DerivedField[deriviedFields.size()]));
}
}
}
return localTransformations;
}
use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.
the class MiningSchemaCreator method build.
@Override
public MiningSchema build(BasicML basicML) {
MiningSchema miningSchema = new MiningSchema();
boolean isSegExpansionMode = columnConfigList.size() > datasetHeaders.length;
int segSize = segmentExpansions.size();
if (basicML != null && basicML instanceof BasicFloatNetwork) {
BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
Set<Integer> featureSet = bfn.getFeatureSet();
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.getColumnNum() >= datasetHeaders.length) {
// in order
break;
}
if (isActiveColumn(featureSet, columnConfig)) {
if (columnConfig.isTarget()) {
List<MiningField> miningFields = createTargetMingFields(columnConfig);
miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
} else {
miningSchema.addMiningFields(createActiveMingFields(columnConfig));
}
} else if (isSegExpansionMode) {
// even current column not selected, if segment column selected, we should keep raw column
for (int i = 0; i < segSize; i++) {
int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
ColumnConfig cc = columnConfigList.get(newIndex);
if (cc.isFinalSelect()) {
// if one segment feature is selected, we should put raw column in
if (columnConfig.isTarget()) {
List<MiningField> miningFields = createTargetMingFields(columnConfig);
miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
} else {
miningSchema.addMiningFields(createActiveMingFields(columnConfig));
}
break;
}
}
}
}
} else {
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.getColumnNum() >= datasetHeaders.length) {
// in order
break;
}
// FIXME, if no variable is selected
if (columnConfig.isFinalSelect() || columnConfig.isTarget()) {
if (columnConfig.isTarget()) {
List<MiningField> miningFields = createTargetMingFields(columnConfig);
miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
} else {
miningSchema.addMiningFields(createActiveMingFields(columnConfig));
}
} else if (isSegExpansionMode) {
// even current column not selected, if segment column selected, we should keep raw column
for (int i = 0; i < segSize; i++) {
int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
ColumnConfig cc = columnConfigList.get(newIndex);
if (cc.isFinalSelect()) {
// if one segment feature is selected, we should put raw column in
if (columnConfig.isTarget()) {
List<MiningField> miningFields = createTargetMingFields(columnConfig);
miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
} else {
miningSchema.addMiningFields(createActiveMingFields(columnConfig));
}
break;
}
}
}
}
}
return miningSchema;
}
use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.
the class DTrainUtils method generateNetwork.
// public static BasicNetwork generateNetwork(int in, int out, int numLayers, List<String> actFunc,
// List<Integer> hiddenNodeList, boolean isRandomizeWeights, double dropoutRate) {
// return generateNetwork(in, out, numLayers, actFunc, hiddenNodeList, isRandomizeWeights, dropoutRate,
// WGT_INIT_DEFAULT);
// }
public static BasicNetwork generateNetwork(int in, int out, int numLayers, List<String> actFunc, List<Integer> hiddenNodeList, boolean isRandomizeWeights, double dropoutRate, String wgtInit, boolean isLinearTarget, String outputActivationFunc) {
final BasicFloatNetwork network = new BasicFloatNetwork();
// in shifuconfig, we have a switch to control enable input layer dropout
if (Boolean.valueOf(Environment.getProperty(CommonConstants.SHIFU_TRAIN_NN_INPUTLAYERDROPOUT_ENABLE, "true"))) {
// we need to guarantee that input layer dropout rate is 40% of hiddenlayer dropout rate
network.addLayer(new BasicDropoutLayer(new ActivationLinear(), true, in, dropoutRate * 0.4d));
} else {
network.addLayer(new BasicDropoutLayer(new ActivationLinear(), true, in, 0d));
}
// int hiddenNodes = 0;
for (int i = 0; i < numLayers; i++) {
String func = actFunc.get(i);
Integer numHiddenNode = hiddenNodeList.get(i);
// hiddenNodes += numHiddenNode;
if (func.equalsIgnoreCase(NNConstants.NN_LINEAR)) {
network.addLayer(new BasicDropoutLayer(new ActivationLinear(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_SIGMOID)) {
network.addLayer(new BasicDropoutLayer(new ActivationSigmoid(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_TANH)) {
network.addLayer(new BasicDropoutLayer(new ActivationTANH(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_LOG)) {
network.addLayer(new BasicDropoutLayer(new ActivationLOG(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_SIN)) {
network.addLayer(new BasicDropoutLayer(new ActivationSIN(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_RELU)) {
network.addLayer(new BasicDropoutLayer(new ActivationReLU(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_LEAKY_RELU)) {
network.addLayer(new BasicDropoutLayer(new ActivationLeakyReLU(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_SWISH)) {
network.addLayer(new BasicDropoutLayer(new ActivationSwish(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_PTANH)) {
network.addLayer(new BasicDropoutLayer(new ActivationPTANH(), true, numHiddenNode, dropoutRate));
} else {
network.addLayer(new BasicDropoutLayer(new ActivationSigmoid(), true, numHiddenNode, dropoutRate));
}
}
if (isLinearTarget) {
if (NNConstants.NN_RELU.equalsIgnoreCase(outputActivationFunc)) {
network.addLayer(new BasicLayer(new ActivationReLU(), true, out));
} else if (NNConstants.NN_LEAKY_RELU.equalsIgnoreCase(outputActivationFunc)) {
network.addLayer(new BasicLayer(new ActivationLeakyReLU(), true, out));
} else if (NNConstants.NN_SWISH.equalsIgnoreCase(outputActivationFunc)) {
network.addLayer(new BasicLayer(new ActivationSwish(), true, out));
} else {
network.addLayer(new BasicLayer(new ActivationLinear(), true, out));
}
} else {
network.addLayer(new BasicLayer(new ActivationSigmoid(), false, out));
}
NeuralStructure structure = network.getStructure();
if (network.getStructure() instanceof FloatNeuralStructure) {
((FloatNeuralStructure) structure).finalizeStruct();
} else {
structure.finalizeStructure();
}
if (isRandomizeWeights) {
if (wgtInit == null || wgtInit.length() == 0) {
// default randomization
network.reset();
} else if (wgtInit.equalsIgnoreCase(WGT_INIT_GAUSSIAN)) {
new GaussianRandomizer(0, 1).randomize(network);
} else if (wgtInit.equalsIgnoreCase(WGT_INIT_XAVIER)) {
new XavierWeightRandomizer().randomize(network);
} else if (wgtInit.equalsIgnoreCase(WGT_INIT_HE)) {
new HeWeightRandomizer().randomize(network);
} else if (wgtInit.equalsIgnoreCase(WGT_INIT_LECUN)) {
new LecunWeightRandomizer().randomize(network);
} else if (wgtInit.equalsIgnoreCase(WGT_INIT_DEFAULT)) {
// default randomization
network.reset();
} else {
// default randomization
network.reset();
}
}
return network;
}
Aggregations