open BinInt
open DItv
open DomAbs
open DomArrayBlk
open DomBasic
open DomMem
open Global
open InterNode
open SemMem
open String0
open Sumbool
open Syn
open TStr
open UserInputType

(** val eval_const : constant -> Val.t **)

let eval_const = function
| CInt64 v ->
  (match v with
   | Some z -> val_of_itv (Itv.of_int z)
   | None -> val_of_itv Itv.top)
| CChr z -> val_of_itv (Itv.of_int z)
| CReal (lb, ub) -> val_of_itv (Itv.of_ints lb ub)
| CEnum -> val_of_itv Itv.top

(** val eval_uop : unop -> Val.t -> Val.t **)

let eval_uop u v =
  if Val.eq_dec v Val.bot
  then Val.bot
  else (match u with
        | Neg -> val_of_itv (Itv.minus Itv.zero (itv_of_val v))
        | BNot -> val_of_itv (Itv.b_not_itv (itv_of_val v))
        | LNot -> val_of_itv (Itv.not_itv (itv_of_val v)))

(** val is_array_loc : Loc.t -> bool **)

let is_array_loc = function
| Loc.Inl a ->
  let (t0, _) = a in
  (match t0 with
   | VarAllocsite.Inl _ -> false
   | VarAllocsite.Inr _ -> true)
| Loc.Inr _ -> false

(** val array_loc_of_val : Val.t -> Val.t **)

let array_loc_of_val v =
  val_of_pow_loc (PowLoc.filter is_array_loc (pow_loc_of_val v))

(** val eval_bop : binop -> Val.t -> Val.t -> Val.t **)

let eval_bop b v1 v2 =
  match b with
  | PlusA -> val_of_itv (Itv.plus (itv_of_val v1) (itv_of_val v2))
  | MinusA -> val_of_itv (Itv.minus (itv_of_val v1) (itv_of_val v2))
  | MinusPI ->
    Val.join (array_loc_of_val v1)
      (val_of_array
        (ArrayBlk.minus_offset (array_of_val v1) (itv_of_val v2)))
  | MinusPP -> val_of_itv Itv.top
  | Mult -> val_of_itv (Itv.times (itv_of_val v1) (itv_of_val v2))
  | Div -> val_of_itv (Itv.divide (itv_of_val v1) (itv_of_val v2))
  | Mod -> val_of_itv (Itv.mod_itv (itv_of_val v1) (itv_of_val v2))
  | Shiftlt -> val_of_itv (Itv.l_shift_itv (itv_of_val v1) (itv_of_val v2))
  | Shiftrt -> val_of_itv (Itv.r_shift_itv (itv_of_val v1) (itv_of_val v2))
  | Lt -> val_of_itv (Itv.lt_itv (itv_of_val v1) (itv_of_val v2))
  | Gt -> val_of_itv (Itv.gt_itv (itv_of_val v1) (itv_of_val v2))
  | Le -> val_of_itv (Itv.le_itv (itv_of_val v1) (itv_of_val v2))
  | Ge -> val_of_itv (Itv.ge_itv (itv_of_val v1) (itv_of_val v2))
  | Eq -> val_of_itv (Itv.eq_itv (itv_of_val v1) (itv_of_val v2))
  | Ne -> val_of_itv (Itv.ne_itv (itv_of_val v1) (itv_of_val v2))
  | BAnd -> val_of_itv (Itv.b_and_itv (itv_of_val v1) (itv_of_val v2))
  | BXor -> val_of_itv (Itv.b_xor_itv (itv_of_val v1) (itv_of_val v2))
  | BOr -> val_of_itv (Itv.b_or_itv (itv_of_val v1) (itv_of_val v2))
  | LAnd -> val_of_itv (Itv.and_itv (itv_of_val v1) (itv_of_val v2))
  | LOr -> val_of_itv (Itv.or_itv (itv_of_val v1) (itv_of_val v2))
  | _ ->
    Val.join (array_loc_of_val v1)
      (val_of_array (ArrayBlk.plus_offset (array_of_val v1) (itv_of_val v2)))

(** val eval_string : char list -> Val.t **)

let eval_string _ =
  val_of_itv Itv.zero_pos

(** val eval_string_loc : char list -> Allocsite.t -> PowLoc.t -> Val.t **)

let eval_string_loc s a lvs =
  let i = Z.of_nat (Pervasives.succ (length s)) in
  Val.join (val_of_pow_loc lvs)
    (val_of_array (ArrayBlk.make a Itv.zero (Itv.of_int i) (Itv.of_int 1)))

(** val deref_of_val : Val.t -> PowLoc.t **)

let deref_of_val v =
  PowLoc.join (pow_loc_of_val v) (ArrayBlk.pow_loc_of_array (array_of_val v))

