use of org.encog.ml.BasicML in project shifu by ShifuML.
the class NNModelSpecTest method testFitExistingModelIn.
@Test
public void testFitExistingModelIn() {
BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
BasicNetwork basicNetwork = (BasicNetwork) basicML;
FlatNetwork flatNetwork = basicNetwork.getFlat();
NNMaster master = new NNMaster();
Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, flatNetwork, Arrays.asList(new Integer[] { 6 }));
List<Integer> indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 31);
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, flatNetwork, Arrays.asList(new Integer[] { 1 }));
indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 930);
BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1 }));
indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 930);
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1 }), false);
indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 900);
}
use of org.encog.ml.BasicML in project shifu by ShifuML.
the class BinaryNNSerializer method save.
public static void save(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, List<BasicML> basicNetworks, FileSystem fs, Path output) throws IOException {
DataOutputStream fos = null;
try {
fos = new DataOutputStream(new GZIPOutputStream(fs.create(output)));
// version
fos.writeInt(CommonConstants.NN_FORMAT_VERSION);
// write normStr
String normStr = modelConfig.getNormalize().getNormType().toString();
ml.shifu.shifu.core.dtrain.StringUtils.writeString(fos, normStr);
// compute columns needed
Map<Integer, String> columnIndexNameMapping = getIndexNameMapping(columnConfigList);
// write column stats to output
List<NNColumnStats> csList = new ArrayList<NNColumnStats>();
for (ColumnConfig cc : columnConfigList) {
if (columnIndexNameMapping.containsKey(cc.getColumnNum())) {
NNColumnStats cs = new NNColumnStats();
cs.setCutoff(modelConfig.getNormalizeStdDevCutOff());
cs.setColumnType(cc.getColumnType());
cs.setMean(cc.getMean());
cs.setStddev(cc.getStdDev());
cs.setColumnNum(cc.getColumnNum());
cs.setColumnName(cc.getColumnName());
cs.setBinCategories(cc.getBinCategory());
cs.setBinBoundaries(cc.getBinBoundary());
cs.setBinPosRates(cc.getBinPosRate());
cs.setBinCountWoes(cc.getBinCountWoe());
cs.setBinWeightWoes(cc.getBinWeightedWoe());
// TODO cache such computation
double[] meanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, false);
cs.setWoeMean(meanAndStdDev[0]);
cs.setWoeStddev(meanAndStdDev[1]);
double[] WgtMeanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, true);
cs.setWoeWgtMean(WgtMeanAndStdDev[0]);
cs.setWoeWgtStddev(WgtMeanAndStdDev[1]);
csList.add(cs);
}
}
fos.writeInt(csList.size());
for (NNColumnStats cs : csList) {
cs.write(fos);
}
// write column index mapping
Map<Integer, Integer> columnMapping = getColumnMapping(columnConfigList);
fos.writeInt(columnMapping.size());
for (Entry<Integer, Integer> entry : columnMapping.entrySet()) {
fos.writeInt(entry.getKey());
fos.writeInt(entry.getValue());
}
// persist network, set it as list
fos.writeInt(basicNetworks.size());
for (BasicML network : basicNetworks) {
new PersistBasicFloatNetwork().saveNetwork(fos, (BasicFloatNetwork) network);
}
} finally {
IOUtils.closeStream(fos);
}
}
use of org.encog.ml.BasicML in project shifu by ShifuML.
the class EvalNormUDF method exec.
public Tuple exec(Tuple input) throws IOException {
if (isCsvFormat) {
String firstCol = ((input.get(0) == null) ? "" : input.get(0).toString());
if (this.headers[0].equals(CommonUtils.normColumnName(firstCol))) {
// TODO what to do if the column value == column name? ...
return null;
}
}
if (this.modelRunner == null && this.isAppendScore) {
// here to initialize modelRunner, this is moved from constructor to here to avoid OOM in client side.
// UDF in pig client will be initialized to get some metadata issues
@SuppressWarnings("deprecation") List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelConfig, evalConfig, evalConfig.getDataSet().getSource(), evalConfig.getGbtConvertToProb(), evalConfig.getGbtScoreConvertStrategy());
this.modelRunner = new ModelRunner(modelConfig, columnConfigList, this.headers, evalConfig.getDataSet().getDataDelimiter(), models);
this.modelRunner.setScoreScale(Integer.parseInt(this.scale));
}
Map<NSColumn, String> rawDataNsMap = CommonUtils.convertDataIntoNsMap(input, this.headers, this.segFilterSize);
if (MapUtils.isEmpty(rawDataNsMap)) {
return null;
}
Tuple tuple = TupleFactory.getInstance().newTuple();
for (int i = 0; i < this.outputNames.size(); i++) {
String name = this.outputNames.get(i);
String raw = rawDataNsMap.get(new NSColumn(name));
if (i == 0) {
tuple.append(raw);
} else if (i == 1) {
tuple.append(StringUtils.isEmpty(raw) ? "1" : raw);
} else if (i > 1 && i < 2 + validMetaSize) {
// [2, 2 + validMetaSize) are meta columns
tuple.append(raw);
} else {
ColumnConfig columnConfig = this.columnConfigMap.get(name);
List<Double> normVals = Normalizer.normalize(columnConfig, raw, this.modelConfig.getNormalizeStdDevCutOff(), this.modelConfig.getNormalizeType());
if (this.isOutputRaw) {
tuple.append(raw);
}
for (Double normVal : normVals) {
tuple.append(getOutputValue(normVal, true));
}
}
}
if (this.isAppendScore && this.modelRunner != null) {
CaseScoreResult score = this.modelRunner.computeNsData(rawDataNsMap);
if (this.modelRunner == null || this.modelRunner.getModelsCnt() == 0 || score == null) {
tuple.append(-999.0);
} else if (this.scIndex < 0) {
tuple.append(score.getAvgScore());
} else {
tuple.append(score.getScores().get(this.scIndex));
}
}
return tuple;
}
Aggregations