Search in sources :

Example 1 with Optimize

use of com.microsoft.z3.Optimize in project batfish by batfish.

the class PropertyAdder method instrumentPathLength.

// Potentially useful in the future to optimize reachability when we know
// that there can't be routing loops e.g., due to a preliminary static analysis
/* public Map<String, BoolExpr> instrumentReachabilityFast(String router) {
    Context ctx = _encoderSlice.getCtx();
    Solver solver = _encoderSlice.getSolver();
    Map<String, BoolExpr> reachableVars = new HashMap<>();
    String sliceName = _encoderSlice.getSliceName();
    _encoderSlice
        .getGraph()
        .getConfigurations()
        .forEach(
            (r, conf) -> {
              int id = _encoderSlice.getEncoder().getId();
              String s2 = id + "_" + sliceName + "_reachable_" + r;
              BoolExpr var = ctx.mkBoolConst(s2);
              reachableVars.put(r, var);
              _encoderSlice.getAllVariables().put(var.toString(), var);
            });

    BoolExpr baseReach = reachableVars.get(router);
    _encoderSlice.add(baseReach);
    _encoderSlice
        .getGraph()
        .getEdgeMap()
        .forEach(
            (r, edges) -> {
              if (!r.equals(router)) {
                BoolExpr reach = reachableVars.get(r);
                BoolExpr hasRecursiveRoute = ctx.mkFalse();
                for (GraphEdge edge : edges) {
                  if (!edge.isAbstract()) {
                    BoolExpr fwd = _encoderSlice.getForwardsAcross().get(r, edge);
                    if (edge.getPeer() != null) {
                      BoolExpr peerReachable = reachableVars.get(edge.getPeer());
                      BoolExpr sendToReachable = ctx.mkAnd(fwd, peerReachable);
                      hasRecursiveRoute = ctx.mkOr(hasRecursiveRoute, sendToReachable);
                    }
                  }
                }
                solver.add(ctx.mkEq(reach, hasRecursiveRoute));
              }
            });

    return reachableVars;
  }

  public Map<String, BoolExpr> instrumentReachabilityFast(Set<GraphEdge> ges) {
    Context ctx = _encoderSlice.getCtx();
    Solver solver = _encoderSlice.getSolver();
    EncoderSlice slice = _encoderSlice;
    String sliceName = _encoderSlice.getSliceName();
    Graph g = slice.getGraph();
    Map<String, BoolExpr> reachableVars = new HashMap<>();

    _encoderSlice
        .getGraph()
        .getConfigurations()
        .forEach(
            (r, conf) -> {
              int id = _encoderSlice.getEncoder().getId();
              String s2 = id + "_" + sliceName + "_reachable_" + r;
              BoolExpr var = ctx.mkBoolConst(s2);
              reachableVars.put(r, var);
              _encoderSlice.getAllVariables().put(var.toString(), var);
            });

    for (Entry<String, List<GraphEdge>> entry : g.getEdgeMap().entrySet()) {
      String router = entry.getKey();
      List<GraphEdge> edges = entry.getValue();
      BoolExpr reach = reachableVars.get(router);

      // Add the base case, reachable if we forward to a directly connected interface
      BoolExpr hasDirectRoute = ctx.mkFalse();
      BoolExpr isAbsorbed = ctx.mkFalse();
      SymbolicRoute r = _encoderSlice.getBestNeighborPerProtocol(router, Protocol.CONNECTED);

      for (GraphEdge ge : edges) {
        if (!ge.isAbstract() && ges.contains(ge)) {
          // If a host, consider reachable
          if (g.isHost(router)) {
            hasDirectRoute = ctx.mkTrue();
            break;
          }
          // Reachable if we leave the network
          if (ge.getPeer() == null) {
            BoolExpr fwdIface = _encoderSlice.getForwardsAcross().get(ge.getRouter(), ge);
            assert (fwdIface != null);
            hasDirectRoute = ctx.mkOr(hasDirectRoute, fwdIface);
          }
          // Also reachable if connected route and we use it despite not forwarding
          if (r != null) {
            BitVecExpr dstIp = _encoderSlice.getSymbolicPacket().getDstIp();
            BitVecExpr ip = ctx.mkBV(ge.getStart().getIp().getIp().asLong(), 32);
            BoolExpr reachable = ctx.mkAnd(r.getPermitted(), ctx.mkEq(dstIp, ip));
            isAbsorbed = ctx.mkOr(isAbsorbed, reachable);
          }
        }
      }

      // Add the recursive case, where it is reachable through a neighbor
      BoolExpr hasRecursiveRoute = ctx.mkFalse();
      for (GraphEdge edge : edges) {
        if (!edge.isAbstract()) {
          BoolExpr fwd = _encoderSlice.getForwardsAcross().get(router, edge);
          if (edge.getPeer() != null) {
            BoolExpr peerReachable = reachableVars.get(edge.getPeer());
            BoolExpr sendToReachable = ctx.mkAnd(fwd, peerReachable);
            hasRecursiveRoute = ctx.mkOr(hasRecursiveRoute, sendToReachable);
          }
        }
      }

      BoolExpr cond = slice.mkOr(hasDirectRoute, isAbsorbed, hasRecursiveRoute);
      solver.add(slice.mkEq(reach, cond));
    }

    return reachableVars;
  } */
