Search in sources :

Example 6 with BasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.

the class NNMaster method initOrRecoverParams.

private NNParams initOrRecoverParams(MasterContext<NNParams, NNParams> context) {
    // read existing model weights
    NNParams params = null;
    try {
        Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
        BasicML basicML = ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
        params = initWeights();
        BasicFloatNetwork existingModel = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(basicML);
        if (existingModel != null) {
            LOG.info("Starting to train model from existing model {}.", modelPath);
            int mspecCompareResult = new NNStructureComparator().compare(this.flatNetwork, existingModel.getFlat());
            if (mspecCompareResult == 0) {
                // same model structure
                params.setWeights(existingModel.getFlat().getWeights());
                this.fixedWeightIndexSet = getFixedWights(fixedLayers);
            } else if (mspecCompareResult == 1) {
                // new model structure is larger than existing one
                this.fixedWeightIndexSet = fitExistingModelIn(existingModel.getFlat(), this.flatNetwork, this.fixedLayers, this.fixedBias);
            } else {
                // new model structure is smaller, couldn't hold existing one
                throw new GuaguaRuntimeException("Network changed for recover or continuous training. " + "New network couldn't hold existing network!");
            }
        } else {
            LOG.info("Starting to train model from scratch.");
        }
    } catch (IOException e) {
        throw new GuaguaRuntimeException(e);
    }
    return params;
}
Also used : Path(org.apache.hadoop.fs.Path) BasicML(org.encog.ml.BasicML) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) IOException(java.io.IOException)

Example 7 with BasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.

the class IndependentNNModel method loadFromStream.

/**
 * Load model instance from input stream which is saved in NNOutput for specified binary format.
 *
 * @param input
 *            the input stream, flat input stream or gzip input stream both OK
 * @param isRemoveNameSpace
 *            is remove name space or not
 * @return the nn model instance
 * @throws IOException
 *             any IOException in de-serialization.
 */
