open DomAbs
open DomBasic
open DomMem
open Global
open InterCfg
open Syn
open TStr
open UserInputType

module Make =
 functor (M:Monad.Monad) ->
 functor (MB:sig
  val mem_find : Loc.t -> Mem.t -> Val.t M.m

  val mem_add : Loc.t -> Val.t -> Mem.t -> Mem.t M.m

  val mem_weak_add : Loc.t -> Val.t -> Mem.t -> Mem.t M.m

  val mem_pre_weak_add : Loc.t -> Val.t -> Mem.t -> Mem.t M.m
 end) ->
 struct
  (** val can_strong_update : update_mode -> G.t -> Loc.t -> bool **)

  let can_strong_update mode g l =
    match mode with
    | Weak -> false
    | Strong -> approx_one_loc g l

  (** val mem_lookup : PowLoc.t -> Mem.t -> Val.t M.m **)

  let mem_lookup lvs m0 =
    let find_join = fun loc acc_a ->
      M.bind acc_a (fun acc ->
        M.bind (MB.mem_find loc m0) (fun v -> M.ret (Val.join acc v)))
    in
    PowLoc.fold find_join lvs (M.ret Val.bot)

  (** val add : Loc.t -> Val.t -> Mem.t -> Mem.t M.m **)

  let add =
    MB.mem_add

  (** val weak_add : update_mode -> Loc.t -> Val.t -> Mem.t -> Mem.t M.m **)

  let weak_add = function
  | Weak -> MB.mem_pre_weak_add
  | Strong -> MB.mem_weak_add

  (** val mem_update :
      update_mode -> G.t -> Loc.t -> Val.t -> Mem.t -> Mem.t M.m **)

  let mem_update mode g l v m0 =
    if can_strong_update mode g l then add l v m0 else weak_add mode l v m0

  (** val mem_wupdate :
      update_mode -> PowLoc.t -> Val.t -> Mem.t -> Mem.t M.m **)

  let mem_wupdate mode lvs v m0 =
    let weak_add_v = fun lv m_a ->
      M.bind m_a (fun m1 -> weak_add mode lv v m1)
    in
    PowLoc.fold weak_add_v lvs (M.ret m0)

  (** val list_fold2_m :
      ('a1 -> 'a2 -> 'a3 -> 'a3 M.m) -> 'a1 list -> 'a2 list -> 'a3 -> 'a3 M.m **)

  let rec list_fold2_m f l1 l2 acc =
    match l1 with
    | [] -> M.ret acc
    | a :: l1' ->
      (match l2 with
       | [] -> M.ret acc
       | b :: l2' ->
         M.bind (f a b acc) (fun acc' -> list_fold2_m f l1' l2' acc'))

  (** val bind_arg :
      update_mode -> string_t -> string_t -> Val.t -> Mem.t -> Mem.t M.m **)

  let bind_arg mode f x v m0 =
    mem_wupdate mode (PowLoc.singleton (loc_of_var (var_of_lvar (f, x)))) v
      m0

  (** val bind_args :
      update_mode -> G.t -> Val.t list -> pid_t -> Mem.t -> Mem.t M.m **)

  let bind_args mode g vs f m0 =
    match get_args g.G.icfg f with
    | Some args -> list_fold2_m (bind_arg mode f) args vs m0
    | None -> M.ret m0
 end