use of com.airbnb.aerosolve.core.LabelDictionaryEntry in project aerosolve by airbnb.
the class LowRankLinearModelTest method makeLabelDictionary.
ArrayList<LabelDictionaryEntry> makeLabelDictionary() {
ArrayList<LabelDictionaryEntry> labelDictionary = new ArrayList<>();
// construct label dictionary
LabelDictionaryEntry animalLabel = new LabelDictionaryEntry();
animalLabel.setLabel("A");
animalLabel.setCount(1);
labelDictionary.add(animalLabel);
LabelDictionaryEntry colorLabel = new LabelDictionaryEntry();
colorLabel.setLabel("C");
colorLabel.setCount(1);
labelDictionary.add(colorLabel);
LabelDictionaryEntry fruitLabel = new LabelDictionaryEntry();
fruitLabel.setLabel("F");
fruitLabel.setCount(1);
labelDictionary.add(fruitLabel);
return labelDictionary;
}
use of com.airbnb.aerosolve.core.LabelDictionaryEntry in project aerosolve by airbnb.
the class FullRankLinearModel method loadInternal.
@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
long rows = header.getNumRecords();
labelDictionary = new ArrayList<>();
for (LabelDictionaryEntry entry : header.getLabelDictionary()) {
labelDictionary.add(entry);
}
buildLabelToIndex();
weightVector = new HashMap<>();
for (long i = 0; i < rows; i++) {
String line = reader.readLine();
ModelRecord record = Util.decodeModel(line);
String family = record.getFeatureFamily();
String name = record.getFeatureName();
Map<String, FloatVector> inner = weightVector.get(family);
if (inner == null) {
inner = new HashMap<>();
weightVector.put(family, inner);
}
FloatVector vec = new FloatVector(record.getWeightVector().size());
for (int j = 0; j < record.getWeightVector().size(); j++) {
vec.values[j] = record.getWeightVector().get(j).floatValue();
}
inner.put(name, vec);
}
}
use of com.airbnb.aerosolve.core.LabelDictionaryEntry in project aerosolve by airbnb.
the class LowRankLinearModel method loadInternal.
@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
long rows = header.getNumRecords();
labelDictionary = new ArrayList<>();
for (LabelDictionaryEntry entry : header.getLabelDictionary()) {
labelDictionary.add(entry);
}
buildLabelToIndex();
labelWeightVector = new HashMap<>();
embeddingDimension = header.getLabelEmbedding().entrySet().iterator().next().getValue().size();
for (Map.Entry<String, java.util.List<Double>> labelRepresentation : header.getLabelEmbedding().entrySet()) {
java.util.List<Double> values = labelRepresentation.getValue();
String labelKey = labelRepresentation.getKey();
FloatVector labelWeight = new FloatVector(embeddingDimension);
for (int i = 0; i < embeddingDimension; i++) {
labelWeight.set(i, values.get(i).floatValue());
}
labelWeightVector.put(labelKey, labelWeight);
}
featureWeightVector = new HashMap<>();
for (long i = 0; i < rows; i++) {
String line = reader.readLine();
ModelRecord record = Util.decodeModel(line);
String family = record.getFeatureFamily();
String name = record.getFeatureName();
Map<String, FloatVector> inner = featureWeightVector.get(family);
if (inner == null) {
inner = new HashMap<>();
featureWeightVector.put(family, inner);
}
FloatVector vec = new FloatVector(record.getWeightVector().size());
for (int j = 0; j < record.getWeightVector().size(); j++) {
vec.values[j] = record.getWeightVector().get(j).floatValue();
}
inner.put(name, vec);
}
}
Aggregations