use of com.airbnb.aerosolve.core.MulticlassScoringResult in project aerosolve by airbnb.
the class FullRankLinearModel method scoreItemMulticlass.
public ArrayList<MulticlassScoringResult> scoreItemMulticlass(FeatureVector combinedItem) {
ArrayList<MulticlassScoringResult> results = new ArrayList<>();
Map<String, Map<String, Double>> flatFeatures = Util.flattenFeature(combinedItem);
FloatVector sum = scoreFlatFeature(flatFeatures);
for (int i = 0; i < labelDictionary.size(); i++) {
MulticlassScoringResult result = new MulticlassScoringResult();
result.setLabel(labelDictionary.get(i).getLabel());
result.setScore(sum.values[i]);
results.add(result);
}
return results;
}
use of com.airbnb.aerosolve.core.MulticlassScoringResult in project aerosolve by airbnb.
the class LowRankLinearModelTest method testScoreNonEmptyFeature.
@Test
public void testScoreNonEmptyFeature() {
FeatureVector animalFv = makeFeatureVector("animal");
FeatureVector colorFv = makeFeatureVector("color");
FeatureVector fruitFv = makeFeatureVector("fruit");
LowRankLinearModel model = makeLowRankLinearModel();
ArrayList<MulticlassScoringResult> s1 = model.scoreItemMulticlass(animalFv);
assertEquals(s1.size(), 3);
assertEquals(0.0f, s1.get(0).score, 3.0f);
assertEquals(0.0f, s1.get(1).score, 1e-10f);
assertEquals(0.0f, s1.get(2).score, 1e-10f);
ArrayList<MulticlassScoringResult> s2 = model.scoreItemMulticlass(colorFv);
assertEquals(s2.size(), 3);
assertEquals(0.0f, s2.get(0).score, 1e-10f);
assertEquals(0.0f, s2.get(1).score, 6.0f);
assertEquals(0.0f, s2.get(2).score, 1e-10f);
ArrayList<MulticlassScoringResult> s3 = model.scoreItemMulticlass(fruitFv);
assertEquals(s3.size(), 3);
assertEquals(0.0f, s3.get(0).score, 1e-10f);
assertEquals(0.0f, s3.get(1).score, 1e-10f);
assertEquals(0.0f, s3.get(2).score, 4.0f);
}
use of com.airbnb.aerosolve.core.MulticlassScoringResult in project aerosolve by airbnb.
the class LowRankLinearModelTest method testLoad.
@Test
public void testLoad() {
CharArrayWriter charWriter = new CharArrayWriter();
BufferedWriter writer = new BufferedWriter(charWriter);
ModelHeader header = new ModelHeader();
header.setModelType("low_rank_linear");
header.setLabelDictionary(makeLabelDictionary());
Map<String, FloatVector> labelWeightVector = makeLabelWeightVector();
Map<String, java.util.List<Double>> labelEmbedding = new HashMap<>();
for (Map.Entry<String, FloatVector> labelRepresentation : labelWeightVector.entrySet()) {
float[] values = labelRepresentation.getValue().getValues();
ArrayList<Double> arrayList = new ArrayList<>();
for (int i = 0; i < 3; i++) {
arrayList.add((double) values[i]);
}
labelEmbedding.put(labelRepresentation.getKey(), arrayList);
}
header.setLabelEmbedding(labelEmbedding);
header.setNumRecords(4);
ArrayList<Double> ws = new ArrayList<>();
ws.add(1.0);
ws.add(0.0);
ws.add(0.0);
ModelRecord record1 = new ModelRecord();
record1.setModelHeader(header);
ModelRecord record2 = new ModelRecord();
record2.setFeatureFamily("a");
record2.setFeatureName("cat");
record2.setWeightVector(ws);
ModelRecord record3 = new ModelRecord();
record3.setFeatureFamily("a");
record3.setFeatureName("dog");
record3.setWeightVector(ws);
ModelRecord record4 = new ModelRecord();
record4.setFeatureFamily("a");
record4.setFeatureName("fish");
record4.setWeightVector(ws);
ModelRecord record5 = new ModelRecord();
record5.setFeatureFamily("a");
record5.setFeatureName("horse");
record5.setWeightVector(ws);
try {
writer.write(Util.encode(record1) + "\n");
writer.write(Util.encode(record2) + "\n");
writer.write(Util.encode(record3) + "\n");
writer.write(Util.encode(record4) + "\n");
writer.write(Util.encode(record5) + "\n");
writer.close();
} catch (IOException e) {
assertTrue("Could not write", false);
}
String serialized = charWriter.toString();
assertTrue(serialized.length() > 0);
StringReader strReader = new StringReader(serialized);
BufferedReader reader = new BufferedReader(strReader);
FeatureVector animalFv = makeFeatureVector("animal");
FeatureVector colorFv = makeFeatureVector("color");
try {
Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
assertTrue(model.isPresent());
ArrayList<MulticlassScoringResult> s1 = model.get().scoreItemMulticlass(animalFv);
assertEquals(s1.size(), 3);
assertEquals(0.0f, s1.get(0).score, 3.0f);
assertEquals(0.0f, s1.get(1).score, 1e-10f);
assertEquals(0.0f, s1.get(2).score, 1e-10f);
ArrayList<MulticlassScoringResult> s2 = model.get().scoreItemMulticlass(colorFv);
assertEquals(s2.size(), 3);
assertEquals(0.0f, s2.get(0).score, 1e-10f);
assertEquals(0.0f, s2.get(1).score, 1e-10f);
assertEquals(0.0f, s2.get(2).score, 1e-10f);
} catch (IOException e) {
assertTrue("Could not read", false);
}
}
use of com.airbnb.aerosolve.core.MulticlassScoringResult in project aerosolve by airbnb.
the class LowRankLinearModel method scoreItemMulticlass.
public ArrayList<MulticlassScoringResult> scoreItemMulticlass(FeatureVector combinedItem) {
ArrayList<MulticlassScoringResult> results = new ArrayList<>();
Map<String, Map<String, Double>> flatFeatures = Util.flattenFeature(combinedItem);
FloatVector sum = scoreFlatFeature(flatFeatures);
for (int i = 0; i < labelDictionary.size(); i++) {
MulticlassScoringResult result = new MulticlassScoringResult();
result.setLabel(labelDictionary.get(i).getLabel());
result.setScore(sum.values[i]);
results.add(result);
}
return results;
}
use of com.airbnb.aerosolve.core.MulticlassScoringResult in project aerosolve by airbnb.
the class LowRankLinearModelTest method testScoreEmptyFeature.
@Test
public void testScoreEmptyFeature() {
FeatureVector featureVector = new FeatureVector();
LowRankLinearModel model = makeLowRankLinearModel();
ArrayList<MulticlassScoringResult> score = model.scoreItemMulticlass(featureVector);
assertEquals(score.size(), 3);
assertEquals(0.0f, score.get(0).score, 1e-10f);
assertEquals(0.0f, score.get(1).score, 1e-10f);
assertEquals(0.0f, score.get(2).score, 1e-10f);
}
Aggregations