/*
   * Instruments the network with path length information to a
   * destination port corresponding to a graph edge ge.
   * A router has a path of length n if some neighbor has a path
   * with length n-1.
   */
Map<String, ArithExpr> instrumentPathLength(Set<GraphEdge> ges) {
    Context ctx = _encoderSlice.getCtx();
    Solver solver = _encoderSlice.getSolver();
    String sliceName = _encoderSlice.getSliceName();
    // Initialize path length variables
    Graph graph = _encoderSlice.getGraph();
    Map<String, ArithExpr> lenVars = new HashMap<>();
    for (String router : graph.getRouters()) {
        String name = _encoderSlice.getEncoder().getId() + "_" + sliceName + "_path-length_" + router;
        ArithExpr var = ctx.mkIntConst(name);
        lenVars.put(router, var);
        _encoderSlice.getAllVariables().put(var.toString(), var);
    }
    ArithExpr zero = ctx.mkInt(0);
    ArithExpr one = ctx.mkInt(1);
    ArithExpr minusOne = ctx.mkInt(-1);
    // Lower bound for all lengths
    lenVars.forEach((name, var) -> solver.add(ctx.mkGe(var, minusOne)));
    for (Entry<String, List<GraphEdge>> entry : graph.getEdgeMap().entrySet()) {
        String router = entry.getKey();
        List<GraphEdge> edges = entry.getValue();
        ArithExpr length = lenVars.get(router);
        // If there is a direct route, then we have length 0
        BoolExpr hasDirectRoute = ctx.mkFalse();
        BoolExpr isAbsorbed = ctx.mkFalse();
        SymbolicRoute r = _encoderSlice.getBestNeighborPerProtocol(router, Protocol.CONNECTED);
        for (GraphEdge ge : edges) {
            if (!ge.isAbstract() && ges.contains(ge)) {
                // Reachable if we leave the network
                if (ge.getPeer() == null) {
                    BoolExpr fwdIface = _encoderSlice.getForwardsAcross().get(ge.getRouter(), ge);
                    assert (fwdIface != null);
                    hasDirectRoute = ctx.mkOr(hasDirectRoute, fwdIface);
                }
                // Also reachable if connected route and we use it despite not forwarding
                if (r != null) {
                    BitVecExpr dstIp = _encoderSlice.getSymbolicPacket().getDstIp();
                    BitVecExpr ip = ctx.mkBV(ge.getStart().getAddress().getIp().asLong(), 32);
                    BoolExpr reach = ctx.mkAnd(r.getPermitted(), ctx.mkEq(dstIp, ip));
                    isAbsorbed = ctx.mkOr(isAbsorbed, reach);
                }
            }
        }
        // Otherwise, we find length recursively
        BoolExpr accNone = ctx.mkTrue();
        BoolExpr accSome = ctx.mkFalse();
        for (GraphEdge edge : edges) {
            if (!edge.isAbstract() && edge.getPeer() != null) {
                BoolExpr dataFwd = _encoderSlice.getForwardsAcross().get(router, edge);
                assert (dataFwd != null);
                ArithExpr peerLen = lenVars.get(edge.getPeer());
                accNone = ctx.mkAnd(accNone, ctx.mkOr(ctx.mkLt(peerLen, zero), ctx.mkNot(dataFwd)));
                ArithExpr newVal = ctx.mkAdd(peerLen, one);
                BoolExpr fwd = ctx.mkAnd(ctx.mkGe(peerLen, zero), dataFwd, ctx.mkEq(length, newVal));
                accSome = ctx.mkOr(accSome, fwd);
            }
        }
        BoolExpr guard = _encoderSlice.mkOr(hasDirectRoute, isAbsorbed);
        BoolExpr cond1 = _encoderSlice.mkIf(accNone, ctx.mkEq(length, minusOne), accSome);
        BoolExpr cond2 = _encoderSlice.mkIf(guard, ctx.mkEq(length, zero), cond1);
        solver.add(cond2);
    }
    return lenVars;
}
Also used : Context(com.microsoft.z3.Context) ArithExpr(com.microsoft.z3.ArithExpr) BoolExpr(com.microsoft.z3.BoolExpr) Solver(com.microsoft.z3.Solver) HashMap(java.util.HashMap) BitVecExpr(com.microsoft.z3.BitVecExpr) Graph(org.batfish.symbolic.Graph) List(java.util.List) GraphEdge(org.batfish.symbolic.GraphEdge)