public static IndependentNNModel loadFromStream(InputStream input, boolean isRemoveNameSpace) throws IOException {
    DataInputStream dis = null;
    // check if gzip or not
    try {
        byte[] header = new byte[2];
        BufferedInputStream bis = new BufferedInputStream(input);
        bis.mark(2);
        int result = bis.read(header);
        bis.reset();
        int ss = (header[0] & 0xff) | ((header[1] & 0xff) << 8);
        if (result != -1 && ss == GZIPInputStream.GZIP_MAGIC) {
            dis = new DataInputStream(new GZIPInputStream(bis));
        } else {
            dis = new DataInputStream(bis);
        }
    } catch (java.io.IOException e) {
        dis = new DataInputStream(input);
    }
    int version = dis.readInt();
    IndependentNNModel.setVersion(version);
    String normStr = ml.shifu.shifu.core.dtrain.StringUtils.readString(dis);
    NormType normType = NormType.valueOf(normStr.toUpperCase());
    // for all features
    Map<Integer, String> numNameMap = new HashMap<Integer, String>();
    Map<Integer, List<String>> cateColumnNameNames = new HashMap<Integer, List<String>>();
    // for categorical features
    Map<Integer, Map<String, Double>> cateWoeMap = new HashMap<Integer, Map<String, Double>>();
    Map<Integer, Map<String, Double>> cateWgtWoeMap = new HashMap<Integer, Map<String, Double>>();
    Map<Integer, Map<String, Double>> binPosRateMap = new HashMap<Integer, Map<String, Double>>();
    // for numerical features
    Map<Integer, List<Double>> numerBinBoundaries = new HashMap<Integer, List<Double>>();
    Map<Integer, List<Double>> numerWoes = new HashMap<Integer, List<Double>>();
    Map<Integer, List<Double>> numerWgtWoes = new HashMap<Integer, List<Double>>();
    // for all features
    Map<Integer, Double> numerMeanMap = new HashMap<Integer, Double>();
    Map<Integer, Double> numerStddevMap = new HashMap<Integer, Double>();
    Map<Integer, Double> woeMeanMap = new HashMap<Integer, Double>();
    Map<Integer, Double> woeStddevMap = new HashMap<Integer, Double>();
    Map<Integer, Double> wgtWoeMeanMap = new HashMap<Integer, Double>();
    Map<Integer, Double> wgtWoeStddevMap = new HashMap<Integer, Double>();
    Map<Integer, Double> cutoffMap = new HashMap<Integer, Double>();
    Map<Integer, ColumnType> columnTypeMap = new HashMap<Integer, ColumnType>();
    Map<Integer, Map<String, Integer>> cateIndexMapping = new HashMap<Integer, Map<String, Integer>>();
    int columnSize = dis.readInt();
    for (int i = 0; i < columnSize; i++) {
        NNColumnStats cs = new NNColumnStats();
        cs.readFields(dis);
        List<Double> binWoes = cs.getBinCountWoes();
        List<Double> binWgtWoes = cs.getBinWeightWoes();
        List<Double> binPosRates = cs.getBinPosRates();
        int columnNum = cs.getColumnNum();
        columnTypeMap.put(columnNum, cs.getColumnType());
        if (isRemoveNameSpace) {
            // remove name-space in column name to make it be called by simple name
            numNameMap.put(columnNum, StringUtils.getSimpleColumnName(cs.getColumnName()));
        } else {
            numNameMap.put(columnNum, cs.getColumnName());
        }
        // for categorical features
        Map<String, Double> woeMap = new HashMap<String, Double>();
        Map<String, Double> woeWgtMap = new HashMap<String, Double>();
        Map<String, Double> posRateMap = new HashMap<String, Double>();
        Map<String, Integer> cateIndexMap = new HashMap<String, Integer>();
        if (cs.isCategorical() || cs.isHybrid()) {
            List<String> binCategories = cs.getBinCategories();
            cateColumnNameNames.put(columnNum, binCategories);
            for (int j = 0; j < binCategories.size(); j++) {
                String currCate = binCategories.get(j);
                if (currCate.contains(Constants.CATEGORICAL_GROUP_VAL_DELIMITER)) {
                    // merged category should be flatten, use own split function to avoid depending on guava jar in
                    // prediction
                    String[] splits = StringUtils.split(currCate, Constants.CATEGORICAL_GROUP_VAL_DELIMITER);
                    for (String str : splits) {
                        woeMap.put(str, binWoes.get(j));
                        woeWgtMap.put(str, binWgtWoes.get(j));
                        posRateMap.put(str, binPosRates.get(j));
                        cateIndexMap.put(str, j);
                    }
                } else {
                    woeMap.put(currCate, binWoes.get(j));
                    woeWgtMap.put(currCate, binWgtWoes.get(j));
                    posRateMap.put(currCate, binPosRates.get(j));
                    cateIndexMap.put(currCate, j);
                }
            }
            // append last missing bin
            woeMap.put(Constants.EMPTY_CATEGORY, binWoes.get(binCategories.size()));
            woeWgtMap.put(Constants.EMPTY_CATEGORY, binWgtWoes.get(binCategories.size()));
            posRateMap.put(Constants.EMPTY_CATEGORY, binPosRates.get(binCategories.size()));
        }
        if (cs.isNumerical() || cs.isHybrid()) {
            numerBinBoundaries.put(columnNum, cs.getBinBoundaries());
            numerWoes.put(columnNum, binWoes);
            numerWgtWoes.put(columnNum, binWgtWoes);
        }
        cateWoeMap.put(columnNum, woeMap);
        cateWgtWoeMap.put(columnNum, woeWgtMap);
        binPosRateMap.put(columnNum, posRateMap);
        cateIndexMapping.put(columnNum, cateIndexMap);
        numerMeanMap.put(columnNum, cs.getMean());
        numerStddevMap.put(columnNum, cs.getStddev());
        woeMeanMap.put(columnNum, cs.getWoeMean());
        woeStddevMap.put(columnNum, cs.getWoeStddev());
        wgtWoeMeanMap.put(columnNum, cs.getWoeWgtMean());
        wgtWoeStddevMap.put(columnNum, cs.getWoeWgtStddev());
        cutoffMap.put(columnNum, cs.getCutoff());
    }
    Map<Integer, Integer> columnMap = new HashMap<Integer, Integer>();
    int columnMapSize = dis.readInt();
    for (int i = 0; i < columnMapSize; i++) {
        columnMap.put(dis.readInt(), dis.readInt());
    }
    int size = dis.readInt();
    List<BasicFloatNetwork> networks = new ArrayList<BasicFloatNetwork>();
    for (int i = 0; i < size; i++) {
        networks.add(new PersistBasicFloatNetwork().readNetwork(dis));
    }
    return new IndependentNNModel(networks, normType, numNameMap, cateColumnNameNames, columnMap, cateWoeMap, cateWgtWoeMap, binPosRateMap, numerBinBoundaries, numerWgtWoes, numerWoes, cutoffMap, numerMeanMap, numerStddevMap, woeMeanMap, woeStddevMap, wgtWoeMeanMap, wgtWoeStddevMap, columnTypeMap, cateIndexMapping);
}
Also used : ColumnType(ml.shifu.shifu.container.obj.ColumnType) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) GZIPInputStream(java.util.zip.GZIPInputStream) BufferedInputStream(java.io.BufferedInputStream) ArrayList(java.util.ArrayList) List(java.util.List) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) NormType(ml.shifu.shifu.container.obj.ModelNormalizeConf.NormType) IOException(java.io.IOException) DataInputStream(java.io.DataInputStream) HashMap(java.util.HashMap) Map(java.util.Map) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)

