use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class KernelModel method loadInternal.
@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
long rows = header.getNumRecords();
dictionary = new StringDictionary(header.getDictionary());
supportVectors = new ArrayList<>();
for (long i = 0; i < rows; i++) {
String line = reader.readLine();
ModelRecord record = Util.decodeModel(line);
supportVectors.add(new SupportVector(record));
}
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class LinearModel method loadInternal.
// Loads model from a buffered stream.
@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
long rows = header.getNumRecords();
// Very old models did not set slope and offset so check first.
if (header.isSetSlope()) {
slope = header.getSlope();
}
if (header.isSetOffset()) {
offset = header.getOffset();
}
weights = new HashMap<>();
for (long i = 0; i < rows; i++) {
String line = reader.readLine();
ModelRecord record = Util.decodeModel(line);
String family = record.getFeatureFamily();
String name = record.getFeatureName();
Map<String, Float> inner = weights.get(family);
if (inner == null) {
inner = new HashMap<>();
weights.put(family, inner);
}
float weight = (float) record.getFeatureWeight();
inner.put(name, weight);
}
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class LowRankLinearModel method loadInternal.
@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
long rows = header.getNumRecords();
labelDictionary = new ArrayList<>();
for (LabelDictionaryEntry entry : header.getLabelDictionary()) {
labelDictionary.add(entry);
}
buildLabelToIndex();
labelWeightVector = new HashMap<>();
embeddingDimension = header.getLabelEmbedding().entrySet().iterator().next().getValue().size();
for (Map.Entry<String, java.util.List<Double>> labelRepresentation : header.getLabelEmbedding().entrySet()) {
java.util.List<Double> values = labelRepresentation.getValue();
String labelKey = labelRepresentation.getKey();
FloatVector labelWeight = new FloatVector(embeddingDimension);
for (int i = 0; i < embeddingDimension; i++) {
labelWeight.set(i, values.get(i).floatValue());
}
labelWeightVector.put(labelKey, labelWeight);
}
featureWeightVector = new HashMap<>();
for (long i = 0; i < rows; i++) {
String line = reader.readLine();
ModelRecord record = Util.decodeModel(line);
String family = record.getFeatureFamily();
String name = record.getFeatureName();
Map<String, FloatVector> inner = featureWeightVector.get(family);
if (inner == null) {
inner = new HashMap<>();
featureWeightVector.put(family, inner);
}
FloatVector vec = new FloatVector(record.getWeightVector().size());
for (int j = 0; j < record.getWeightVector().size(); j++) {
vec.values[j] = record.getWeightVector().get(j).floatValue();
}
inner.put(name, vec);
}
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class LowRankLinearModel method save.
public void save(BufferedWriter writer) throws IOException {
ModelHeader header = new ModelHeader();
header.setModelType("low_rank_linear");
long count = 0;
for (Map.Entry<String, Map<String, FloatVector>> familyMap : featureWeightVector.entrySet()) {
count += familyMap.getValue().entrySet().size();
}
header.setNumRecords(count);
header.setLabelDictionary(labelDictionary);
Map<String, java.util.List<Double>> labelEmbedding = new HashMap<>();
for (Map.Entry<String, FloatVector> labelRepresentation : labelWeightVector.entrySet()) {
float[] values = labelRepresentation.getValue().getValues();
ArrayList<Double> arrayList = new ArrayList<>();
for (int i = 0; i < embeddingDimension; i++) {
arrayList.add((double) values[i]);
}
labelEmbedding.put(labelRepresentation.getKey(), arrayList);
}
header.setLabelEmbedding(labelEmbedding);
ModelRecord headerRec = new ModelRecord();
headerRec.setModelHeader(header);
writer.write(Util.encode(headerRec));
writer.newLine();
for (Map.Entry<String, Map<String, FloatVector>> familyMap : featureWeightVector.entrySet()) {
for (Map.Entry<String, FloatVector> feature : familyMap.getValue().entrySet()) {
ModelRecord record = new ModelRecord();
record.setFeatureFamily(familyMap.getKey());
record.setFeatureName(feature.getKey());
ArrayList<Double> arrayList = new ArrayList<>();
for (int i = 0; i < feature.getValue().values.length; i++) {
arrayList.add((double) feature.getValue().values[i]);
}
record.setWeightVector(arrayList);
writer.write(Util.encode(record));
writer.newLine();
}
}
writer.flush();
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class MlpModel method loadInternal.
@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
numHiddenLayers = header.getNumHiddenLayers();
List<Integer> hiddenNodeNumber = header.getNumberHiddenNodes();
for (int i = 0; i < hiddenNodeNumber.size(); i++) {
layerNodeNumber.add(hiddenNodeNumber.get(i));
}
// load input layer weights
long rows = header.getNumRecords();
for (int i = 0; i < rows; i++) {
String line = reader.readLine();
ModelRecord record = Util.decodeModel(line);
String family = record.getFeatureFamily();
String name = record.getFeatureName();
Map<String, FloatVector> inner = inputLayerWeights.get(family);
if (inner == null) {
inner = new HashMap<>();
inputLayerWeights.put(family, inner);
}
FloatVector vec = new FloatVector(record.getWeightVector().size());
for (int j = 0; j < record.getWeightVector().size(); j++) {
vec.values[j] = record.getWeightVector().get(j).floatValue();
}
inner.put(name, vec);
}
// load bias and activation function
for (int i = 0; i < numHiddenLayers + 1; i++) {
String line = reader.readLine();
ModelRecord record = Util.decodeModel(line);
List<Double> arrayList = record.getWeightVector();
FloatVector layerBias = new FloatVector(arrayList.size());
for (int j = 0; j < arrayList.size(); j++) {
layerBias.set(j, arrayList.get(j).floatValue());
}
bias.put(i, layerBias);
activationFunction.add(record.getFunctionForm());
}
// load the hiddenLayerWeights, one record per (layer + node)
for (int i = 0; i < numHiddenLayers; i++) {
ArrayList<FloatVector> weights = new ArrayList<>();
for (int j = 0; j < layerNodeNumber.get(i); j++) {
String line = reader.readLine();
ModelRecord record = Util.decodeModel(line);
List<Double> arrayList = record.getWeightVector();
FloatVector w = new FloatVector(arrayList.size());
for (int k = 0; k < arrayList.size(); k++) {
w.set(k, arrayList.get(k).floatValue());
}
weights.add(w);
}
hiddenLayerWeights.put(i, weights);
}
}
Aggregations