use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.
the class TrainModelProcessor method inputOutputModelCheckSuccess.
@SuppressWarnings("unchecked")
private boolean inputOutputModelCheckSuccess(FileSystem fileSystem, Path modelPath, Map<String, Object> modelParams) throws IOException {
BasicML basicML = ModelSpecLoaderUtils.loadModel(this.modelConfig, modelPath, fileSystem);
BasicFloatNetwork model = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(basicML);
int[] outputCandidateCounts = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), getColumnConfigList());
int inputs = outputCandidateCounts[0] == 0 ? outputCandidateCounts[2] : outputCandidateCounts[0];
boolean isInputOutConsistent = model.getInputCount() <= inputs && model.getOutputCount() == outputCandidateCounts[1];
if (!isInputOutConsistent) {
return false;
}
// same hidden layer ?
boolean isHasSameHiddenLayer = (model.getLayerCount() - 2) <= (Integer) modelParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
if (!isHasSameHiddenLayer) {
return false;
}
// same hidden nodes ?
boolean isHasSameHiddenNodes = true;
// same activations ?
boolean isHasSameHiddenActivation = true;
List<Integer> hiddenNodeList = (List<Integer>) modelParams.get(CommonConstants.NUM_HIDDEN_NODES);
List<String> actFuncList = (List<String>) modelParams.get(CommonConstants.ACTIVATION_FUNC);
for (int i = 1; i < model.getLayerCount() - 1; i++) {
if (model.getLayerNeuronCount(i) > hiddenNodeList.get(i - 1)) {
isHasSameHiddenNodes = false;
break;
}
ActivationFunction activation = model.getActivation(i);
String actFunc = actFuncList.get(i - 1);
if (actFunc.equalsIgnoreCase(NNConstants.NN_LINEAR)) {
isHasSameHiddenActivation = ActivationLinear.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_SIGMOID)) {
isHasSameHiddenActivation = ActivationSigmoid.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_TANH)) {
isHasSameHiddenActivation = ActivationTANH.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_LOG)) {
isHasSameHiddenActivation = ActivationLOG.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_SIN)) {
isHasSameHiddenActivation = ActivationSIN.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_RELU)) {
isHasSameHiddenActivation = ActivationReLU.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_LEAKY_RELU)) {
isHasSameHiddenActivation = ActivationLeakyReLU.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_SWISH)) {
isHasSameHiddenActivation = ActivationSwish.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_PTANH)) {
isHasSameHiddenActivation = ActivationPTANH.class == activation.getClass();
} else {
isHasSameHiddenActivation = ActivationSigmoid.class == activation.getClass();
}
if (!isHasSameHiddenActivation) {
break;
}
}
if (!isHasSameHiddenNodes || !isHasSameHiddenActivation) {
return false;
}
return true;
}
use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.
the class NNOutput method initNetwork.
@SuppressWarnings("unchecked")
private void initNetwork(MasterContext<NNParams, NNParams> context) {
int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
boolean isLinearTarget = CommonUtils.isLinearTarget(modelConfig, columnConfigList);
@SuppressWarnings("unused") int inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
// if is one vs all classification, outputNodeCount is set to 1, if classes=2, outputNodeCount is also 1
int classes = modelConfig.getTags().size();
int outputNodeCount = (isLinearTarget || modelConfig.isRegression()) ? inputOutputIndex[1] : (modelConfig.getTrain().isOneVsAll() ? inputOutputIndex[1] : (classes == 2 ? 1 : classes));
int numLayers = (Integer) validParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
List<String> actFunc = (List<String>) validParams.get(CommonConstants.ACTIVATION_FUNC);
List<Integer> hiddenNodeList = (List<Integer>) validParams.get(CommonConstants.NUM_HIDDEN_NODES);
boolean isAfterVarSelect = inputOutputIndex[0] != 0;
// cache all feature list for sampling features
List<Integer> allFeatures = NormalUtils.getAllFeatureList(columnConfigList, isAfterVarSelect);
String subsetStr = context.getProps().getProperty(CommonConstants.SHIFU_NN_FEATURE_SUBSET);
if (StringUtils.isBlank(subsetStr)) {
this.subFeatures = new HashSet<Integer>(allFeatures);
} else {
String[] splits = subsetStr.split(",");
this.subFeatures = new HashSet<Integer>();
for (String split : splits) {
this.subFeatures.add(Integer.parseInt(split));
}
}
int featureInputsCnt = DTrainUtils.getFeatureInputsCnt(modelConfig, columnConfigList, this.subFeatures);
String outputActivationFunc = (String) validParams.get(CommonConstants.OUTPUT_ACTIVATION_FUNC);
this.network = DTrainUtils.generateNetwork(featureInputsCnt, outputNodeCount, numLayers, actFunc, hiddenNodeList, false, this.dropoutRate, this.wgtInit, CommonUtils.isLinearTarget(modelConfig, columnConfigList), outputActivationFunc);
((BasicFloatNetwork) this.network).setFeatureSet(this.subFeatures);
// register here to save models
PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
}
use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.
the class IndependentNNModel method compute.
/**
* Given double array data, compute score values of neural network
*
* @param data
* data array includes only effective column data, numeric value is real value after normalization,
* categorical feature value is pos rates or woe .
* @return neural network model output, if multiple models, do averaging on all models outputs
*/
public double[] compute(double[] data) {
if (this.basicNetworks == null || this.basicNetworks.size() == 0) {
throw new IllegalStateException("no models inside");
}
if (this.basicNetworks.size() == 1) {
return this.basicNetworks.get(0).compute(new BasicMLData(data)).getData();
} else {
int outputSize = this.basicNetworks.get(0).getOutputCount();
int modelSize = this.basicNetworks.size();
double[] results = new double[outputSize];
for (BasicFloatNetwork network : this.basicNetworks) {
double[] currResults = network.compute(new BasicMLData(data)).getData();
assert currResults.length == results.length;
for (int i = 0; i < currResults.length; i++) {
// directly do averaging on each model output element
results[i] += currResults[i] / modelSize;
}
}
return results;
}
}
Aggregations