use of ml.shifu.shifu.core.dtrain.random.XavierWeightRandomizer in project shifu by ShifuML.
the class DTrainUtils method generateNetwork.
// public static BasicNetwork generateNetwork(int in, int out, int numLayers, List<String> actFunc,
// List<Integer> hiddenNodeList, boolean isRandomizeWeights, double dropoutRate) {
// return generateNetwork(in, out, numLayers, actFunc, hiddenNodeList, isRandomizeWeights, dropoutRate,
// WGT_INIT_DEFAULT);
// }
public static BasicNetwork generateNetwork(int in, int out, int numLayers, List<String> actFunc, List<Integer> hiddenNodeList, boolean isRandomizeWeights, double dropoutRate, String wgtInit, boolean isLinearTarget, String outputActivationFunc) {
final BasicFloatNetwork network = new BasicFloatNetwork();
// in shifuconfig, we have a switch to control enable input layer dropout
if (Boolean.valueOf(Environment.getProperty(CommonConstants.SHIFU_TRAIN_NN_INPUTLAYERDROPOUT_ENABLE, "true"))) {
// we need to guarantee that input layer dropout rate is 40% of hiddenlayer dropout rate
network.addLayer(new BasicDropoutLayer(new ActivationLinear(), true, in, dropoutRate * 0.4d));
} else {
network.addLayer(new BasicDropoutLayer(new ActivationLinear(), true, in, 0d));
}
// int hiddenNodes = 0;
for (int i = 0; i < numLayers; i++) {
String func = actFunc.get(i);
Integer numHiddenNode = hiddenNodeList.get(i);
// hiddenNodes += numHiddenNode;
if (func.equalsIgnoreCase(NNConstants.NN_LINEAR)) {
network.addLayer(new BasicDropoutLayer(new ActivationLinear(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_SIGMOID)) {
network.addLayer(new BasicDropoutLayer(new ActivationSigmoid(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_TANH)) {
network.addLayer(new BasicDropoutLayer(new ActivationTANH(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_LOG)) {
network.addLayer(new BasicDropoutLayer(new ActivationLOG(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_SIN)) {
network.addLayer(new BasicDropoutLayer(new ActivationSIN(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_RELU)) {
network.addLayer(new BasicDropoutLayer(new ActivationReLU(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_LEAKY_RELU)) {
network.addLayer(new BasicDropoutLayer(new ActivationLeakyReLU(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_SWISH)) {
network.addLayer(new BasicDropoutLayer(new ActivationSwish(), true, numHiddenNode, dropoutRate));
} else if (func.equalsIgnoreCase(NNConstants.NN_PTANH)) {
network.addLayer(new BasicDropoutLayer(new ActivationPTANH(), true, numHiddenNode, dropoutRate));
} else {
network.addLayer(new BasicDropoutLayer(new ActivationSigmoid(), true, numHiddenNode, dropoutRate));
}
}
if (isLinearTarget) {
if (NNConstants.NN_RELU.equalsIgnoreCase(outputActivationFunc)) {
network.addLayer(new BasicLayer(new ActivationReLU(), true, out));
} else if (NNConstants.NN_LEAKY_RELU.equalsIgnoreCase(outputActivationFunc)) {
network.addLayer(new BasicLayer(new ActivationLeakyReLU(), true, out));
} else if (NNConstants.NN_SWISH.equalsIgnoreCase(outputActivationFunc)) {
network.addLayer(new BasicLayer(new ActivationSwish(), true, out));
} else {
network.addLayer(new BasicLayer(new ActivationLinear(), true, out));
}
} else {
network.addLayer(new BasicLayer(new ActivationSigmoid(), false, out));
}
NeuralStructure structure = network.getStructure();
if (network.getStructure() instanceof FloatNeuralStructure) {
((FloatNeuralStructure) structure).finalizeStruct();
} else {
structure.finalizeStructure();
}
if (isRandomizeWeights) {
if (wgtInit == null || wgtInit.length() == 0) {
// default randomization
network.reset();
} else if (wgtInit.equalsIgnoreCase(WGT_INIT_GAUSSIAN)) {
new GaussianRandomizer(0, 1).randomize(network);
} else if (wgtInit.equalsIgnoreCase(WGT_INIT_XAVIER)) {
new XavierWeightRandomizer().randomize(network);
} else if (wgtInit.equalsIgnoreCase(WGT_INIT_HE)) {
new HeWeightRandomizer().randomize(network);
} else if (wgtInit.equalsIgnoreCase(WGT_INIT_LECUN)) {
new LecunWeightRandomizer().randomize(network);
} else if (wgtInit.equalsIgnoreCase(WGT_INIT_DEFAULT)) {
// default randomization
network.reset();
} else {
// default randomization
network.reset();
}
}
return network;
}
Aggregations