module Make =
 functor (M:Monad.Monad) ->
 functor (MB:sig
  val mem_find : Loc.t -> Mem.t -> Val.t M.m

  val mem_add : Loc.t -> Val.t -> Mem.t -> Mem.t M.m

  val mem_weak_add : Loc.t -> Val.t -> Mem.t -> Mem.t M.m

  val mem_pre_weak_add : Loc.t -> Val.t -> Mem.t -> Mem.t M.m
 end) ->
 struct
  module SemMem = Make(M)(MB)

  (** val eval_var : t -> GVar.t -> bool -> Loc.t **)

  let eval_var node x = function
  | true -> loc_of_var (var_of_gvar x)
  | false -> loc_of_var (var_of_lvar ((get_pid node), x))

  (** val eval : update_mode -> t -> exp -> Mem.t -> Val.t M.m **)

  let rec eval mode node e m0 =
    match e with
    | Const (c, _) -> M.ret (eval_const c)
    | Lval (l, _) ->
      M.bind (eval_lv mode node l m0) (fun lv -> SemMem.mem_lookup lv m0)
    | SizeOf (t_opt, _) ->
      let t_itv =
        match t_opt with
        | Some t0 -> Itv.of_int t0
        | None -> Itv.top
      in
      M.ret (val_of_itv t_itv)
    | SizeOfE (e_opt, _) ->
      let e_itv =
        match e_opt with
        | Some e0 -> Itv.of_int e0
        | None -> Itv.top
      in
      M.ret (val_of_itv e_itv)
    | SizeOfStr (s, _) ->
      let i = Z.of_nat (Pervasives.succ (length s)) in
      M.ret (val_of_itv (Itv.of_int i))
    | AlignOf (t0, _) -> M.ret (val_of_itv (Itv.of_int t0))
    | AlignOfE (_, _) -> M.ret (val_of_itv Itv.top)
    | UnOp (u, e0, _) ->
      M.bind (eval mode node e0 m0) (fun v -> M.ret (eval_uop u v))
    | BinOp (b, e1, e2, _) ->
      M.bind (eval mode node e1 m0) (fun v1 ->
        M.bind (eval mode node e2 m0) (fun v2 -> M.ret (eval_bop b v1 v2)))
    | Question (e1, e2, e3, _) ->
      M.bind (eval mode node e1 m0) (fun v1 ->
        let i1 = itv_of_val v1 in
        if Itv.eq_dec i1 Itv.bot
        then M.ret Val.bot
        else if Itv.eq_dec i1 Itv.zero
             then eval mode node e3 m0
             else if sumbool_not (Itv.le_dec Itv.zero i1)
                  then eval mode node e2 m0
                  else M.bind (eval mode node e2 m0) (fun v2 ->
                         M.bind (eval mode node e3 m0) (fun v3 ->
                           M.ret (Val.join v2 v3))))
    | CastE (new_stride_opt, e0, _) ->
      (match new_stride_opt with
       | Some new_stride ->
         M.bind (eval mode node e0 m0) (fun v ->
           let array_v = ArrayBlk.cast_array_int new_stride (array_of_val v)
           in
           M.ret (modify_array v array_v))
       | None -> eval mode node e0 m0)
    | AddrOf (l, _) ->
      M.bind (eval_lv mode node l m0) (fun lv -> M.ret (val_of_pow_loc lv))
    | StartOf (l, _) ->
      M.bind (eval_lv mode node l m0) (fun lv -> M.ret (val_of_pow_loc lv))

  (** val eval_lv : update_mode -> t -> lval -> Mem.t -> PowLoc.t M.m **)

  and eval_lv mode node lv m0 =
    let Coq_lval_intro (lhost', ofs, _) = lv in
    M.bind
      (match lhost' with
       | VarLhost (vi, is_global) ->
         let x = eval_var node vi is_global in
         M.ret (val_of_pow_loc (PowLoc.singleton x))
       | MemLhost e -> eval mode node e m0) (fun v ->
      resolve_offset mode node v ofs m0)

  (** val resolve_offset :
      update_mode -> t -> Val.t -> offset -> Mem.t -> PowLoc.t M.m **)

  and resolve_offset mode node v os m0 =
    match os with
    | NoOffset -> M.ret (deref_of_val v)
    | FOffset (f, os') ->
      resolve_offset mode node
        (val_of_pow_loc
          (PowLoc.join (pow_loc_append_field (pow_loc_of_val v) f)
            (ArrayBlk.pow_loc_of_struct_w_field (array_of_val v) f))) os' m0
    | IOffset (e, os') ->
      M.bind (eval mode node e m0) (fun idx ->
        M.bind (SemMem.mem_lookup (deref_of_val v) m0) (fun v' ->
          let v'0 =
            modify_array v'
              (ArrayBlk.plus_offset (array_of_val v') (itv_of_val idx))
          in
          resolve_offset mode node v'0 os' m0))

  (** val eval_list :
      update_mode -> t -> exp list -> Mem.t -> Val.t list M.m **)

  let rec eval_list mode node es m0 =
    match es with
    | [] -> M.ret []
    | e :: tl ->
      M.bind (eval mode node e m0) (fun v ->
        M.bind (eval_list mode node tl m0) (fun tl' -> M.ret (v :: tl')))

  (** val eval_alloc' : t -> Val.t -> Val.t **)

  let eval_alloc' node sz_v =
    let allocsite = allocsite_of_node node in
    let pow_loc = PowLoc.singleton (loc_of_allocsite allocsite) in
    Val.join (val_of_pow_loc pow_loc)
      (val_of_array
        (ArrayBlk.make allocsite Itv.zero (itv_of_val sz_v) (Itv.of_int 1)))

  (** val eval_alloc : update_mode -> t -> alloc -> Mem.t -> Val.t M.m **)

  let eval_alloc mode node a mem =
    M.bind (eval mode node a mem) (fun sz_v -> M.ret (eval_alloc' node sz_v))
 end