use of com.airbnb.aerosolve.core.FeatureVector in project aerosolve by airbnb.
the class FeatureVectorGen method toFeatureVector.
// TODO add a new function to consider dense feature.
public static FeatureVector toFeatureVector(Features features, List<StringFamily> stringFamilies, List<FloatFamily> floatFamilies) {
FeatureVector featureVector = new FeatureVector();
// Set string features.
final Map<String, Set<String>> stringFeatures = new HashMap<>();
featureVector.setStringFeatures(stringFeatures);
setBIAS(stringFeatures);
for (StringFamily featureFamily : stringFamilies) {
stringFeatures.put(featureFamily.getFamilyName(), featureFamily.getFeatures());
}
final Map<String, Map<String, Double>> floatFeatures = new HashMap<>();
featureVector.setFloatFeatures(floatFeatures);
for (FloatFamily featureFamily : floatFamilies) {
floatFeatures.put(featureFamily.getFamilyName(), featureFamily.getFeatures());
}
for (int i = 0; i < features.names.length; ++i) {
Object feature = features.values[i];
if (feature != null) {
// Integer type = features.types[i];
String name = features.names[i];
if (feature instanceof Double || feature instanceof Float || feature instanceof Integer || feature instanceof Long) {
for (FloatFamily featureFamily : floatFamilies) {
if (featureFamily.add(name, feature))
break;
}
} else if (feature instanceof String) {
for (StringFamily featureFamily : stringFamilies) {
if (featureFamily.add(name, feature))
break;
}
} else if (feature instanceof Boolean) {
for (StringFamily featureFamily : stringFamilies) {
if (featureFamily.add(name, (Boolean) feature))
break;
}
}
}
}
return featureVector;
}
use of com.airbnb.aerosolve.core.FeatureVector in project aerosolve by airbnb.
the class Transformer method addContextToItems.
/**
* Adds the context's features to examples' features
*/
public void addContextToItems(Example examples) {
Map<String, Set<String>> contextStringFeatures = null;
Map<String, Map<String, Double>> contextFloatFeatures = null;
Map<String, List<Double>> contextDenseFeatures = null;
if (examples.context != null) {
if (examples.context.stringFeatures != null) {
contextStringFeatures = examples.context.getStringFeatures();
}
if (examples.context.floatFeatures != null) {
contextFloatFeatures = examples.context.getFloatFeatures();
}
if (examples.context.denseFeatures != null) {
contextDenseFeatures = examples.context.getDenseFeatures();
}
}
for (FeatureVector item : examples.example) {
addContextToItem(contextStringFeatures, contextFloatFeatures, contextDenseFeatures, item);
}
}
use of com.airbnb.aerosolve.core.FeatureVector in project aerosolve by airbnb.
the class MinKernelDenseFeatureDictionary method getKNearestNeighbors.
/**
* /**
* Calculates the Min Kernel distance to each dictionary element.
* Returns the top K elements as a new sparse feature.
*/
@Override
public FeatureVector getKNearestNeighbors(KNearestNeighborsOptions options, FeatureVector featureVector) {
FeatureVector result = new FeatureVector();
Map<String, List<Double>> denseFeatures = featureVector.getDenseFeatures();
if (denseFeatures == null) {
return result;
}
PriorityQueue<SimpleEntry<String, Double>> pq = new PriorityQueue<>(options.getNumNearest() + 1, new EntryComparator());
String idKey = options.getIdKey();
Map<String, Map<String, Double>> floatFeatures = new HashMap<>();
String myId = featureVector.getStringFeatures().get(idKey).iterator().next();
for (FeatureVector supportVector : dictionaryList) {
Double minKernel = FeatureVectorUtil.featureVectorMinKernel(featureVector, supportVector);
Set<String> idSet = supportVector.getStringFeatures().get(idKey);
String id = idSet.iterator().next();
if (id == myId)
continue;
SimpleEntry<String, Double> entry = new SimpleEntry<String, Double>(id, minKernel);
pq.add(entry);
if (pq.size() > options.getNumNearest()) {
pq.poll();
}
}
HashMap<String, Double> newFeature = new HashMap<>();
while (pq.peek() != null) {
SimpleEntry<String, Double> entry = pq.poll();
newFeature.put(entry.getKey(), entry.getValue());
}
floatFeatures.put(options.getOutputKey(), newFeature);
result.setFloatFeatures(floatFeatures);
return result;
}
use of com.airbnb.aerosolve.core.FeatureVector in project aerosolve by airbnb.
the class LocalitySensitiveHashSparseFeatureDictionary method getKNearestNeighbors.
@Override
public FeatureVector getKNearestNeighbors(KNearestNeighborsOptions options, FeatureVector featureVector) {
FeatureVector result = new FeatureVector();
Map<String, Set<String>> stringFeatures = featureVector.getStringFeatures();
if (stringFeatures == null) {
return result;
}
String featureKey = options.getFeatureKey();
Set<String> keys = stringFeatures.get(featureKey);
if (keys == null) {
return result;
}
if (!haveLSH) {
buildHashTable(featureKey);
}
String idKey = options.getIdKey();
PriorityQueue<SimpleEntry<String, Double>> pq = new PriorityQueue<>(options.getNumNearest() + 1, new EntryComparator());
Map<String, Map<String, Double>> floatFeatures = new HashMap<>();
String myId = featureVector.getStringFeatures().get(idKey).iterator().next();
Set<Integer> candidates = getCandidates(keys);
for (Integer candidate : candidates) {
FeatureVector supportVector = dictionaryList.get(candidate);
double sim = similarity(featureVector, supportVector, featureKey);
Set<String> idSet = supportVector.getStringFeatures().get(idKey);
String id = idSet.iterator().next();
if (id == myId) {
continue;
}
SimpleEntry<String, Double> entry = new SimpleEntry<String, Double>(id, sim);
pq.add(entry);
if (pq.size() > options.getNumNearest()) {
pq.poll();
}
}
HashMap<String, Double> newFeature = new HashMap<>();
while (pq.peek() != null) {
SimpleEntry<String, Double> entry = pq.poll();
newFeature.put(entry.getKey(), entry.getValue());
}
floatFeatures.put(options.getOutputKey(), newFeature);
result.setFloatFeatures(floatFeatures);
return result;
}
use of com.airbnb.aerosolve.core.FeatureVector in project aerosolve by airbnb.
the class LowRankLinearModelTest method testLoad.
@Test
public void testLoad() {
CharArrayWriter charWriter = new CharArrayWriter();
BufferedWriter writer = new BufferedWriter(charWriter);
ModelHeader header = new ModelHeader();
header.setModelType("low_rank_linear");
header.setLabelDictionary(makeLabelDictionary());
Map<String, FloatVector> labelWeightVector = makeLabelWeightVector();
Map<String, java.util.List<Double>> labelEmbedding = new HashMap<>();
for (Map.Entry<String, FloatVector> labelRepresentation : labelWeightVector.entrySet()) {
float[] values = labelRepresentation.getValue().getValues();
ArrayList<Double> arrayList = new ArrayList<>();
for (int i = 0; i < 3; i++) {
arrayList.add((double) values[i]);
}
labelEmbedding.put(labelRepresentation.getKey(), arrayList);
}
header.setLabelEmbedding(labelEmbedding);
header.setNumRecords(4);
ArrayList<Double> ws = new ArrayList<>();
ws.add(1.0);
ws.add(0.0);
ws.add(0.0);
ModelRecord record1 = new ModelRecord();
record1.setModelHeader(header);
ModelRecord record2 = new ModelRecord();
record2.setFeatureFamily("a");
record2.setFeatureName("cat");
record2.setWeightVector(ws);
ModelRecord record3 = new ModelRecord();
record3.setFeatureFamily("a");
record3.setFeatureName("dog");
record3.setWeightVector(ws);
ModelRecord record4 = new ModelRecord();
record4.setFeatureFamily("a");
record4.setFeatureName("fish");
record4.setWeightVector(ws);
ModelRecord record5 = new ModelRecord();
record5.setFeatureFamily("a");
record5.setFeatureName("horse");
record5.setWeightVector(ws);
try {
writer.write(Util.encode(record1) + "\n");
writer.write(Util.encode(record2) + "\n");
writer.write(Util.encode(record3) + "\n");
writer.write(Util.encode(record4) + "\n");
writer.write(Util.encode(record5) + "\n");
writer.close();
} catch (IOException e) {
assertTrue("Could not write", false);
}
String serialized = charWriter.toString();
assertTrue(serialized.length() > 0);
StringReader strReader = new StringReader(serialized);
BufferedReader reader = new BufferedReader(strReader);
FeatureVector animalFv = makeFeatureVector("animal");
FeatureVector colorFv = makeFeatureVector("color");
try {
Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
assertTrue(model.isPresent());
ArrayList<MulticlassScoringResult> s1 = model.get().scoreItemMulticlass(animalFv);
assertEquals(s1.size(), 3);
assertEquals(0.0f, s1.get(0).score, 3.0f);
assertEquals(0.0f, s1.get(1).score, 1e-10f);
assertEquals(0.0f, s1.get(2).score, 1e-10f);
ArrayList<MulticlassScoringResult> s2 = model.get().scoreItemMulticlass(colorFv);
assertEquals(s2.size(), 3);
assertEquals(0.0f, s2.get(0).score, 1e-10f);
assertEquals(0.0f, s2.get(1).score, 1e-10f);
assertEquals(0.0f, s2.get(2).score, 1e-10f);
} catch (IOException e) {
assertTrue("Could not read", false);
}
}
Aggregations