Search in sources :

Example 6 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class LinearModelTest method testLoad.

@Test
public void testLoad() {
    CharArrayWriter charWriter = new CharArrayWriter();
    BufferedWriter writer = new BufferedWriter(charWriter);
    ModelHeader header = new ModelHeader();
    header.setModelType("linear");
    header.setNumRecords(1);
    ModelRecord record1 = new ModelRecord();
    record1.setModelHeader(header);
    ModelRecord record2 = new ModelRecord();
    record2.setFeatureFamily("string_feature");
    record2.setFeatureName("bbb");
    record2.setFeatureWeight(0.9f);
    try {
        writer.write(Util.encode(record1) + "\n");
        writer.write(Util.encode(record2) + "\n");
        writer.close();
    } catch (IOException e) {
        assertTrue("Could not write", false);
    }
    String serialized = charWriter.toString();
    assertTrue(serialized.length() > 0);
    StringReader strReader = new StringReader(serialized);
    BufferedReader reader = new BufferedReader(strReader);
    FeatureVector featureVector = makeFeatureVector();
    try {
        Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
        assertTrue(model.isPresent());
        float score = model.get().scoreItem(featureVector);
        assertTrue(score > 0.89f);
        assertTrue(score < 0.91f);
    } catch (IOException e) {
        assertTrue("Could not read", false);
    }
}
Also used : FeatureVector(com.airbnb.aerosolve.core.FeatureVector) ModelHeader(com.airbnb.aerosolve.core.ModelHeader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord) Test(org.junit.Test)

Example 7 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class LowRankLinearModelTest method testLoad.

@Test
public void testLoad() {
    CharArrayWriter charWriter = new CharArrayWriter();
    BufferedWriter writer = new BufferedWriter(charWriter);
    ModelHeader header = new ModelHeader();
    header.setModelType("low_rank_linear");
    header.setLabelDictionary(makeLabelDictionary());
    Map<String, FloatVector> labelWeightVector = makeLabelWeightVector();
    Map<String, java.util.List<Double>> labelEmbedding = new HashMap<>();
    for (Map.Entry<String, FloatVector> labelRepresentation : labelWeightVector.entrySet()) {
        float[] values = labelRepresentation.getValue().getValues();
        ArrayList<Double> arrayList = new ArrayList<>();
        for (int i = 0; i < 3; i++) {
            arrayList.add((double) values[i]);
        }
        labelEmbedding.put(labelRepresentation.getKey(), arrayList);
    }
    header.setLabelEmbedding(labelEmbedding);
    header.setNumRecords(4);
    ArrayList<Double> ws = new ArrayList<>();
    ws.add(1.0);
    ws.add(0.0);
    ws.add(0.0);
    ModelRecord record1 = new ModelRecord();
    record1.setModelHeader(header);
    ModelRecord record2 = new ModelRecord();
    record2.setFeatureFamily("a");
    record2.setFeatureName("cat");
    record2.setWeightVector(ws);
    ModelRecord record3 = new ModelRecord();
    record3.setFeatureFamily("a");
    record3.setFeatureName("dog");
    record3.setWeightVector(ws);
    ModelRecord record4 = new ModelRecord();
    record4.setFeatureFamily("a");
    record4.setFeatureName("fish");
    record4.setWeightVector(ws);
    ModelRecord record5 = new ModelRecord();
    record5.setFeatureFamily("a");
    record5.setFeatureName("horse");
    record5.setWeightVector(ws);
    try {
        writer.write(Util.encode(record1) + "\n");
        writer.write(Util.encode(record2) + "\n");
        writer.write(Util.encode(record3) + "\n");
        writer.write(Util.encode(record4) + "\n");
        writer.write(Util.encode(record5) + "\n");
        writer.close();
    } catch (IOException e) {
        assertTrue("Could not write", false);
    }
    String serialized = charWriter.toString();
    assertTrue(serialized.length() > 0);
    StringReader strReader = new StringReader(serialized);
    BufferedReader reader = new BufferedReader(strReader);
    FeatureVector animalFv = makeFeatureVector("animal");
    FeatureVector colorFv = makeFeatureVector("color");
    try {
        Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
        assertTrue(model.isPresent());
        ArrayList<MulticlassScoringResult> s1 = model.get().scoreItemMulticlass(animalFv);
        assertEquals(s1.size(), 3);
        assertEquals(0.0f, s1.get(0).score, 3.0f);
        assertEquals(0.0f, s1.get(1).score, 1e-10f);
        assertEquals(0.0f, s1.get(2).score, 1e-10f);
        ArrayList<MulticlassScoringResult> s2 = model.get().scoreItemMulticlass(colorFv);
        assertEquals(s2.size(), 3);
        assertEquals(0.0f, s2.get(0).score, 1e-10f);
        assertEquals(0.0f, s2.get(1).score, 1e-10f);
        assertEquals(0.0f, s2.get(2).score, 1e-10f);
    } catch (IOException e) {
        assertTrue("Could not read", false);
    }
}
Also used : FeatureVector(com.airbnb.aerosolve.core.FeatureVector) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) FloatVector(com.airbnb.aerosolve.core.util.FloatVector) MulticlassScoringResult(com.airbnb.aerosolve.core.MulticlassScoringResult) CharArrayWriter(java.io.CharArrayWriter) BufferedWriter(java.io.BufferedWriter) ModelHeader(com.airbnb.aerosolve.core.ModelHeader) StringReader(java.io.StringReader) ArrayList(java.util.ArrayList) IOException(java.io.IOException) BufferedReader(java.io.BufferedReader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord) HashMap(java.util.HashMap) Map(java.util.Map) Test(org.junit.Test)

