Search in sources :

Example 21 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class NNModelSpecTest method testFitExistingModelIn.

@Test
public void testFitExistingModelIn() {
    BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
    BasicNetwork basicNetwork = (BasicNetwork) basicML;
    FlatNetwork flatNetwork = basicNetwork.getFlat();
    NNMaster master = new NNMaster();
    Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, flatNetwork, Arrays.asList(new Integer[] { 6 }));
    List<Integer> indexList = new ArrayList<Integer>(fixedWeightIndexSet);
    Collections.sort(indexList);
    Assert.assertEquals(indexList.size(), 31);
    fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, flatNetwork, Arrays.asList(new Integer[] { 1 }));
    indexList = new ArrayList<Integer>(fixedWeightIndexSet);
    Collections.sort(indexList);
    Assert.assertEquals(indexList.size(), 930);
    BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
    BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
    FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
    fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1 }));
    indexList = new ArrayList<Integer>(fixedWeightIndexSet);
    Collections.sort(indexList);
    Assert.assertEquals(indexList.size(), 930);
    fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1 }), false);
    indexList = new ArrayList<Integer>(fixedWeightIndexSet);
    Collections.sort(indexList);
    Assert.assertEquals(indexList.size(), 900);
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) NNMaster(ml.shifu.shifu.core.dtrain.nn.NNMaster) BasicNetwork(org.encog.neural.networks.BasicNetwork) ArrayList(java.util.ArrayList) BasicML(org.encog.ml.BasicML) File(java.io.File) Test(org.testng.annotations.Test)

Example 22 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class BinaryNNSerializer method save.

public static void save(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, List<BasicML> basicNetworks, FileSystem fs, Path output) throws IOException {
    DataOutputStream fos = null;
    try {
        fos = new DataOutputStream(new GZIPOutputStream(fs.create(output)));
        // version
        fos.writeInt(CommonConstants.NN_FORMAT_VERSION);
        // write normStr
        String normStr = modelConfig.getNormalize().getNormType().toString();
        ml.shifu.shifu.core.dtrain.StringUtils.writeString(fos, normStr);
        // compute columns needed
        Map<Integer, String> columnIndexNameMapping = getIndexNameMapping(columnConfigList);
        // write column stats to output
        List<NNColumnStats> csList = new ArrayList<NNColumnStats>();
        for (ColumnConfig cc : columnConfigList) {
            if (columnIndexNameMapping.containsKey(cc.getColumnNum())) {
                NNColumnStats cs = new NNColumnStats();
                cs.setCutoff(modelConfig.getNormalizeStdDevCutOff());
                cs.setColumnType(cc.getColumnType());
                cs.setMean(cc.getMean());
                cs.setStddev(cc.getStdDev());
                cs.setColumnNum(cc.getColumnNum());
                cs.setColumnName(cc.getColumnName());
                cs.setBinCategories(cc.getBinCategory());
                cs.setBinBoundaries(cc.getBinBoundary());
                cs.setBinPosRates(cc.getBinPosRate());
                cs.setBinCountWoes(cc.getBinCountWoe());
                cs.setBinWeightWoes(cc.getBinWeightedWoe());
                // TODO cache such computation
                double[] meanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, false);
                cs.setWoeMean(meanAndStdDev[0]);
                cs.setWoeStddev(meanAndStdDev[1]);
                double[] WgtMeanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, true);
                cs.setWoeWgtMean(WgtMeanAndStdDev[0]);
                cs.setWoeWgtStddev(WgtMeanAndStdDev[1]);
                csList.add(cs);
            }
        }
        fos.writeInt(csList.size());
        for (NNColumnStats cs : csList) {
            cs.write(fos);
        }
        // write column index mapping
        Map<Integer, Integer> columnMapping = getColumnMapping(columnConfigList);
        fos.writeInt(columnMapping.size());
        for (Entry<Integer, Integer> entry : columnMapping.entrySet()) {
            fos.writeInt(entry.getKey());
            fos.writeInt(entry.getValue());
        }
        // persist network, set it as list
        fos.writeInt(basicNetworks.size());
        for (BasicML network : basicNetworks) {
            new PersistBasicFloatNetwork().saveNetwork(fos, (BasicFloatNetwork) network);
        }
    } finally {
        IOUtils.closeStream(fos);
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) DataOutputStream(java.io.DataOutputStream) ArrayList(java.util.ArrayList) BasicML(org.encog.ml.BasicML) GZIPOutputStream(java.util.zip.GZIPOutputStream) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)