Example 2 with Optimize

use of com.microsoft.z3.Optimize in project VERDICT by ge-high-assurance.

the class VerdictSynthesis method performSynthesisMaxSmt.

/**
 * Perform synthesis using Z3 MaxSMT.
 *
 * @param tree
 * @param targetDal
 * @param factory
 * @return
 * @deprecated use the multi-requirement approach instead
 */
@Deprecated
public static Optional<Pair<Set<ComponentDefense>, Double>> performSynthesisMaxSmt(DTree tree, int targetDal, DLeaf.Factory factory) {
    Context context = new Context();
    Optimize optimizer = context.mkOptimize();
    Collection<ComponentDefense> pairs = factory.allComponentDefensePairs();
    int costLcd = normalizeCosts(pairs);
    for (ComponentDefense pair : pairs) {
        if (pair.dalToNormCost(targetDal) > 0) {
            // this id ("cover") doesn't matter but we have to specify something
            optimizer.AssertSoft(context.mkNot(pair.toZ3(context)), pair.dalToNormCost(targetDal), "cover");
        }
    }
    optimizer.Assert(tree.toZ3(context));
    if (optimizer.Check().equals(Status.SATISFIABLE)) {
        Set<ComponentDefense> output = new LinkedHashSet<>();
        int totalNormalizedCost = 0;
        Model model = optimizer.getModel();
        for (ComponentDefense pair : pairs) {
            Expr expr = model.eval(pair.toZ3(context), true);
            switch(expr.getBoolValue()) {
                case Z3_L_TRUE:
                    output.add(pair);
                    totalNormalizedCost += pair.dalToNormCost(targetDal);
                    break;
                case Z3_L_FALSE:
                    break;
                case Z3_L_UNDEF:
                default:
                    throw new RuntimeException("Synthesis: Undefined variable in output model: " + pair.toString());
            }
        }
        return Optional.of(new Pair<>(output, ((double) totalNormalizedCost) / costLcd));
    } else {
        System.err.println("Synthesis: SMT not satisfiable, perhaps there are unmitigatable attacks");
        return Optional.empty();
    }
}
Also used : Context(com.microsoft.z3.Context) LinkedHashSet(java.util.LinkedHashSet) BoolExpr(com.microsoft.z3.BoolExpr) ArithExpr(com.microsoft.z3.ArithExpr) Expr(com.microsoft.z3.Expr) ComponentDefense(com.ge.verdict.synthesis.dtree.DLeaf.ComponentDefense) Model(com.microsoft.z3.Model) Optimize(com.microsoft.z3.Optimize)

Example 3 with Optimize

use of com.microsoft.z3.Optimize in project VERDICT by ge-high-assurance.

the class VerdictSynthesis method performSynthesisMultiple.

/**
 * Performs synthesis on multiple cyber requirements. This is the only version of synthesis that
 * should be used.
 *
 * @param tree the defense tree
 * @param factory the dleaf factory used to construct the defense tree
 * @param costModel the cost model
 * @param partialSolution whether we are using partial solutions
 * @param inputSat whether the input tree is satisfied
 * @param meritAssignment whether to perform merit assignment if the input is satisfied
 * @param dumpSmtLib whether to output the intermediate SMT-LIB file for debugging
 * @return the result, if successful
 */