Example 8 with BasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.

the class ZscoreLocalTransformCreator method build.

@Override
public LocalTransformations build(BasicML basicML) {
    LocalTransformations localTransformations = new LocalTransformations();
    if (basicML instanceof BasicFloatNetwork) {
        BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
        Set<Integer> featureSet = bfn.getFeatureSet();
        for (ColumnConfig config : columnConfigList) {
            if (config.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(config.getColumnNum()))) {
                double cutoff = modelConfig.getNormalizeStdDevCutOff();
                List<DerivedField> deriviedFields = config.isCategorical() ? createCategoricalDerivedField(config, cutoff, modelConfig.getNormalizeType()) : createNumericalDerivedField(config, cutoff, modelConfig.getNormalizeType());
                localTransformations.addDerivedFields(deriviedFields.toArray(new DerivedField[deriviedFields.size()]));
            }
        }
    } else {
        for (ColumnConfig config : columnConfigList) {
            if (config.isFinalSelect()) {
                double cutoff = modelConfig.getNormalizeStdDevCutOff();
                List<DerivedField> deriviedFields = config.isCategorical() ? createCategoricalDerivedField(config, cutoff, modelConfig.getNormalizeType()) : createNumericalDerivedField(config, cutoff, modelConfig.getNormalizeType());
                localTransformations.addDerivedFields(deriviedFields.toArray(new DerivedField[deriviedFields.size()]));
            }
        }
    }
    return localTransformations;
}
Also used : LocalTransformations(org.dmg.pmml.LocalTransformations) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) DerivedField(org.dmg.pmml.DerivedField)

Example 9 with BasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.

the class MiningSchemaCreator method build.

@Override
public MiningSchema build(BasicML basicML) {
    MiningSchema miningSchema = new MiningSchema();
    boolean isSegExpansionMode = columnConfigList.size() > datasetHeaders.length;
    int segSize = segmentExpansions.size();
    if (basicML != null && basicML instanceof BasicFloatNetwork) {
        BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
        Set<Integer> featureSet = bfn.getFeatureSet();
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.getColumnNum() >= datasetHeaders.length) {
                // in order
                break;
            }
            if (isActiveColumn(featureSet, columnConfig)) {
                if (columnConfig.isTarget()) {
                    List<MiningField> miningFields = createTargetMingFields(columnConfig);
                    miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
                } else {
                    miningSchema.addMiningFields(createActiveMingFields(columnConfig));
                }
            } else if (isSegExpansionMode) {
                // even current column not selected, if segment column selected, we should keep raw column
                for (int i = 0; i < segSize; i++) {
                    int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
                    ColumnConfig cc = columnConfigList.get(newIndex);
                    if (cc.isFinalSelect()) {
                        // if one segment feature is selected, we should put raw column in
                        if (columnConfig.isTarget()) {
                            List<MiningField> miningFields = createTargetMingFields(columnConfig);
                            miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
                        } else {
                            miningSchema.addMiningFields(createActiveMingFields(columnConfig));
                        }
                        break;
                    }
                }
            }
        }
    } else {
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.getColumnNum() >= datasetHeaders.length) {
                // in order
                break;
            }
            // FIXME, if no variable is selected
            if (columnConfig.isFinalSelect() || columnConfig.isTarget()) {
                if (columnConfig.isTarget()) {
                    List<MiningField> miningFields = createTargetMingFields(columnConfig);
                    miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
                } else {
                    miningSchema.addMiningFields(createActiveMingFields(columnConfig));
                }
            } else if (isSegExpansionMode) {
                // even current column not selected, if segment column selected, we should keep raw column
                for (int i = 0; i < segSize; i++) {
                    int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
                    ColumnConfig cc = columnConfigList.get(newIndex);
                    if (cc.isFinalSelect()) {
                        // if one segment feature is selected, we should put raw column in
                        if (columnConfig.isTarget()) {
                            List<MiningField> miningFields = createTargetMingFields(columnConfig);
                            miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
                        } else {
                            miningSchema.addMiningFields(createActiveMingFields(columnConfig));
                        }
                        break;
                    }
                }
            }
        }
    }
    return miningSchema;
}
Also used : MiningField(org.dmg.pmml.MiningField) MiningSchema(org.dmg.pmml.MiningSchema) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList) List(java.util.List) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)

