Search in sources :

Example 11 with BasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.

the class TrainModelProcessor method inputOutputModelCheckSuccess.

@SuppressWarnings("unchecked")
private boolean inputOutputModelCheckSuccess(FileSystem fileSystem, Path modelPath, Map<String, Object> modelParams) throws IOException {
    BasicML basicML = ModelSpecLoaderUtils.loadModel(this.modelConfig, modelPath, fileSystem);
    BasicFloatNetwork model = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(basicML);
    int[] outputCandidateCounts = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), getColumnConfigList());
    int inputs = outputCandidateCounts[0] == 0 ? outputCandidateCounts[2] : outputCandidateCounts[0];
    boolean isInputOutConsistent = model.getInputCount() <= inputs && model.getOutputCount() == outputCandidateCounts[1];
    if (!isInputOutConsistent) {
        return false;
    }
    // same hidden layer ?
    boolean isHasSameHiddenLayer = (model.getLayerCount() - 2) <= (Integer) modelParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
    if (!isHasSameHiddenLayer) {
        return false;
    }
    // same hidden nodes ?
    boolean isHasSameHiddenNodes = true;
    // same activations ?
    boolean isHasSameHiddenActivation = true;
    List<Integer> hiddenNodeList = (List<Integer>) modelParams.get(CommonConstants.NUM_HIDDEN_NODES);
    List<String> actFuncList = (List<String>) modelParams.get(CommonConstants.ACTIVATION_FUNC);
    for (int i = 1; i < model.getLayerCount() - 1; i++) {
        if (model.getLayerNeuronCount(i) > hiddenNodeList.get(i - 1)) {
            isHasSameHiddenNodes = false;
            break;
        }
        ActivationFunction activation = model.getActivation(i);
        String actFunc = actFuncList.get(i - 1);
        if (actFunc.equalsIgnoreCase(NNConstants.NN_LINEAR)) {
            isHasSameHiddenActivation = ActivationLinear.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_SIGMOID)) {
            isHasSameHiddenActivation = ActivationSigmoid.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_TANH)) {
            isHasSameHiddenActivation = ActivationTANH.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_LOG)) {
            isHasSameHiddenActivation = ActivationLOG.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_SIN)) {
            isHasSameHiddenActivation = ActivationSIN.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_RELU)) {
            isHasSameHiddenActivation = ActivationReLU.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_LEAKY_RELU)) {
            isHasSameHiddenActivation = ActivationLeakyReLU.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_SWISH)) {
            isHasSameHiddenActivation = ActivationSwish.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_PTANH)) {
            isHasSameHiddenActivation = ActivationPTANH.class == activation.getClass();
        } else {
            isHasSameHiddenActivation = ActivationSigmoid.class == activation.getClass();
        }
        if (!isHasSameHiddenActivation) {
            break;
        }
    }
    if (!isHasSameHiddenNodes || !isHasSameHiddenActivation) {
        return false;
    }
    return true;
}
Also used : BasicML(org.encog.ml.BasicML) RequiredFieldList(org.apache.pig.LoadPushDown.RequiredFieldList) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)

Example 12 with BasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.

the class NNOutput method initNetwork.

@SuppressWarnings("unchecked")
private void initNetwork(MasterContext<NNParams, NNParams> context) {
    int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
    boolean isLinearTarget = CommonUtils.isLinearTarget(modelConfig, columnConfigList);
    @SuppressWarnings("unused") int inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
    // if is one vs all classification, outputNodeCount is set to 1, if classes=2, outputNodeCount is also 1
    int classes = modelConfig.getTags().size();
    int outputNodeCount = (isLinearTarget || modelConfig.isRegression()) ? inputOutputIndex[1] : (modelConfig.getTrain().isOneVsAll() ? inputOutputIndex[1] : (classes == 2 ? 1 : classes));
    int numLayers = (Integer) validParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
    List<String> actFunc = (List<String>) validParams.get(CommonConstants.ACTIVATION_FUNC);
    List<Integer> hiddenNodeList = (List<Integer>) validParams.get(CommonConstants.NUM_HIDDEN_NODES);
    boolean isAfterVarSelect = inputOutputIndex[0] != 0;
    // cache all feature list for sampling features
    List<Integer> allFeatures = NormalUtils.getAllFeatureList(columnConfigList, isAfterVarSelect);
    String subsetStr = context.getProps().getProperty(CommonConstants.SHIFU_NN_FEATURE_SUBSET);
    if (StringUtils.isBlank(subsetStr)) {
        this.subFeatures = new HashSet<Integer>(allFeatures);
    } else {
        String[] splits = subsetStr.split(",");
        this.subFeatures = new HashSet<Integer>();
        for (String split : splits) {
            this.subFeatures.add(Integer.parseInt(split));
        }
    }
    int featureInputsCnt = DTrainUtils.getFeatureInputsCnt(modelConfig, columnConfigList, this.subFeatures);
    String outputActivationFunc = (String) validParams.get(CommonConstants.OUTPUT_ACTIVATION_FUNC);
    this.network = DTrainUtils.generateNetwork(featureInputsCnt, outputNodeCount, numLayers, actFunc, hiddenNodeList, false, this.dropoutRate, this.wgtInit, CommonUtils.isLinearTarget(modelConfig, columnConfigList), outputActivationFunc);
    ((BasicFloatNetwork) this.network).setFeatureSet(this.subFeatures);
    // register here to save models
    PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
}
Also used : List(java.util.List) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)

Example 13 with BasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.

the class IndependentNNModel method compute.

/**
 * Given double array data, compute score values of neural network
 *
 * @param data
 *            data array includes only effective column data, numeric value is real value after normalization,
 *            categorical feature value is pos rates or woe .
 * @return neural network model output, if multiple models, do averaging on all models outputs
 */
public double[] compute(double[] data) {
    if (this.basicNetworks == null || this.basicNetworks.size() == 0) {
        throw new IllegalStateException("no models inside");
    }
    if (this.basicNetworks.size() == 1) {
        return this.basicNetworks.get(0).compute(new BasicMLData(data)).getData();
    } else {
        int outputSize = this.basicNetworks.get(0).getOutputCount();
        int modelSize = this.basicNetworks.size();
        double[] results = new double[outputSize];
        for (BasicFloatNetwork network : this.basicNetworks) {
            double[] currResults = network.compute(new BasicMLData(data)).getData();
            assert currResults.length == results.length;
            for (int i = 0; i < currResults.length; i++) {
                // directly do averaging on each model output element
                results[i] += currResults[i] / modelSize;
            }
        }
        return results;
    }
}
Also used : BasicMLData(org.encog.ml.data.basic.BasicMLData) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)

Aggregations

BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)13 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)5 ArrayList (java.util.ArrayList)4 PersistBasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)4 BasicML (org.encog.ml.BasicML)4 List (java.util.List)3 BasicMLData (org.encog.ml.data.basic.BasicMLData)3 IOException (java.io.IOException)2 Path (org.apache.hadoop.fs.Path)2 RequiredFieldList (org.apache.pig.LoadPushDown.RequiredFieldList)2 BufferedInputStream (java.io.BufferedInputStream)1 DataInputStream (java.io.DataInputStream)1 Comparator (java.util.Comparator)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 Callable (java.util.concurrent.Callable)1 GZIPInputStream (java.util.zip.GZIPInputStream)1 GuaguaRuntimeException (ml.shifu.guagua.GuaguaRuntimeException)1 GuaguaMapReduceClient (ml.shifu.guagua.mapreduce.GuaguaMapReduceClient)1 ScoreObject (ml.shifu.shifu.container.ScoreObject)1