use of com.airbnb.aerosolve.core.util.SupportVector in project aerosolve by airbnb.
the class KernelModel method onlineUpdate.
@Override
public void onlineUpdate(float grad, float learningRate, Map<String, Map<String, Double>> flatFeatures) {
FloatVector vec = dictionary.makeVectorFromSparseFloats(flatFeatures);
float deltaG = -learningRate * grad;
for (SupportVector sv : supportVectors) {
float response = sv.evaluateUnweighted(vec);
float deltaW = deltaG * response;
sv.setWeight(sv.getWeight() + deltaW);
}
}
use of com.airbnb.aerosolve.core.util.SupportVector in project aerosolve by airbnb.
the class KernelModel method save.
@Override
public void save(BufferedWriter writer) throws IOException {
ModelHeader header = new ModelHeader();
header.setModelType("kernel");
header.setDictionary(dictionary.getDictionary());
long count = supportVectors.size();
header.setNumRecords(count);
ModelRecord headerRec = new ModelRecord();
headerRec.setModelHeader(header);
writer.write(Util.encode(headerRec));
writer.newLine();
for (SupportVector sv : supportVectors) {
writer.write(Util.encode(sv.toModelRecord()));
writer.newLine();
}
writer.flush();
}
use of com.airbnb.aerosolve.core.util.SupportVector in project aerosolve by airbnb.
the class KernelModel method loadInternal.
@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
long rows = header.getNumRecords();
dictionary = new StringDictionary(header.getDictionary());
supportVectors = new ArrayList<>();
for (long i = 0; i < rows; i++) {
String line = reader.readLine();
ModelRecord record = Util.decodeModel(line);
supportVectors.add(new SupportVector(record));
}
}
use of com.airbnb.aerosolve.core.util.SupportVector in project aerosolve by airbnb.
the class KernelModel method scoreItem.
@Override
public float scoreItem(FeatureVector combinedItem) {
Map<String, Map<String, Double>> flatFeatures = Util.flattenFeature(combinedItem);
FloatVector vec = dictionary.makeVectorFromSparseFloats(flatFeatures);
float sum = 0.0f;
for (int i = 0; i < supportVectors.size(); i++) {
SupportVector sv = supportVectors.get(i);
sum += sv.evaluate(vec);
}
return sum;
}
Aggregations