use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class VarSelWorker method init.
@Override
public void init(WorkerContext<VarSelMasterResult, VarSelWorkerResult> workerContext) {
Properties props = workerContext.getProps();
try {
RawSourceData.SourceType sourceType = RawSourceData.SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, RawSourceData.SourceType.HDFS.toString()));
this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType);
this.columnConfigList = CommonUtils.loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
String conductorClsName = props.getProperty(Constants.VAR_SEL_WORKER_CONDUCTOR);
this.workerConductor = (AbstractWorkerConductor) Class.forName(conductorClsName).getDeclaredConstructor(ModelConfig.class, List.class).newInstance(this.modelConfig, this.columnConfigList);
} catch (IOException e) {
throw new RuntimeException("Fail to load ModelConfig or List<ColumnConfig>", e);
} catch (ClassNotFoundException e) {
throw new RuntimeException("Invalid Master Conductor class", e);
} catch (InstantiationException e) {
throw new RuntimeException("Fail to create instance", e);
} catch (IllegalAccessException e) {
throw new RuntimeException("Illegal access when creating instance", e);
} catch (NoSuchMethodException e) {
throw new RuntimeException("Fail to call method when creating instance", e);
} catch (InvocationTargetException e) {
throw new RuntimeException("Fail to invoke when creating instance", e);
}
List<Integer> normalizedColumnIdList = this.getNormalizedColumnIdList();
this.inputNodeCount = normalizedColumnIdList.size();
this.outputNodeCount = this.getTargetColumnCount();
trainingDataSet = new TrainingDataSet(normalizedColumnIdList);
try {
dataPurifier = new DataPurifier(modelConfig, false);
} catch (IOException e) {
throw new RuntimeException("Fail to create DataPurifier", e);
}
this.targetColumnId = CommonUtils.getTargetColumnNum(this.columnConfigList);
if (StringUtils.isNotBlank(modelConfig.getWeightColumnName())) {
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.getColumnName().equalsIgnoreCase(modelConfig.getWeightColumnName().trim())) {
this.weightColumnId = columnConfig.getColumnNum();
break;
}
}
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class MiningSchemaCreator method build.
@Override
public MiningSchema build(BasicML basicML) {
MiningSchema miningSchema = new MiningSchema();
boolean isSegExpansionMode = columnConfigList.size() > datasetHeaders.length;
int segSize = segmentExpansions.size();
if (basicML != null && basicML instanceof BasicFloatNetwork) {
BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
Set<Integer> featureSet = bfn.getFeatureSet();
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.getColumnNum() >= datasetHeaders.length) {
// in order
break;
}
if (isActiveColumn(featureSet, columnConfig)) {
if (columnConfig.isTarget()) {
List<MiningField> miningFields = createTargetMingFields(columnConfig);
miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
} else {
miningSchema.addMiningFields(createActiveMingFields(columnConfig));
}
} else if (isSegExpansionMode) {
// even current column not selected, if segment column selected, we should keep raw column
for (int i = 0; i < segSize; i++) {
int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
ColumnConfig cc = columnConfigList.get(newIndex);
if (cc.isFinalSelect()) {
// if one segment feature is selected, we should put raw column in
if (columnConfig.isTarget()) {
List<MiningField> miningFields = createTargetMingFields(columnConfig);
miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
} else {
miningSchema.addMiningFields(createActiveMingFields(columnConfig));
}
break;
}
}
}
}
} else {
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.getColumnNum() >= datasetHeaders.length) {
// in order
break;
}
// FIXME, if no variable is selected
if (columnConfig.isFinalSelect() || columnConfig.isTarget()) {
if (columnConfig.isTarget()) {
List<MiningField> miningFields = createTargetMingFields(columnConfig);
miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
} else {
miningSchema.addMiningFields(createActiveMingFields(columnConfig));
}
} else if (isSegExpansionMode) {
// even current column not selected, if segment column selected, we should keep raw column
for (int i = 0; i < segSize; i++) {
int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
ColumnConfig cc = columnConfigList.get(newIndex);
if (cc.isFinalSelect()) {
// if one segment feature is selected, we should put raw column in
if (columnConfig.isTarget()) {
List<MiningField> miningFields = createTargetMingFields(columnConfig);
miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
} else {
miningSchema.addMiningFields(createActiveMingFields(columnConfig));
}
break;
}
}
}
}
}
return miningSchema;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class PersistWideAndDeep method readColumnConfig.
private static ColumnConfig readColumnConfig(DataInputStream dis) throws IOException {
ColumnConfig columnConfig = new ColumnConfig();
columnConfig.read(dis);
return columnConfig;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class WDLWorker method load.
/**
* Logic to load data into memory list which includes float array for numerical features and sparse object array for
* categorical features.
*/
@Override
public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue, WorkerContext<WDLParams, WDLParams> context) {
if ((++this.count) % 5000 == 0) {
LOG.info("Read {} records.", this.count);
}
// hashcode for fixed input split in train and validation
long hashcode = 0;
float[] inputs = new float[this.numInputs];
this.cateInputs = (int) this.columnConfigList.stream().filter(ColumnConfig::isCategorical).count();
SparseInput[] cateInputs = new SparseInput[this.cateInputs];
float ideal = 0f, significance = 1f;
int index = 0, numIndex = 0, cateIndex = 0;
// use guava Splitter to iterate only once
for (String input : this.splitter.split(currentValue.getWritable().toString())) {
if (index == this.columnConfigList.size()) {
significance = getWeightValue(input);
// the last field is significance, break here
break;
} else {
ColumnConfig config = this.columnConfigList.get(index);
if (config != null && config.isTarget()) {
ideal = getFloatValue(input);
} else {
// final select some variables but meta and target are not included
if (validColumn(config)) {
if (config.isNumerical()) {
inputs[numIndex] = getFloatValue(input);
this.inputIndexMap.putIfAbsent(config.getColumnNum(), numIndex++);
} else if (config.isCategorical()) {
cateInputs[cateIndex] = new SparseInput(config.getColumnNum(), getCateIndex(input, config));
this.inputIndexMap.putIfAbsent(config.getColumnNum(), cateIndex++);
}
hashcode = hashcode * 31 + input.hashCode();
}
}
}
index += 1;
}
// output delimiter in norm can be set by user now and if user set a special one later changed, this exception
// is helped to quick find such issue.
validateInputLength(context, inputs, numIndex);
// sample negative only logic here
if (sampleNegOnly(hashcode, ideal)) {
return;
}
// up sampling logic, just add more weights while bagging sampling rate is still not changed
if (modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal, 1d) == 0) {
// ideal == 1 means positive tags; sample + 1 to avoid sample count to 0
significance = significance * (this.upSampleRng.sample() + 1);
}
Data data = new Data(inputs, cateInputs, significance, ideal);
// split into validation and training data set according to validation rate
boolean isInTraining = this.addDataPairToDataSet(hashcode, data, context.getAttachment());
// update some positive or negative selected count in metrics
this.updateMetrics(data, isInTraining);
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class WDLWorker method initCateIndexMap.
private void initCateIndexMap() {
this.columnCategoryIndexMapping = new HashMap<Integer, Map<String, Integer>>();
for (ColumnConfig config : this.columnConfigList) {
if (config.isCategorical() && config.getBinCategory() != null) {
Map<String, Integer> tmpMap = new HashMap<String, Integer>();
for (int i = 0; i < config.getBinCategory().size(); i++) {
List<String> catVals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i));
for (String cval : catVals) {
tmpMap.put(cval, i);
}
}
this.columnCategoryIndexMapping.put(config.getColumnNum(), tmpMap);
}
}
}
Aggregations