Example 23 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class EvalNormUDF method exec.

public Tuple exec(Tuple input) throws IOException {
    if (isCsvFormat) {
        String firstCol = ((input.get(0) == null) ? "" : input.get(0).toString());
        if (this.headers[0].equals(CommonUtils.normColumnName(firstCol))) {
            // TODO what to do if the column value == column name? ...
            return null;
        }
    }
    if (this.modelRunner == null && this.isAppendScore) {
        // here to initialize modelRunner, this is moved from constructor to here to avoid OOM in client side.
        // UDF in pig client will be initialized to get some metadata issues
        @SuppressWarnings("deprecation") List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelConfig, evalConfig, evalConfig.getDataSet().getSource(), evalConfig.getGbtConvertToProb(), evalConfig.getGbtScoreConvertStrategy());
        this.modelRunner = new ModelRunner(modelConfig, columnConfigList, this.headers, evalConfig.getDataSet().getDataDelimiter(), models);
        this.modelRunner.setScoreScale(Integer.parseInt(this.scale));
    }
    Map<NSColumn, String> rawDataNsMap = CommonUtils.convertDataIntoNsMap(input, this.headers, this.segFilterSize);
    if (MapUtils.isEmpty(rawDataNsMap)) {
        return null;
    }
    Tuple tuple = TupleFactory.getInstance().newTuple();
    for (int i = 0; i < this.outputNames.size(); i++) {
        String name = this.outputNames.get(i);
        String raw = rawDataNsMap.get(new NSColumn(name));
        if (i == 0) {
            tuple.append(raw);
        } else if (i == 1) {
            tuple.append(StringUtils.isEmpty(raw) ? "1" : raw);
        } else if (i > 1 && i < 2 + validMetaSize) {
            // [2, 2 + validMetaSize) are meta columns
            tuple.append(raw);
        } else {
            ColumnConfig columnConfig = this.columnConfigMap.get(name);
            List<Double> normVals = Normalizer.normalize(columnConfig, raw, this.modelConfig.getNormalizeStdDevCutOff(), this.modelConfig.getNormalizeType());
            if (this.isOutputRaw) {
                tuple.append(raw);
            }
            for (Double normVal : normVals) {
                tuple.append(getOutputValue(normVal, true));
            }
        }
    }
    if (this.isAppendScore && this.modelRunner != null) {
        CaseScoreResult score = this.modelRunner.computeNsData(rawDataNsMap);
        if (this.modelRunner == null || this.modelRunner.getModelsCnt() == 0 || score == null) {
            tuple.append(-999.0);
        } else if (this.scIndex < 0) {
            tuple.append(score.getAvgScore());
        } else {
            tuple.append(score.getScores().get(this.scIndex));
        }
    }
    return tuple;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BasicML(org.encog.ml.BasicML) CaseScoreResult(ml.shifu.shifu.container.CaseScoreResult) Tuple(org.apache.pig.data.Tuple) ModelRunner(ml.shifu.shifu.core.ModelRunner) NSColumn(ml.shifu.shifu.column.NSColumn)

Aggregations

BasicML (org.encog.ml.BasicML)23 File (java.io.File)6 BasicNetwork (org.encog.neural.networks.BasicNetwork)5 IOException (java.io.IOException)4 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)4 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)4 PersistBasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)4 FileSystem (org.apache.hadoop.fs.FileSystem)4 FlatNetwork (org.encog.neural.flat.FlatNetwork)4 ArrayList (java.util.ArrayList)3 NSColumn (ml.shifu.shifu.column.NSColumn)3 ModelRunner (ml.shifu.shifu.core.ModelRunner)3 ModelSpec (ml.shifu.shifu.core.model.ModelSpec)3 MutablePair (org.apache.commons.lang3.tuple.MutablePair)3 Configuration (org.apache.hadoop.conf.Configuration)3 FileStatus (org.apache.hadoop.fs.FileStatus)3 Path (org.apache.hadoop.fs.Path)3 JarFile (java.util.jar.JarFile)2 CaseScoreResult (ml.shifu.shifu.container.CaseScoreResult)2 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)2