Search in sources :

Example 1 with MulticlassScoringResult

use of com.airbnb.aerosolve.core.MulticlassScoringResult in project aerosolve by airbnb.

the class FullRankLinearModel method scoreItemMulticlass.

public ArrayList<MulticlassScoringResult> scoreItemMulticlass(FeatureVector combinedItem) {
    ArrayList<MulticlassScoringResult> results = new ArrayList<>();
    Map<String, Map<String, Double>> flatFeatures = Util.flattenFeature(combinedItem);
    FloatVector sum = scoreFlatFeature(flatFeatures);
    for (int i = 0; i < labelDictionary.size(); i++) {
        MulticlassScoringResult result = new MulticlassScoringResult();
        result.setLabel(labelDictionary.get(i).getLabel());
        result.setScore(sum.values[i]);
        results.add(result);
    }
    return results;
}
Also used : MulticlassScoringResult(com.airbnb.aerosolve.core.MulticlassScoringResult) FloatVector(com.airbnb.aerosolve.core.util.FloatVector)

Example 2 with MulticlassScoringResult

use of com.airbnb.aerosolve.core.MulticlassScoringResult in project aerosolve by airbnb.

the class LowRankLinearModelTest method testScoreNonEmptyFeature.

@Test
public void testScoreNonEmptyFeature() {
    FeatureVector animalFv = makeFeatureVector("animal");
    FeatureVector colorFv = makeFeatureVector("color");
    FeatureVector fruitFv = makeFeatureVector("fruit");
    LowRankLinearModel model = makeLowRankLinearModel();
    ArrayList<MulticlassScoringResult> s1 = model.scoreItemMulticlass(animalFv);
    assertEquals(s1.size(), 3);
    assertEquals(0.0f, s1.get(0).score, 3.0f);
    assertEquals(0.0f, s1.get(1).score, 1e-10f);
    assertEquals(0.0f, s1.get(2).score, 1e-10f);
    ArrayList<MulticlassScoringResult> s2 = model.scoreItemMulticlass(colorFv);
    assertEquals(s2.size(), 3);
    assertEquals(0.0f, s2.get(0).score, 1e-10f);
    assertEquals(0.0f, s2.get(1).score, 6.0f);
    assertEquals(0.0f, s2.get(2).score, 1e-10f);
    ArrayList<MulticlassScoringResult> s3 = model.scoreItemMulticlass(fruitFv);
    assertEquals(s3.size(), 3);
    assertEquals(0.0f, s3.get(0).score, 1e-10f);
    assertEquals(0.0f, s3.get(1).score, 1e-10f);
    assertEquals(0.0f, s3.get(2).score, 4.0f);
}
Also used : FeatureVector(com.airbnb.aerosolve.core.FeatureVector) MulticlassScoringResult(com.airbnb.aerosolve.core.MulticlassScoringResult) Test(org.junit.Test)

Example 3 with MulticlassScoringResult

use of com.airbnb.aerosolve.core.MulticlassScoringResult in project aerosolve by airbnb.

the class LowRankLinearModelTest method testLoad.

