use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class LinearModelTest method testLoad.
@Test
public void testLoad() {
CharArrayWriter charWriter = new CharArrayWriter();
BufferedWriter writer = new BufferedWriter(charWriter);
ModelHeader header = new ModelHeader();
header.setModelType("linear");
header.setNumRecords(1);
ModelRecord record1 = new ModelRecord();
record1.setModelHeader(header);
ModelRecord record2 = new ModelRecord();
record2.setFeatureFamily("string_feature");
record2.setFeatureName("bbb");
record2.setFeatureWeight(0.9f);
try {
writer.write(Util.encode(record1) + "\n");
writer.write(Util.encode(record2) + "\n");
writer.close();
} catch (IOException e) {
assertTrue("Could not write", false);
}
String serialized = charWriter.toString();
assertTrue(serialized.length() > 0);
StringReader strReader = new StringReader(serialized);
BufferedReader reader = new BufferedReader(strReader);
FeatureVector featureVector = makeFeatureVector();
try {
Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
assertTrue(model.isPresent());
float score = model.get().scoreItem(featureVector);
assertTrue(score > 0.89f);
assertTrue(score < 0.91f);
} catch (IOException e) {
assertTrue("Could not read", false);
}
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class LowRankLinearModelTest method testLoad.
@Test
public void testLoad() {
CharArrayWriter charWriter = new CharArrayWriter();
BufferedWriter writer = new BufferedWriter(charWriter);
ModelHeader header = new ModelHeader();
header.setModelType("low_rank_linear");
header.setLabelDictionary(makeLabelDictionary());
Map<String, FloatVector> labelWeightVector = makeLabelWeightVector();
Map<String, java.util.List<Double>> labelEmbedding = new HashMap<>();
for (Map.Entry<String, FloatVector> labelRepresentation : labelWeightVector.entrySet()) {
float[] values = labelRepresentation.getValue().getValues();
ArrayList<Double> arrayList = new ArrayList<>();
for (int i = 0; i < 3; i++) {
arrayList.add((double) values[i]);
}
labelEmbedding.put(labelRepresentation.getKey(), arrayList);
}
header.setLabelEmbedding(labelEmbedding);
header.setNumRecords(4);
ArrayList<Double> ws = new ArrayList<>();
ws.add(1.0);
ws.add(0.0);
ws.add(0.0);
ModelRecord record1 = new ModelRecord();
record1.setModelHeader(header);
ModelRecord record2 = new ModelRecord();
record2.setFeatureFamily("a");
record2.setFeatureName("cat");
record2.setWeightVector(ws);
ModelRecord record3 = new ModelRecord();
record3.setFeatureFamily("a");
record3.setFeatureName("dog");
record3.setWeightVector(ws);
ModelRecord record4 = new ModelRecord();
record4.setFeatureFamily("a");
record4.setFeatureName("fish");
record4.setWeightVector(ws);
ModelRecord record5 = new ModelRecord();
record5.setFeatureFamily("a");
record5.setFeatureName("horse");
record5.setWeightVector(ws);
try {
writer.write(Util.encode(record1) + "\n");
writer.write(Util.encode(record2) + "\n");
writer.write(Util.encode(record3) + "\n");
writer.write(Util.encode(record4) + "\n");
writer.write(Util.encode(record5) + "\n");
writer.close();
} catch (IOException e) {
assertTrue("Could not write", false);
}
String serialized = charWriter.toString();
assertTrue(serialized.length() > 0);
StringReader strReader = new StringReader(serialized);
BufferedReader reader = new BufferedReader(strReader);
FeatureVector animalFv = makeFeatureVector("animal");
FeatureVector colorFv = makeFeatureVector("color");
try {
Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
assertTrue(model.isPresent());
ArrayList<MulticlassScoringResult> s1 = model.get().scoreItemMulticlass(animalFv);
assertEquals(s1.size(), 3);
assertEquals(0.0f, s1.get(0).score, 3.0f);
assertEquals(0.0f, s1.get(1).score, 1e-10f);
assertEquals(0.0f, s1.get(2).score, 1e-10f);
ArrayList<MulticlassScoringResult> s2 = model.get().scoreItemMulticlass(colorFv);
assertEquals(s2.size(), 3);
assertEquals(0.0f, s2.get(0).score, 1e-10f);
assertEquals(0.0f, s2.get(1).score, 1e-10f);
assertEquals(0.0f, s2.get(2).score, 1e-10f);
} catch (IOException e) {
assertTrue("Could not read", false);
}
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class KernelModel method save.
@Override
public void save(BufferedWriter writer) throws IOException {
ModelHeader header = new ModelHeader();
header.setModelType("kernel");
header.setDictionary(dictionary.getDictionary());
long count = supportVectors.size();
header.setNumRecords(count);
ModelRecord headerRec = new ModelRecord();
headerRec.setModelHeader(header);
writer.write(Util.encode(headerRec));
writer.newLine();
for (SupportVector sv : supportVectors) {
writer.write(Util.encode(sv.toModelRecord()));
writer.newLine();
}
writer.flush();
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class LinearModel method save.
// save model
public void save(BufferedWriter writer) throws IOException {
ModelHeader header = new ModelHeader();
header.setModelType("linear");
header.setSlope(slope);
header.setOffset(offset);
long count = 0;
for (Map.Entry<String, Map<String, Float>> familyMap : weights.entrySet()) {
for (Map.Entry<String, Float> feature : familyMap.getValue().entrySet()) {
count++;
}
}
header.setNumRecords(count);
ModelRecord headerRec = new ModelRecord();
headerRec.setModelHeader(header);
writer.write(Util.encode(headerRec));
writer.newLine();
for (Map.Entry<String, Map<String, Float>> familyMap : weights.entrySet()) {
for (Map.Entry<String, Float> feature : familyMap.getValue().entrySet()) {
ModelRecord record = new ModelRecord();
record.setFeatureFamily(familyMap.getKey());
record.setFeatureName(feature.getKey());
record.setFeatureWeight(feature.getValue());
writer.write(Util.encode(record));
writer.newLine();
}
}
writer.flush();
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class MaxoutModel method loadInternal.
@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
long rows = header.getNumRecords();
numHidden = header.getNumHidden();
weightVector = new HashMap<>();
for (long i = 0; i < rows; i++) {
String line = reader.readLine();
ModelRecord record = Util.decodeModel(line);
String family = record.getFeatureFamily();
String name = record.getFeatureName();
Map<String, WeightVector> inner = weightVector.get(family);
if (inner == null) {
inner = new HashMap<>();
weightVector.put(family, inner);
}
WeightVector vec = new WeightVector();
vec.scale = (float) record.getScale();
vec.weights = new FloatVector(numHidden);
for (int j = 0; j < numHidden; j++) {
vec.weights.values[j] = record.getWeightVector().get(j).floatValue();
}
inner.put(name, vec);
}
Map<String, WeightVector> special = weightVector.get("$SPECIAL");
assert (special != null);
bias = special.get("$BIAS");
assert (bias != null);
}
Aggregations