use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class ModelFactory method createFromReader.
public static Optional<AbstractModel> createFromReader(BufferedReader reader) throws IOException {
Optional<AbstractModel> model = Optional.absent();
String headerLine = reader.readLine();
ModelRecord record = Util.decodeModel(headerLine);
if (record == null) {
log.error("Could not decode header " + headerLine);
return model;
}
ModelHeader header = record.getModelHeader();
if (header != null) {
AbstractModel result = createByName(header.getModelType());
if (result != null) {
result.loadInternal(header, reader);
model = Optional.of(result);
}
}
return model;
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class SplineModel method save.
@Override
public void save(BufferedWriter writer) throws IOException {
ModelHeader header = new ModelHeader();
header.setModelType("spline");
header.setNumHidden(numBins);
header.setSlope(slope);
header.setOffset(offset);
long count = 0;
for (Map.Entry<String, Map<String, WeightSpline>> familyMap : weightSpline.entrySet()) {
for (Map.Entry<String, WeightSpline> feature : familyMap.getValue().entrySet()) {
count++;
}
}
header.setNumRecords(count);
ModelRecord headerRec = new ModelRecord();
headerRec.setModelHeader(header);
writer.write(Util.encode(headerRec));
writer.newLine();
for (Map.Entry<String, Map<String, WeightSpline>> familyMap : weightSpline.entrySet()) {
for (Map.Entry<String, WeightSpline> feature : familyMap.getValue().entrySet()) {
ModelRecord record = new ModelRecord();
record.setFeatureFamily(familyMap.getKey());
record.setFeatureName(feature.getKey());
ArrayList<Double> arrayList = new ArrayList<Double>();
for (int i = 0; i < feature.getValue().splineWeights.length; i++) {
arrayList.add((double) feature.getValue().splineWeights[i]);
}
record.setWeightVector(arrayList);
record.setMinVal(feature.getValue().spline.getMinVal());
record.setMaxVal(feature.getValue().spline.getMaxVal());
writer.write(Util.encode(record));
writer.newLine();
}
}
writer.flush();
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class BoostedStumpsModel method loadInternal.
@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
long rows = header.getNumRecords();
stumps = new ArrayList<>();
for (long i = 0; i < rows; i++) {
String line = reader.readLine();
ModelRecord record = Util.decodeModel(line);
stumps.add(record);
}
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class FullRankLinearModel method loadInternal.
@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
long rows = header.getNumRecords();
labelDictionary = new ArrayList<>();
for (LabelDictionaryEntry entry : header.getLabelDictionary()) {
labelDictionary.add(entry);
}
buildLabelToIndex();
weightVector = new HashMap<>();
for (long i = 0; i < rows; i++) {
String line = reader.readLine();
ModelRecord record = Util.decodeModel(line);
String family = record.getFeatureFamily();
String name = record.getFeatureName();
Map<String, FloatVector> inner = weightVector.get(family);
if (inner == null) {
inner = new HashMap<>();
weightVector.put(family, inner);
}
FloatVector vec = new FloatVector(record.getWeightVector().size());
for (int j = 0; j < record.getWeightVector().size(); j++) {
vec.values[j] = record.getWeightVector().get(j).floatValue();
}
inner.put(name, vec);
}
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class DecisionTreeTransform method doTransform.
@Override
public void doTransform(FeatureVector featureVector) {
Map<String, Map<String, Double>> floatFeatures = featureVector.getFloatFeatures();
if (floatFeatures == null) {
return;
}
Util.optionallyCreateStringFeatures(featureVector);
Map<String, Set<String>> stringFeatures = featureVector.getStringFeatures();
Set<String> outputString = Util.getOrCreateStringFeature(outputLeaves, stringFeatures);
Map<String, Double> outputFloat = Util.getOrCreateFloatFeature(outputScoreFamily, floatFeatures);
int leafIdx = tree.getLeafIndex(floatFeatures);
ModelRecord rec = tree.getStumps().get(leafIdx);
outputString.add(rec.featureName);
outputFloat.put(outputScoreName, rec.featureWeight);
}
Aggregations