Example 10 with BasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.

the class DTrainUtils method generateNetwork.

// public static BasicNetwork generateNetwork(int in, int out, int numLayers, List<String> actFunc,
// List<Integer> hiddenNodeList, boolean isRandomizeWeights, double dropoutRate) {
// return generateNetwork(in, out, numLayers, actFunc, hiddenNodeList, isRandomizeWeights, dropoutRate,
// WGT_INIT_DEFAULT);
// }
public static BasicNetwork generateNetwork(int in, int out, int numLayers, List<String> actFunc, List<Integer> hiddenNodeList, boolean isRandomizeWeights, double dropoutRate, String wgtInit, boolean isLinearTarget, String outputActivationFunc) {
    final BasicFloatNetwork network = new BasicFloatNetwork();
    // in shifuconfig, we have a switch to control enable input layer dropout
    if (Boolean.valueOf(Environment.getProperty(CommonConstants.SHIFU_TRAIN_NN_INPUTLAYERDROPOUT_ENABLE, "true"))) {
        // we need to guarantee that input layer dropout rate is 40% of hiddenlayer dropout rate
        network.addLayer(new BasicDropoutLayer(new ActivationLinear(), true, in, dropoutRate * 0.4d));
    } else {
        network.addLayer(new BasicDropoutLayer(new ActivationLinear(), true, in, 0d));
    }
    // int hiddenNodes = 0;
    for (int i = 0; i < numLayers; i++) {
        String func = actFunc.get(i);
        Integer numHiddenNode = hiddenNodeList.get(i);
        // hiddenNodes += numHiddenNode;
        if (func.equalsIgnoreCase(NNConstants.NN_LINEAR)) {
            network.addLayer(new BasicDropoutLayer(new ActivationLinear(), true, numHiddenNode, dropoutRate));
        } else if (func.equalsIgnoreCase(NNConstants.NN_SIGMOID)) {
            network.addLayer(new BasicDropoutLayer(new ActivationSigmoid(), true, numHiddenNode, dropoutRate));
        } else if (func.equalsIgnoreCase(NNConstants.NN_TANH)) {
            network.addLayer(new BasicDropoutLayer(new ActivationTANH(), true, numHiddenNode, dropoutRate));
        } else if (func.equalsIgnoreCase(NNConstants.NN_LOG)) {
            network.addLayer(new BasicDropoutLayer(new ActivationLOG(), true, numHiddenNode, dropoutRate));
        } else if (func.equalsIgnoreCase(NNConstants.NN_SIN)) {
            network.addLayer(new BasicDropoutLayer(new ActivationSIN(), true, numHiddenNode, dropoutRate));
        } else if (func.equalsIgnoreCase(NNConstants.NN_RELU)) {
            network.addLayer(new BasicDropoutLayer(new ActivationReLU(), true, numHiddenNode, dropoutRate));
        } else if (func.equalsIgnoreCase(NNConstants.NN_LEAKY_RELU)) {
            network.addLayer(new BasicDropoutLayer(new ActivationLeakyReLU(), true, numHiddenNode, dropoutRate));
        } else if (func.equalsIgnoreCase(NNConstants.NN_SWISH)) {
            network.addLayer(new BasicDropoutLayer(new ActivationSwish(), true, numHiddenNode, dropoutRate));
        } else if (func.equalsIgnoreCase(NNConstants.NN_PTANH)) {
            network.addLayer(new BasicDropoutLayer(new ActivationPTANH(), true, numHiddenNode, dropoutRate));
        } else {
            network.addLayer(new BasicDropoutLayer(new ActivationSigmoid(), true, numHiddenNode, dropoutRate));
        }
    }
    if (isLinearTarget) {
        if (NNConstants.NN_RELU.equalsIgnoreCase(outputActivationFunc)) {
            network.addLayer(new BasicLayer(new ActivationReLU(), true, out));
        } else if (NNConstants.NN_LEAKY_RELU.equalsIgnoreCase(outputActivationFunc)) {
            network.addLayer(new BasicLayer(new ActivationLeakyReLU(), true, out));
        } else if (NNConstants.NN_SWISH.equalsIgnoreCase(outputActivationFunc)) {
            network.addLayer(new BasicLayer(new ActivationSwish(), true, out));
        } else {
            network.addLayer(new BasicLayer(new ActivationLinear(), true, out));
        }
    } else {
        network.addLayer(new BasicLayer(new ActivationSigmoid(), false, out));
    }
    NeuralStructure structure = network.getStructure();
    if (network.getStructure() instanceof FloatNeuralStructure) {
        ((FloatNeuralStructure) structure).finalizeStruct();
    } else {
        structure.finalizeStructure();
    }
    if (isRandomizeWeights) {
        if (wgtInit == null || wgtInit.length() == 0) {
            // default randomization
            network.reset();
        } else if (wgtInit.equalsIgnoreCase(WGT_INIT_GAUSSIAN)) {
            new GaussianRandomizer(0, 1).randomize(network);
        } else if (wgtInit.equalsIgnoreCase(WGT_INIT_XAVIER)) {
            new XavierWeightRandomizer().randomize(network);
        } else if (wgtInit.equalsIgnoreCase(WGT_INIT_HE)) {
            new HeWeightRandomizer().randomize(network);
        } else if (wgtInit.equalsIgnoreCase(WGT_INIT_LECUN)) {
            new LecunWeightRandomizer().randomize(network);
        } else if (wgtInit.equalsIgnoreCase(WGT_INIT_DEFAULT)) {
            // default randomization
            network.reset();
        } else {
            // default randomization
            network.reset();
        }
    }
    return network;
}
Also used : LecunWeightRandomizer(ml.shifu.shifu.core.dtrain.random.LecunWeightRandomizer) XavierWeightRandomizer(ml.shifu.shifu.core.dtrain.random.XavierWeightRandomizer) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) FloatNeuralStructure(ml.shifu.shifu.core.dtrain.dataset.FloatNeuralStructure) BasicLayer(org.encog.neural.networks.layers.BasicLayer) GaussianRandomizer(org.encog.mathutil.randomize.GaussianRandomizer) HeWeightRandomizer(ml.shifu.shifu.core.dtrain.random.HeWeightRandomizer) FloatNeuralStructure(ml.shifu.shifu.core.dtrain.dataset.FloatNeuralStructure) NeuralStructure(org.encog.neural.networks.structure.NeuralStructure)

Aggregations

BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)13 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)5 ArrayList (java.util.ArrayList)4 PersistBasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)4 BasicML (org.encog.ml.BasicML)4 List (java.util.List)3 BasicMLData (org.encog.ml.data.basic.BasicMLData)3 IOException (java.io.IOException)2 Path (org.apache.hadoop.fs.Path)2 RequiredFieldList (org.apache.pig.LoadPushDown.RequiredFieldList)2 BufferedInputStream (java.io.BufferedInputStream)1 DataInputStream (java.io.DataInputStream)1 Comparator (java.util.Comparator)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 Callable (java.util.concurrent.Callable)1 GZIPInputStream (java.util.zip.GZIPInputStream)1 GuaguaRuntimeException (ml.shifu.guagua.GuaguaRuntimeException)1 GuaguaMapReduceClient (ml.shifu.guagua.mapreduce.GuaguaMapReduceClient)1 ScoreObject (ml.shifu.shifu.container.ScoreObject)1