Search in sources :

Example 31 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class ModelFactory method createFromReader.

public static Optional<AbstractModel> createFromReader(BufferedReader reader) throws IOException {
    Optional<AbstractModel> model = Optional.absent();
    String headerLine = reader.readLine();
    ModelRecord record = Util.decodeModel(headerLine);
    if (record == null) {
        log.error("Could not decode header " + headerLine);
        return model;
    }
    ModelHeader header = record.getModelHeader();
    if (header != null) {
        AbstractModel result = createByName(header.getModelType());
        if (result != null) {
            result.loadInternal(header, reader);
            model = Optional.of(result);
        }
    }
    return model;
}
Also used : ModelHeader(com.airbnb.aerosolve.core.ModelHeader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Example 32 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class SplineModel method save.

@Override
public void save(BufferedWriter writer) throws IOException {
    ModelHeader header = new ModelHeader();
    header.setModelType("spline");
    header.setNumHidden(numBins);
    header.setSlope(slope);
    header.setOffset(offset);
    long count = 0;
    for (Map.Entry<String, Map<String, WeightSpline>> familyMap : weightSpline.entrySet()) {
        for (Map.Entry<String, WeightSpline> feature : familyMap.getValue().entrySet()) {
            count++;
        }
    }
    header.setNumRecords(count);
    ModelRecord headerRec = new ModelRecord();
    headerRec.setModelHeader(header);
    writer.write(Util.encode(headerRec));
    writer.newLine();
    for (Map.Entry<String, Map<String, WeightSpline>> familyMap : weightSpline.entrySet()) {
        for (Map.Entry<String, WeightSpline> feature : familyMap.getValue().entrySet()) {
            ModelRecord record = new ModelRecord();
            record.setFeatureFamily(familyMap.getKey());
            record.setFeatureName(feature.getKey());
            ArrayList<Double> arrayList = new ArrayList<Double>();
            for (int i = 0; i < feature.getValue().splineWeights.length; i++) {
                arrayList.add((double) feature.getValue().splineWeights[i]);
            }
            record.setWeightVector(arrayList);
            record.setMinVal(feature.getValue().spline.getMinVal());
            record.setMaxVal(feature.getValue().spline.getMaxVal());
            writer.write(Util.encode(record));
            writer.newLine();
        }
    }
    writer.flush();
}
Also used : ModelHeader(com.airbnb.aerosolve.core.ModelHeader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Example 33 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class BoostedStumpsModel method loadInternal.

@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
    long rows = header.getNumRecords();
    stumps = new ArrayList<>();
    for (long i = 0; i < rows; i++) {
        String line = reader.readLine();
        ModelRecord record = Util.decodeModel(line);
        stumps.add(record);
    }
}
Also used : ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Example 34 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class FullRankLinearModel method loadInternal.

@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
    long rows = header.getNumRecords();
    labelDictionary = new ArrayList<>();
    for (LabelDictionaryEntry entry : header.getLabelDictionary()) {
        labelDictionary.add(entry);
    }
    buildLabelToIndex();
    weightVector = new HashMap<>();
    for (long i = 0; i < rows; i++) {
        String line = reader.readLine();
        ModelRecord record = Util.decodeModel(line);
        String family = record.getFeatureFamily();
        String name = record.getFeatureName();
        Map<String, FloatVector> inner = weightVector.get(family);
        if (inner == null) {
            inner = new HashMap<>();
            weightVector.put(family, inner);
        }
        FloatVector vec = new FloatVector(record.getWeightVector().size());
        for (int j = 0; j < record.getWeightVector().size(); j++) {
            vec.values[j] = record.getWeightVector().get(j).floatValue();
        }
        inner.put(name, vec);
    }
}
Also used : LabelDictionaryEntry(com.airbnb.aerosolve.core.LabelDictionaryEntry) ModelRecord(com.airbnb.aerosolve.core.ModelRecord) FloatVector(com.airbnb.aerosolve.core.util.FloatVector)

Example 35 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class DecisionTreeTransform method doTransform.

@Override
public void doTransform(FeatureVector featureVector) {
    Map<String, Map<String, Double>> floatFeatures = featureVector.getFloatFeatures();
    if (floatFeatures == null) {
        return;
    }
    Util.optionallyCreateStringFeatures(featureVector);
    Map<String, Set<String>> stringFeatures = featureVector.getStringFeatures();
    Set<String> outputString = Util.getOrCreateStringFeature(outputLeaves, stringFeatures);
    Map<String, Double> outputFloat = Util.getOrCreateFloatFeature(outputScoreFamily, floatFeatures);
    int leafIdx = tree.getLeafIndex(floatFeatures);
    ModelRecord rec = tree.getStumps().get(leafIdx);
    outputString.add(rec.featureName);
    outputFloat.put(outputScoreName, rec.featureWeight);
}
Also used : ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Aggregations

ModelRecord (com.airbnb.aerosolve.core.ModelRecord)40 ModelHeader (com.airbnb.aerosolve.core.ModelHeader)14 Test (org.junit.Test)11 ArrayList (java.util.ArrayList)10 FloatVector (com.airbnb.aerosolve.core.util.FloatVector)8 Map (java.util.Map)5 FeatureVector (com.airbnb.aerosolve.core.FeatureVector)4 BufferedReader (java.io.BufferedReader)3 BufferedWriter (java.io.BufferedWriter)3 CharArrayWriter (java.io.CharArrayWriter)3 IOException (java.io.IOException)3 StringReader (java.io.StringReader)3 LabelDictionaryEntry (com.airbnb.aerosolve.core.LabelDictionaryEntry)2 AbstractFunction (com.airbnb.aerosolve.core.function.AbstractFunction)2 Function (com.airbnb.aerosolve.core.function.Function)2 NDTreeModelTest (com.airbnb.aerosolve.core.models.NDTreeModelTest)2 SupportVector (com.airbnb.aerosolve.core.util.SupportVector)2 HashMap (java.util.HashMap)2 DebugScoreRecord (com.airbnb.aerosolve.core.DebugScoreRecord)1 MulticlassScoringResult (com.airbnb.aerosolve.core.MulticlassScoringResult)1