@Test
public void testLoad() {
    CharArrayWriter charWriter = new CharArrayWriter();
    BufferedWriter writer = new BufferedWriter(charWriter);
    ModelHeader header = new ModelHeader();
    header.setModelType("low_rank_linear");
    header.setLabelDictionary(makeLabelDictionary());
    Map<String, FloatVector> labelWeightVector = makeLabelWeightVector();
    Map<String, java.util.List<Double>> labelEmbedding = new HashMap<>();
    for (Map.Entry<String, FloatVector> labelRepresentation : labelWeightVector.entrySet()) {
        float[] values = labelRepresentation.getValue().getValues();
        ArrayList<Double> arrayList = new ArrayList<>();
        for (int i = 0; i < 3; i++) {
            arrayList.add((double) values[i]);
        }
        labelEmbedding.put(labelRepresentation.getKey(), arrayList);
    }
    header.setLabelEmbedding(labelEmbedding);
    header.setNumRecords(4);
    ArrayList<Double> ws = new ArrayList<>();
    ws.add(1.0);
    ws.add(0.0);
    ws.add(0.0);
    ModelRecord record1 = new ModelRecord();
    record1.setModelHeader(header);
    ModelRecord record2 = new ModelRecord();
    record2.setFeatureFamily("a");
    record2.setFeatureName("cat");
    record2.setWeightVector(ws);
    ModelRecord record3 = new ModelRecord();
    record3.setFeatureFamily("a");
    record3.setFeatureName("dog");
    record3.setWeightVector(ws);
    ModelRecord record4 = new ModelRecord();
    record4.setFeatureFamily("a");
    record4.setFeatureName("fish");
    record4.setWeightVector(ws);
    ModelRecord record5 = new ModelRecord();
    record5.setFeatureFamily("a");
    record5.setFeatureName("horse");
    record5.setWeightVector(ws);
    try {
        writer.write(Util.encode(record1) + "\n");
        writer.write(Util.encode(record2) + "\n");
        writer.write(Util.encode(record3) + "\n");
        writer.write(Util.encode(record4) + "\n");
        writer.write(Util.encode(record5) + "\n");
        writer.close();
    } catch (IOException e) {
        assertTrue("Could not write", false);
    }
    String serialized = charWriter.toString();
    assertTrue(serialized.length() > 0);
    StringReader strReader = new StringReader(serialized);
    BufferedReader reader = new BufferedReader(strReader);
    FeatureVector animalFv = makeFeatureVector("animal");
    FeatureVector colorFv = makeFeatureVector("color");
    try {
        Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
        assertTrue(model.isPresent());
        ArrayList<MulticlassScoringResult> s1 = model.get().scoreItemMulticlass(animalFv);
        assertEquals(s1.size(), 3);
        assertEquals(0.0f, s1.get(0).score, 3.0f);
        assertEquals(0.0f, s1.get(1).score, 1e-10f);
        assertEquals(0.0f, s1.get(2).score, 1e-10f);
        ArrayList<MulticlassScoringResult> s2 = model.get().scoreItemMulticlass(colorFv);
        assertEquals(s2.size(), 3);
        assertEquals(0.0f, s2.get(0).score, 1e-10f);
        assertEquals(0.0f, s2.get(1).score, 1e-10f);
        assertEquals(0.0f, s2.get(2).score, 1e-10f);
    } catch (IOException e) {
        assertTrue("Could not read", false);
    }
}
Also used : FeatureVector(com.airbnb.aerosolve.core.FeatureVector) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) FloatVector(com.airbnb.aerosolve.core.util.FloatVector) MulticlassScoringResult(com.airbnb.aerosolve.core.MulticlassScoringResult) CharArrayWriter(java.io.CharArrayWriter) BufferedWriter(java.io.BufferedWriter) ModelHeader(com.airbnb.aerosolve.core.ModelHeader) StringReader(java.io.StringReader) ArrayList(java.util.ArrayList) IOException(java.io.IOException) BufferedReader(java.io.BufferedReader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord) HashMap(java.util.HashMap) Map(java.util.Map) Test(org.junit.Test)

Example 4 with MulticlassScoringResult

use of com.airbnb.aerosolve.core.MulticlassScoringResult in project aerosolve by airbnb.

the class LowRankLinearModel method scoreItemMulticlass.

public ArrayList<MulticlassScoringResult> scoreItemMulticlass(FeatureVector combinedItem) {
    ArrayList<MulticlassScoringResult> results = new ArrayList<>();
    Map<String, Map<String, Double>> flatFeatures = Util.flattenFeature(combinedItem);
    FloatVector sum = scoreFlatFeature(flatFeatures);
    for (int i = 0; i < labelDictionary.size(); i++) {
        MulticlassScoringResult result = new MulticlassScoringResult();
        result.setLabel(labelDictionary.get(i).getLabel());
        result.setScore(sum.values[i]);
        results.add(result);
    }
    return results;
}
Also used : MulticlassScoringResult(com.airbnb.aerosolve.core.MulticlassScoringResult) FloatVector(com.airbnb.aerosolve.core.util.FloatVector)

Example 5 with MulticlassScoringResult

use of com.airbnb.aerosolve.core.MulticlassScoringResult in project aerosolve by airbnb.

the class LowRankLinearModelTest method testScoreEmptyFeature.

@Test
public void testScoreEmptyFeature() {
    FeatureVector featureVector = new FeatureVector();
    LowRankLinearModel model = makeLowRankLinearModel();
    ArrayList<MulticlassScoringResult> score = model.scoreItemMulticlass(featureVector);
    assertEquals(score.size(), 3);
    assertEquals(0.0f, score.get(0).score, 1e-10f);
    assertEquals(0.0f, score.get(1).score, 1e-10f);
    assertEquals(0.0f, score.get(2).score, 1e-10f);
}
Also used : FeatureVector(com.airbnb.aerosolve.core.FeatureVector) MulticlassScoringResult(com.airbnb.aerosolve.core.MulticlassScoringResult) Test(org.junit.Test)

Aggregations

MulticlassScoringResult (com.airbnb.aerosolve.core.MulticlassScoringResult)5 FeatureVector (com.airbnb.aerosolve.core.FeatureVector)3 FloatVector (com.airbnb.aerosolve.core.util.FloatVector)3 Test (org.junit.Test)3 ModelHeader (com.airbnb.aerosolve.core.ModelHeader)1 ModelRecord (com.airbnb.aerosolve.core.ModelRecord)1 BufferedReader (java.io.BufferedReader)1 BufferedWriter (java.io.BufferedWriter)1 CharArrayWriter (java.io.CharArrayWriter)1 IOException (java.io.IOException)1 StringReader (java.io.StringReader)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1