use of com.airbnb.aerosolve.core.ModelHeader in project aerosolve by airbnb.
the class ModelFactory method createFromReader.
public static Optional<AbstractModel> createFromReader(BufferedReader reader) throws IOException {
Optional<AbstractModel> model = Optional.absent();
String headerLine = reader.readLine();
ModelRecord record = Util.decodeModel(headerLine);
if (record == null) {
log.error("Could not decode header " + headerLine);
return model;
}
ModelHeader header = record.getModelHeader();
if (header != null) {
AbstractModel result = createByName(header.getModelType());
if (result != null) {
result.loadInternal(header, reader);
model = Optional.of(result);
}
}
return model;
}
use of com.airbnb.aerosolve.core.ModelHeader in project aerosolve by airbnb.
the class SplineModel method save.
@Override
public void save(BufferedWriter writer) throws IOException {
ModelHeader header = new ModelHeader();
header.setModelType("spline");
header.setNumHidden(numBins);
header.setSlope(slope);
header.setOffset(offset);
long count = 0;
for (Map.Entry<String, Map<String, WeightSpline>> familyMap : weightSpline.entrySet()) {
for (Map.Entry<String, WeightSpline> feature : familyMap.getValue().entrySet()) {
count++;
}
}
header.setNumRecords(count);
ModelRecord headerRec = new ModelRecord();
headerRec.setModelHeader(header);
writer.write(Util.encode(headerRec));
writer.newLine();
for (Map.Entry<String, Map<String, WeightSpline>> familyMap : weightSpline.entrySet()) {
for (Map.Entry<String, WeightSpline> feature : familyMap.getValue().entrySet()) {
ModelRecord record = new ModelRecord();
record.setFeatureFamily(familyMap.getKey());
record.setFeatureName(feature.getKey());
ArrayList<Double> arrayList = new ArrayList<Double>();
for (int i = 0; i < feature.getValue().splineWeights.length; i++) {
arrayList.add((double) feature.getValue().splineWeights[i]);
}
record.setWeightVector(arrayList);
record.setMinVal(feature.getValue().spline.getMinVal());
record.setMaxVal(feature.getValue().spline.getMaxVal());
writer.write(Util.encode(record));
writer.newLine();
}
}
writer.flush();
}
use of com.airbnb.aerosolve.core.ModelHeader in project aerosolve by airbnb.
the class MlpModelTest method testLoad.
@Test
public void testLoad() {
CharArrayWriter charWriter = new CharArrayWriter();
BufferedWriter writer = new BufferedWriter(charWriter);
// create header record
ModelHeader header = new ModelHeader();
header.setModelType("multilayer_perceptron");
header.setNumHiddenLayers(1);
ArrayList<Integer> nodeNum = new ArrayList<>(2);
nodeNum.add(3);
nodeNum.add(1);
header.setNumberHiddenNodes(nodeNum);
header.setNumRecords(2);
ModelRecord record1 = new ModelRecord();
record1.setModelHeader(header);
// create records for input layer
ModelRecord record2 = new ModelRecord();
record2.setFeatureFamily("in");
record2.setFeatureName("a");
ArrayList<Double> in1 = new ArrayList<>();
in1.add(0.0);
in1.add(1.0);
in1.add(1.0);
record2.setWeightVector(in1);
ModelRecord record3 = new ModelRecord();
record3.setFeatureFamily("in");
record3.setFeatureName("b");
ArrayList<Double> in2 = new ArrayList<>();
in2.add(1.0);
in2.add(1.0);
in2.add(0.0);
record3.setWeightVector(in2);
// create records for bias
ModelRecord record4 = new ModelRecord();
ArrayList<Double> b1 = new ArrayList<>();
b1.add(0.0);
b1.add(0.0);
b1.add(0.0);
record4.setWeightVector(b1);
record4.setFunctionForm(FunctionForm.RELU);
ModelRecord record5 = new ModelRecord();
ArrayList<Double> b2 = new ArrayList<>();
b2.add(0.0);
record5.setWeightVector(b2);
record5.setFunctionForm(FunctionForm.RELU);
// create records for hidden layer
ModelRecord record6 = new ModelRecord();
ArrayList<Double> h1 = new ArrayList<>();
h1.add(0.5);
record6.setWeightVector(h1);
ModelRecord record7 = new ModelRecord();
ArrayList<Double> h2 = new ArrayList<>();
h2.add(1.0);
record7.setWeightVector(h2);
ModelRecord record8 = new ModelRecord();
ArrayList<Double> h3 = new ArrayList<>();
h3.add(2.0);
record8.setWeightVector(h3);
try {
writer.write(Util.encode(record1) + "\n");
writer.write(Util.encode(record2) + "\n");
writer.write(Util.encode(record3) + "\n");
writer.write(Util.encode(record4) + "\n");
writer.write(Util.encode(record5) + "\n");
writer.write(Util.encode(record6) + "\n");
writer.write(Util.encode(record7) + "\n");
writer.write(Util.encode(record8) + "\n");
writer.close();
} catch (IOException e) {
assertTrue("Could not write", false);
}
String serialized = charWriter.toString();
assertTrue(serialized.length() > 0);
StringReader strReader = new StringReader(serialized);
BufferedReader reader = new BufferedReader(strReader);
FeatureVector fv = makeFeatureVector();
try {
Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
assertTrue(model.isPresent());
float s = model.get().scoreItem(fv);
assertEquals(s, 6.0f, 1e-10f);
} catch (IOException e) {
assertTrue("Could not read", false);
}
}
use of com.airbnb.aerosolve.core.ModelHeader in project aerosolve by airbnb.
the class AdditiveModelTest method testLoad.
@Test
public void testLoad() {
CharArrayWriter charWriter = new CharArrayWriter();
BufferedWriter writer = new BufferedWriter(charWriter);
ModelHeader header = new ModelHeader();
header.setModelType("additive");
header.setNumRecords(4);
ArrayList<Double> ws = new ArrayList<Double>();
ws.add(5.0);
ws.add(10.0);
ws.add(-20.0);
ArrayList<Double> wl = new ArrayList<Double>();
wl.add(1.0);
wl.add(2.0);
ModelRecord record1 = new ModelRecord();
record1.setModelHeader(header);
ModelRecord record2 = new ModelRecord();
record2.setFunctionForm(FunctionForm.Spline);
record2.setFeatureFamily("spline_float");
record2.setFeatureName("aaa");
record2.setWeightVector(ws);
record2.setMinVal(1.0);
record2.setMaxVal(3.0);
ModelRecord record3 = new ModelRecord();
record3.setFunctionForm(FunctionForm.Spline);
record3.setFeatureFamily("spline_string");
record3.setFeatureName("bbb");
record3.setWeightVector(ws);
record3.setMinVal(1.0);
record3.setMaxVal(2.0);
ModelRecord record4 = new ModelRecord();
record4.setFunctionForm(FunctionForm.Linear);
record4.setFeatureFamily("linear_float");
record4.setFeatureName("ccc");
record4.setWeightVector(wl);
ModelRecord record5 = new ModelRecord();
record5.setFunctionForm(FunctionForm.Linear);
record5.setFeatureFamily("linear_string");
record5.setFeatureName("ddd");
record5.setWeightVector(wl);
try {
writer.write(Util.encode(record1) + "\n");
writer.write(Util.encode(record2) + "\n");
writer.write(Util.encode(record3) + "\n");
writer.write(Util.encode(record4) + "\n");
writer.write(Util.encode(record5) + "\n");
writer.close();
} catch (IOException e) {
assertTrue("Could not write", false);
}
String serialized = charWriter.toString();
assertTrue(serialized.length() > 0);
StringReader strReader = new StringReader(serialized);
BufferedReader reader = new BufferedReader(strReader);
FeatureVector featureVector = makeFeatureVector(2.0f, 7.0f);
try {
Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
assertTrue(model.isPresent());
float score = model.get().scoreItem(featureVector);
assertEquals(8.0f + 10.0f + 15.0f, score, 0.001f);
} catch (IOException e) {
assertTrue("Could not read", false);
}
}
Aggregations