Include trusted "cnf.g".
Include trusted "cnf-util.g".

%=============================================================================
% 'assignment' type
%=============================================================================

Inductive assignment : type :=
  UN : assignment
| TT : assignment
| FF : assignment
.

Define is_assigned :=
  fun(a:assignment).
  match a with
    UN => ff
  | TT => tt
  | FF => tt
  end.

Define lit_assignment :=
  fun(^#owned l:lit).
  match (lit_sign l) with
    ff => FF
  | tt => TT
  end.

Define ulit_assignment :=
  fun(l:ulit).
  match (ulit_sign l) with
    ff => FF
  | tt => TT
  end.

Define is_compat_assignment :=
  fun(a:assignment)(b:bool).
  match a with
    UN => tt
  | TT => b
  | FF => (not b)
  end.

Define is_equal_assignment := % stricter than is_compat_assignment
  fun(a:assignment)(b:bool).
  match a with
    UN => ff
  | TT => b
  | FF => (not b)
  end.


%=============================================================================
% low-level functions for assignment array
%=============================================================================

Define unassigned2 :=
  fun(spec nv:word)
     (nv_ub:{ (ltword nv var_upper_bound) = tt })
     (^#unique_owned pa:<uwarray assignment (inc_nv nv nv_ub)>)
     (v:word)
     (u:{ (leword v nv) = tt }).
  abbrev p = [leword_nv_implies_ltword_inc_nv nv nv_ub v u] in
  match (uwarray_get assignment (inc_nv nv nv_ub) pa v p) with
    default => ff
  | UN => tt
  end.


%=============================================================================
% assignment state
%=============================================================================

Inductive AssignState : Fun(nv:word)(F:formula).type :=
  assign_state :
    Fun(spec nv:word)  % # of variables
       (spec F:formula)
       (nv_ub:{ (ltword nv var_upper_bound) = tt })
       (#unique pa:<uwarray assignment (inc_nv nv nv_ub)>)
       (#unique why:<warray <option <aclause nv F>> (inc_nv nv nv_ub)>)
       (#unique dls:<uwarray word (inc_nv nv nv_ub)>)
       (#unique hist:<uwarray ulit nv>)
       (hist_cur:word) % propagated until this point
       (hist_end:word) % assigned until this point
       %(hist_ub1: { (leword hist_end nv) = tt })
       %(hist_ub2: { (leword hist_cur hist_end) = tt })
       %(hist_valid: all literals in hist are valid unto hist_end)
      . #unique <AssignState nv F>
.

Define newAssignState :=
  fun(nv:word)
     (spec F:formula)
     (nv_ub:{ (ltword nv var_upper_bound) = tt })
    : #unique <AssignState nv F>
  .
  abbrev nv' = (inc_nv nv nv_ub) in
  let pa = (uwarray_new assignment nv' UN) in
  let why =
    let val = (nothing <aclause nv F>) in
    let why' = (warray_new <option <aclause nv F>> nv' (inspect <option <aclause nv F>> val)) in
    do (dec <option <aclause nv F>> val) why' end in
  let dls = (uwarray_new word nv' word0) in
  let hist = (uwarray_new ulit nv ulit_null) in
  %abbrev p1 = hypjoin (leword 0x0 nv) tt by [leZ (to_nat wordlen nv)] end in
  %abbrev p2 = join (leword 0x0 0x0) tt in
  (assign_state nv F nv_ub pa why dls hist 0x0 0x0 %-p1 p2-%)
  .

Define get_assignment :=
  fun(spec nv:word)(spec F:formula)
     (!#unique as:<AssignState nv F>)
     (v:word)
     (u:{ (leword v nv) = tt }).
  let asi = (inspect_unique <AssignState nv F> as) in
  let rval = 
    match asi with assign_state _ _ nv_ub pa _ _ _ _ _ =>
    abbrev p = [leword_nv_implies_ltword_inc_nv nv nv_ub v u] in
    (uwarray_get assignment (inc_nv nv nv_ub) pa v p)
    end in
  do
    (consume_unique_owned <AssignState nv F> asi)
    rval
  end.

Define unassigned :=
  fun(spec nv:word)(spec F:formula)
     (!#unique as:<AssignState nv F>)
     (v:word)
     (u:{ (leword v nv) = tt }).
  let asi = (inspect_unique <AssignState nv F> as) in
  let rval = 
    match asi with assign_state _ _ nv_ub pa _ _ _ _ _ =>
    abbrev p = [leword_nv_implies_ltword_inc_nv nv nv_ub v u] in
    match (uwarray_get assignment (inc_nv nv nv_ub) pa v p) with
      default => ff
    | UN => tt
    end
    end in
  do
    (consume_unique_owned <AssignState nv F> asi)
    rval
  end.

Define unassigned_implies_leword :
  Forall(nv:word)(F:formula)
        (as:<AssignState nv F>)
        (v:word)
        (u:{ (unassigned as v) = tt }).
    { (leword v nv) = tt }
  :=
  foralli(nv:word)(F:formula)
         (as:<AssignState nv F>)
         (v:word)
         (u:{ (unassigned as v) = tt }).
  case as with assign_state _ _ nv_ub pa _ _ _ _ _ =>
  case (leword v nv) by q1 _ with
    ff =>
      abbrev p1 = [leword_nv_eq_ltword_inc_nv nv nv_ub v] in
      abbrev p2 = hypjoin (ltword v (inc_nv nv)) ff by p1 q1 end in
      abbrev p3 = trans symm [ltword_to_lt v (inc_nv nv nv_ub)] p2 in
      abbrev p4 = [vec_get_abort assignment (word_to_nat (inc_nv nv nv_ub)) pa (word_to_nat v) p3] in
      contra
      trans symm u
      trans hypjoin (unassigned as v) abort ! by p4 as_eq end
            aclash tt
      { (leword v nv) = tt }
  | tt => q1
  end
  end.

Define assign :=
  fun(nv:word)
     (spec F:formula)
     (#unique as:<AssignState nv F>)
     (l:ulit)
     (w:<option <aclause nv F>>)
     (dl:word)
     (u:{ (unassigned as (ulit_vnum l)) = tt }): #unique <AssignState nv F>.
  abbrev u' = [unassigned_implies_leword nv F as (ulit_vnum l) u] in
  match as with
    assign_state _ _ nv_ub pa why dls
                 hist hist_cur hist_end %-hist_ub1 hist_ub2-% =>
      abbrev nv' = (inc_nv nv nv_ub) in
      let b = (ulit_assignment l) in
      let vv = (ulit_vnum l) in
      
      % want q1: (leword vv nv) = tt
      abbrev q1 = hypjoin (leword vv nv) tt by u' vv_eq end in
      
      % want q2: (ltword vv nv') = tt
      % have q2_1: (ltword vv (inc_nv nv)) = tt
      abbrev q2_1 = [leword_nv_implies_ltword_inc_nv nv nv_ub vv q1] in
      abbrev q2 = hypjoin (ltword vv nv') tt by q2_1 end in
      
      let pa' = (uwarray_set assignment nv' pa vv b q2) in
      let why' = (warray_set <option <aclause nv F>> vv w nv' why q2) in
      let dls' = (uwarray_set word nv' dls vv dl q2) in
      match (ltword hist_end nv) by q2 Q2 with
        ff => abort <AssignState nv F>
      | tt =>
          let hist' = (uwarray_set ulit nv hist hist_end l q2) in
          abbrev p1 = [ltword_implies_ltword_word_max hist_end nv q2] in 
          let hist_end' = (word_inc_safe hist_end p1) in
          (assign_state nv F nv_ub pa' why' dls'
            hist' hist_cur hist_end' %-p4 p3-%)
      end
  end.


%==============================================================================
% assignment lemmas
%==============================================================================

Define is_assigned_total :=
  foralli(a:assignment).
  case a with
    UN => existsi ff { (is_assigned a) = * }
            hypjoin (is_assigned a) ff by a_eq end
  | TT => existsi tt { (is_assigned a) = * }
            hypjoin (is_assigned a) tt by a_eq end
  | FF => existsi tt { (is_assigned a) = * }
            hypjoin (is_assigned a) tt by a_eq end
  end.

Total is_assigned is_assigned_total.

Define lit_assignment_total :=
  foralli(l:lit).
  case (lit_sign l) by q1 _ with
    ff => existsi FF { (lit_assignment l) = * }
            hypjoin (lit_assignment l) FF by q1 end
  | tt => existsi TT { (lit_assignment l) = * }
            hypjoin (lit_assignment l) TT by q1 end
  end.

Total lit_assignment lit_assignment_total.

Define ulit_assignment_total :=
  foralli(l:ulit).
  case (ulit_sign l) by q1 _ with
    ff => existsi FF { (ulit_assignment l) = * }
            hypjoin (ulit_assignment l) FF by q1 end
  | tt => existsi TT { (ulit_assignment l) = * }
            hypjoin (ulit_assignment l) TT by q1 end
  end.

Total ulit_assignment ulit_assignment_total.

Define is_assigned_ff :=
  foralli(a:assignment)(r:{ (is_assigned a) = ff }).
  case a with
    UN => a_eq
  | TT => contra
            trans symm hypjoin (is_assigned a) tt by a_eq end
            trans r
                  clash ff tt
            { a = UN }
  | FF => contra
            trans symm hypjoin (is_assigned a) tt by a_eq end
            trans r
                  clash ff tt
            { a = UN }
  end.
  
Define is_compat_assignment_lem1 :
  Forall(a:assignment)(l:lit)
        (u1:{ (is_assigned a) = tt })
        (u2:{ (is_compat_assignment a (lit_sign l)) = tt })
  .{ a = (lit_assignment l) }
  :=
  foralli(a:assignment)(l:lit)
         (u1:{ (is_assigned a) = tt })
         (u2:{ (is_compat_assignment a (lit_sign l)) = tt })
  .
  case a with
    UN => contra
            trans symm u1
            trans hypjoin (is_assigned a) ff by a_eq end
                  clash ff tt
            { a = (lit_assignment l) }
  | TT => abbrev p1 = hypjoin (lit_sign l) tt by u2 a_eq end in
          hypjoin a (lit_assignment l) by p1 a_eq end
  | FF => abbrev p1_1 = hypjoin (not (lit_sign l)) tt by u2 a_eq end in
          abbrev p1 = [not_tt (lit_sign l) p1_1] in
          hypjoin a (lit_assignment l) by p1 a_eq end
  end.

Define is_equal_assignment_total :=
  foralli(a:assignment)(s:bool).
  case a with
    UN => existsi ff { (is_equal_assignment a s) = * }
            hypjoin (is_equal_assignment a s) ff by a_eq end
  | TT => existsi s { (is_equal_assignment a s) = * }
            hypjoin (is_equal_assignment a s) s by a_eq end
  | FF => existsi (not s) { (is_equal_assignment a s) = * }
            hypjoin (is_equal_assignment a s) (not s) by a_eq end
  end.

Total is_equal_assignment is_equal_assignment_total

Define assign_state_implies_nv_ub :=
  foralli(nv:word)(F:formula)
         (as:<AssignState nv F>).
  case as with assign_state nv' _ nv_ub _ _ _ _ _ _ =>
    trans cong (ltword * var_upper_bound) join nv nv'
          nv_ub
  end
  .