Example 8 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class KernelModel method save.

@Override
public void save(BufferedWriter writer) throws IOException {
    ModelHeader header = new ModelHeader();
    header.setModelType("kernel");
    header.setDictionary(dictionary.getDictionary());
    long count = supportVectors.size();
    header.setNumRecords(count);
    ModelRecord headerRec = new ModelRecord();
    headerRec.setModelHeader(header);
    writer.write(Util.encode(headerRec));
    writer.newLine();
    for (SupportVector sv : supportVectors) {
        writer.write(Util.encode(sv.toModelRecord()));
        writer.newLine();
    }
    writer.flush();
}
Also used : ModelHeader(com.airbnb.aerosolve.core.ModelHeader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord) SupportVector(com.airbnb.aerosolve.core.util.SupportVector)

Example 9 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class LinearModel method save.

// save model
public void save(BufferedWriter writer) throws IOException {
    ModelHeader header = new ModelHeader();
    header.setModelType("linear");
    header.setSlope(slope);
    header.setOffset(offset);
    long count = 0;
    for (Map.Entry<String, Map<String, Float>> familyMap : weights.entrySet()) {
        for (Map.Entry<String, Float> feature : familyMap.getValue().entrySet()) {
            count++;
        }
    }
    header.setNumRecords(count);
    ModelRecord headerRec = new ModelRecord();
    headerRec.setModelHeader(header);
    writer.write(Util.encode(headerRec));
    writer.newLine();
    for (Map.Entry<String, Map<String, Float>> familyMap : weights.entrySet()) {
        for (Map.Entry<String, Float> feature : familyMap.getValue().entrySet()) {
            ModelRecord record = new ModelRecord();
            record.setFeatureFamily(familyMap.getKey());
            record.setFeatureName(feature.getKey());
            record.setFeatureWeight(feature.getValue());
            writer.write(Util.encode(record));
            writer.newLine();
        }
    }
    writer.flush();
}
Also used : ModelHeader(com.airbnb.aerosolve.core.ModelHeader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Example 10 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class MaxoutModel method loadInternal.

@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
    long rows = header.getNumRecords();
    numHidden = header.getNumHidden();
    weightVector = new HashMap<>();
    for (long i = 0; i < rows; i++) {
        String line = reader.readLine();
        ModelRecord record = Util.decodeModel(line);
        String family = record.getFeatureFamily();
        String name = record.getFeatureName();
        Map<String, WeightVector> inner = weightVector.get(family);
        if (inner == null) {
            inner = new HashMap<>();
            weightVector.put(family, inner);
        }
        WeightVector vec = new WeightVector();
        vec.scale = (float) record.getScale();
        vec.weights = new FloatVector(numHidden);
        for (int j = 0; j < numHidden; j++) {
            vec.weights.values[j] = record.getWeightVector().get(j).floatValue();
        }
        inner.put(name, vec);
    }
    Map<String, WeightVector> special = weightVector.get("$SPECIAL");
    assert (special != null);
    bias = special.get("$BIAS");
    assert (bias != null);
}
Also used : ModelRecord(com.airbnb.aerosolve.core.ModelRecord) FloatVector(com.airbnb.aerosolve.core.util.FloatVector)

Aggregations

ModelRecord (com.airbnb.aerosolve.core.ModelRecord)40 ModelHeader (com.airbnb.aerosolve.core.ModelHeader)14 Test (org.junit.Test)11 ArrayList (java.util.ArrayList)10 FloatVector (com.airbnb.aerosolve.core.util.FloatVector)8 Map (java.util.Map)5 FeatureVector (com.airbnb.aerosolve.core.FeatureVector)4 BufferedReader (java.io.BufferedReader)3 BufferedWriter (java.io.BufferedWriter)3 CharArrayWriter (java.io.CharArrayWriter)3 IOException (java.io.IOException)3 StringReader (java.io.StringReader)3 LabelDictionaryEntry (com.airbnb.aerosolve.core.LabelDictionaryEntry)2 AbstractFunction (com.airbnb.aerosolve.core.function.AbstractFunction)2 Function (com.airbnb.aerosolve.core.function.Function)2 NDTreeModelTest (com.airbnb.aerosolve.core.models.NDTreeModelTest)2 SupportVector (com.airbnb.aerosolve.core.util.SupportVector)2 HashMap (java.util.HashMap)2 DebugScoreRecord (com.airbnb.aerosolve.core.DebugScoreRecord)1 MulticlassScoringResult (com.airbnb.aerosolve.core.MulticlassScoringResult)1