public static Optional<ResultsInstance> performSynthesisMultiple(DTree tree, DLeaf.Factory factory, CostModel costModel, boolean partialSolution, boolean inputSat, boolean meritAssignment, boolean dumpSmtLib) {
    Context context = new Context();
    Optimize optimizer = context.mkOptimize();
    System.out.println("performSynthesisMultiple, configuration: partialSolution=" + partialSolution + ", inputSat=" + inputSat + ", meritAssignment=" + meritAssignment);
    Collection<ComponentDefense> compDefPairs = factory.allComponentDefensePairs();
    // Encode the logical structure of defense tree in MaxSMT
    // Encode cost(impl_defense_dal) >= the target cost (based on the severity of cyber req)
    // With merit assignment off and partial solution on, you will
    // subtract the cost of each DAL by the impl DAL cost, and use 0 if the result is negative.
    optimizer.Assert(tree.toZ3Multi(context));
    if (meritAssignment) {
        // set upper bounds at the current values, so that no upgrades are reported
        // Encode component_defense_var <= the the implemented DAL cost
        optimizer.Assert(context.mkAnd(compDefPairs.stream().map(pair -> context.mkLe(pair.toZ3Multi(context), DLeaf.fractionToZ3(pair.dalToRawCost(pair.implDal), context))).collect(Collectors.toList()).toArray(new BoolExpr[] {})));
    }
    // Make all component-defense var >= Cost(DAL(0));
    optimizer.Assert(// to DAL, then this can be changed to the minimum of all costs.
    context.mkAnd(compDefPairs.stream().map(pair -> context.mkGe(pair.toZ3Multi(context), DLeaf.fractionToZ3(pair.dalToRawCost(0), context))).collect(Collectors.toList()).toArray(new BoolExpr[] {})));
    // Encode objective function: the sum of all component-defense vars
    if (compDefPairs.isEmpty()) {
        optimizer.MkMinimize(context.mkInt(0));
    } else {
        optimizer.MkMinimize(context.mkAdd(compDefPairs.stream().map(pair -> pair.toZ3Multi(context)).collect(Collectors.toList()).toArray(new ArithExpr[] {})));
    }
    if (dumpSmtLib) {
        try {
            // this dumps the file in the working directory, i.e. where the process was started
            // from
            PrintWriter writer = new PrintWriter("verdict-synthesis-dump.smtlib", "UTF-8");
            writer.println(optimizer.toString());
            writer.flush();
            writer.close();
        } catch (FileNotFoundException | UnsupportedEncodingException e) {
            e.printStackTrace();
        }
    }
    if (optimizer.Check().equals(Status.SATISFIABLE)) {
        List<ResultsInstance.Item> items = new ArrayList<>();
        Fraction totalInputCost = new Fraction(0), totalOutputCost = new Fraction(0);
        Model model = optimizer.getModel();
        for (ComponentDefense pair : compDefPairs) {
            // get the value in the model
            RatNum expr = (RatNum) model.eval(pair.toZ3Multi(context), true);
            Fraction rawCost = new Fraction(expr.getNumerator().getInt(), expr.getDenominator().getInt());
            // convert back to DAL (the value in the model is cost rather than DAL)
            int dal = pair.rawCostToDal(rawCost);
            // but we don't trust the cost obtained directly from the model.
            // instead, we re-calculate using the cost model because it is less prone to failure
            // The cost of implemented DAL
            Fraction inputCost = costModel.cost(pair.defenseProperty, pair.component, pair.implDal);
            // The cost of output DAL from SMT
            Fraction outputCost = costModel.cost(pair.defenseProperty, pair.component, dal);
            // keep track of total cost
            totalInputCost = totalInputCost.add(inputCost);
            totalOutputCost = totalOutputCost.add(outputCost);
            items.add(new ResultsInstance.Item(pair.component, pair.defenseProperty, pair.implDal, dal, inputCost, outputCost));
        }
        return Optional.of(new ResultsInstance(partialSolution, meritAssignment, inputSat, totalInputCost, totalOutputCost, items));
    } else {
        System.err.println("Synthesis: SMT not satisfiable, perhaps there are unmitigatable attacks");
        return Optional.empty();
    }
}
Also used : Context(com.microsoft.z3.Context) FileNotFoundException(java.io.FileNotFoundException) ArrayList(java.util.ArrayList) UnsupportedEncodingException(java.io.UnsupportedEncodingException) Fraction(org.apache.commons.math3.fraction.Fraction) ResultsInstance(com.ge.verdict.vdm.synthesis.ResultsInstance) ComponentDefense(com.ge.verdict.synthesis.dtree.DLeaf.ComponentDefense) Model(com.microsoft.z3.Model) RatNum(com.microsoft.z3.RatNum) Optimize(com.microsoft.z3.Optimize) PrintWriter(java.io.PrintWriter)

Aggregations

Context (com.microsoft.z3.Context)3 ComponentDefense (com.ge.verdict.synthesis.dtree.DLeaf.ComponentDefense)2 ArithExpr (com.microsoft.z3.ArithExpr)2 BoolExpr (com.microsoft.z3.BoolExpr)2 Model (com.microsoft.z3.Model)2 Optimize (com.microsoft.z3.Optimize)2 ResultsInstance (com.ge.verdict.vdm.synthesis.ResultsInstance)1 BitVecExpr (com.microsoft.z3.BitVecExpr)1 Expr (com.microsoft.z3.Expr)1 RatNum (com.microsoft.z3.RatNum)1 Solver (com.microsoft.z3.Solver)1 FileNotFoundException (java.io.FileNotFoundException)1 PrintWriter (java.io.PrintWriter)1 UnsupportedEncodingException (java.io.UnsupportedEncodingException)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 LinkedHashSet (java.util.LinkedHashSet)1 List (java.util.List)1 Fraction (org.apache.commons.math3.fraction.Fraction)1 Graph (org.batfish.symbolic.Graph)1