Quot/quotient_term.ML
author Christian Urban <urbanc@in.tum.de>
Sat, 26 Dec 2009 07:15:30 +0100
changeset 790 3a48ffcf0f9a
parent 786 d6407afb913c
child 791 fb4bfbb1a291
permissions -rw-r--r--
generalised absrep function; needs consolidation

signature QUOTIENT_TERM =
sig
   exception LIFT_MATCH of string
 
   datatype flag = absF | repF
   
   val absrep_fun: flag -> Proof.context -> (typ * typ) -> term
   val absrep_fun_chk: flag -> Proof.context -> (typ * typ) -> term

   val regularize_trm: Proof.context -> (term * term) -> term
   val regularize_trm_chk: Proof.context -> (term * term) -> term
   
   val inj_repabs_trm: Proof.context -> (term * term) -> term
   val inj_repabs_trm_chk: Proof.context -> (term * term) -> term
end;

structure Quotient_Term: QUOTIENT_TERM =
struct

open Quotient_Info;

exception LIFT_MATCH of string

(*******************************)
(* Aggregate Rep/Abs Functions *)
(*******************************)

(* The flag repF is for types in negative position, while absF is   *) 
(* for types in positive position. Because of this, function types  *)
(* need to be treated specially, since there the polarity changes.  *)

datatype flag = absF | repF

fun negF absF = repF
  | negF repF = absF

val mk_id = Const (@{const_name "id"}, dummyT)
fun mk_identity ty = Const (@{const_name "id"}, ty --> ty)

fun mk_Free (TVar ((x, i), _)) = Free (unprefix "'" x ^ string_of_int i, dummyT)

fun mk_compose flag (trm1, trm2) = 
  case flag of
    absF => Const (@{const_name "comp"}, dummyT) $ trm1 $ trm2
  | repF => Const (@{const_name "comp"}, dummyT) $ trm2 $ trm1

fun get_mapfun ctxt s =
let
  val thy = ProofContext.theory_of ctxt
  val exc = LIFT_MATCH ("No map function for type " ^ (quote s) ^ " found.")
  val mapfun = #mapfun (maps_lookup thy s) handle NotFound => raise exc
in
  Const (mapfun, dummyT)
end

fun mk_mapfun ctxt vs ty =
let
  val vs' = map (mk_Free) vs

  fun mk_mapfun_aux ty =
    case ty of
      TVar _ => mk_Free ty
    | Type (_, []) => mk_id
    | Type (s, tys) => list_comb (get_mapfun ctxt s, map mk_mapfun_aux tys)
    | _ => raise LIFT_MATCH ("mk_mapfun_aux (default)")
in
  fold_rev Term.lambda vs' (mk_mapfun_aux ty)
end

fun get_rty_qty ctxt s =
let
  val thy = ProofContext.theory_of ctxt
  val exc = LIFT_MATCH ("No quotient type " ^ (quote s) ^ " found.")
  val qdata = (quotdata_lookup thy s) handle NotFound => raise exc
in
  (#rtyp qdata, #qtyp qdata)
end

fun double_lookup rtyenv qtyenv v =
let
  val v' = fst (dest_TVar v)
in
  (snd (the (Vartab.lookup rtyenv v')), snd (the (Vartab.lookup qtyenv v')))
end

fun absrep_const flag ctxt qty_str =
let
  val thy = ProofContext.theory_of ctxt
  val qty_name = Long_Name.base_name qty_str
in
  case flag of
    absF => Const (Sign.full_bname thy ("abs_" ^ qty_name), dummyT)
  | repF => Const (Sign.full_bname thy ("rep_" ^ qty_name), dummyT)
end

fun absrep_fun flag ctxt (rty, qty) =
  if rty = qty  
  then mk_identity qty
  else
    case (rty, qty) of
      (Type ("fun", [ty1, ty2]), Type ("fun", [ty1', ty2'])) =>
        let
          val arg1 = absrep_fun (negF flag) ctxt (ty1, ty1')
          val arg2 = absrep_fun flag ctxt (ty2, ty2')
        in
          list_comb (get_mapfun ctxt "fun", [arg1, arg2])
        end
    | (Type (s, tys), Type (s', tys')) =>
        if s = s'
        then 
           let
             val args = map (absrep_fun flag ctxt) (tys ~~ tys')
           in
             list_comb (get_mapfun ctxt s, args)
           end
        else
           let
             val thy = ProofContext.theory_of ctxt
             val (rty_pat, qty_pat as Type (_, vs)) = get_rty_qty ctxt s'
             val rtyenv = Sign.typ_match thy (rty_pat, rty) Vartab.empty
             val qtyenv = Sign.typ_match thy (qty_pat, qty) Vartab.empty
             val args_aux = map (double_lookup rtyenv qtyenv) vs            
             val args = map (absrep_fun flag ctxt) args_aux
             val map_fun = mk_mapfun ctxt vs rty_pat       
             val result = list_comb (map_fun, args) 
           in
             mk_compose flag (absrep_const flag ctxt s', result)
           end 
    | (TFree x, TFree x') =>
        if x = x'
        then mk_identity qty
        else raise (LIFT_MATCH "absrep_fun (frees)")
    | (TVar _, TVar _) => raise (LIFT_MATCH "absrep_fun (vars)")
    | _ => 
         let
           val rty_str = Syntax.string_of_typ ctxt rty
           val qty_str = Syntax.string_of_typ ctxt qty
         in
           raise (LIFT_MATCH ("absrep_fun (default) " ^ rty_str ^ " " ^ qty_str))
         end

fun absrep_fun_chk flag ctxt (rty, qty) =
let
  val rty_str = Syntax.string_of_typ ctxt rty
  val qty_str = Syntax.string_of_typ ctxt qty
  val _ = tracing "rty / qty"
  val _ = tracing rty_str
  val _ = tracing qty_str
in
  absrep_fun flag ctxt (rty, qty)
  |> Syntax.check_term ctxt
end

(* Regularizing an rtrm means:
 
 - Quantifiers over types that need lifting are replaced 
   by bounded quantifiers, for example:

      All P  ----> All (Respects R) P

   where the aggregate relation R is given by the rty and qty;
 
 - Abstractions over types that need lifting are replaced
   by bounded abstractions, for example:
      
      %x. P  ----> Ball (Respects R) %x. P

 - Equalities over types that need lifting are replaced by
   corresponding equivalence relations, for example:

      A = B  ----> R A B

   or 

      A = B  ----> (R ===> R) A B
 
   for more complicated types of A and B
*)

(* instantiates TVars so that the term is of type ty *)
fun force_typ thy trm ty =
let
  val trm_ty = fastype_of trm
  val ty_inst = Sign.typ_match thy (trm_ty, ty) Vartab.empty
in
  map_types (Envir.subst_type ty_inst) trm
end

(* builds the aggregate equivalence relation *)
(* that will be the argument of Respects     *)
fun mk_resp_arg ctxt (rty, qty) =
let
  val thy = ProofContext.theory_of ctxt
in  
  if rty = qty
  then HOLogic.eq_const rty
  else
    case (rty, qty) of
      (Type (s, tys), Type (s', tys')) =>
       if s = s' 
       then 
         let
           val exc = LIFT_MATCH ("mk_resp_arg (no relation map function found for type " ^ s ^ ")") 
           val relmap = #relmap (maps_lookup thy s) handle NotFound => raise exc
           val args = map (mk_resp_arg ctxt) (tys ~~ tys')
         in
           list_comb (Const (relmap, dummyT), args) 
         end  
       else 
         let
           val exc = LIFT_MATCH ("mk_resp_arg (no quotient found for type " ^ s ^ ")") 
           val equiv_rel = #equiv_rel (quotdata_lookup thy s') handle NotFound => raise exc
           (* FIXME: check in this case that the rty and qty *)
           (* FIXME: correspond to each other *)

           (* we need to instantiate the TVars in the relation *)
           val thy = ProofContext.theory_of ctxt 
           val forced_equiv_rel = force_typ thy equiv_rel (rty --> rty --> @{typ bool})
         in
           forced_equiv_rel
         end
      | _ => HOLogic.eq_const rty
             (* FIXME: check that the types correspond to each other? *)
end

val mk_babs = Const (@{const_name Babs}, dummyT)
val mk_ball = Const (@{const_name Ball}, dummyT)
val mk_bex  = Const (@{const_name Bex}, dummyT)
val mk_resp = Const (@{const_name Respects}, dummyT)

(* - applies f to the subterm of an abstraction,   *)
(*   otherwise to the given term,                  *)
(* - used by regularize, therefore abstracted      *)
(*   variables do not have to be treated specially *)
fun apply_subt f (trm1, trm2) =
  case (trm1, trm2) of
    (Abs (x, T, t), Abs (_ , _, t')) => Abs (x, T, f (t, t'))
  | _ => f (trm1, trm2)

(* the major type of All and Ex quantifiers *)
fun qnt_typ ty = domain_type (domain_type ty)  


(* produces a regularized version of rtrm       *)
(*                                              *)
(* - the result might contain dummyTs           *)
(*                                              *)
(* - for regularisation we do not need any      *)
(*   special treatment of bound variables       *)

fun regularize_trm ctxt (rtrm, qtrm) =
  case (rtrm, qtrm) of
    (Abs (x, ty, t), Abs (_, ty', t')) =>
       let
         val subtrm = Abs(x, ty, regularize_trm ctxt (t, t'))
       in
         if ty = ty' then subtrm
         else mk_babs $ (mk_resp $ mk_resp_arg ctxt (ty, ty')) $ subtrm
       end

  | (Const (@{const_name "All"}, ty) $ t, Const (@{const_name "All"}, ty') $ t') =>
       let
         val subtrm = apply_subt (regularize_trm ctxt) (t, t')
       in
         if ty = ty' then Const (@{const_name "All"}, ty) $ subtrm
         else mk_ball $ (mk_resp $ mk_resp_arg ctxt (qnt_typ ty, qnt_typ ty')) $ subtrm
       end

  | (Const (@{const_name "Ex"}, ty) $ t, Const (@{const_name "Ex"}, ty') $ t') =>
       let
         val subtrm = apply_subt (regularize_trm ctxt) (t, t')
       in
         if ty = ty' then Const (@{const_name "Ex"}, ty) $ subtrm
         else mk_bex $ (mk_resp $ mk_resp_arg ctxt (qnt_typ ty, qnt_typ ty')) $ subtrm
       end

  | (* equalities need to be replaced by appropriate equivalence relations *) 
    (Const (@{const_name "op ="}, ty), Const (@{const_name "op ="}, ty')) =>
         if ty = ty' then rtrm
         else mk_resp_arg ctxt (domain_type ty, domain_type ty') 

  | (* in this case we just check whether the given equivalence relation is correct *) 
    (rel, Const (@{const_name "op ="}, ty')) =>
       let 
         val exc = LIFT_MATCH "regularise (relation mismatch)"
         val rel_ty = fastype_of rel
         val rel' = mk_resp_arg ctxt (domain_type rel_ty, domain_type ty') 
       in 
         if rel' aconv rel then rtrm else raise exc
       end  

  | (_, Const _) =>
       let 
         fun same_name (Const (s, T)) (Const (s', T')) = (s = s') (*andalso (T = T')*)
           | same_name _ _ = false
          (* TODO/FIXME: This test is not enough. *) 
          (*             Why?                     *)
          (* Because constants can have the same name but not be the same
             constant.  All overloaded constants have the same name but because
             of different types they do differ.
        
             This code will let one write a theorem where plus on nat is
             matched to plus on int, even if the latter is defined differently.
    
             This would result in hard to understand failures in injection and
             cleaning. *)
           (* cu: if I also test the type, then something else breaks *)
       in
         if same_name rtrm qtrm then rtrm
         else 
           let 
             val thy = ProofContext.theory_of ctxt
             val qtrm_str = Syntax.string_of_term ctxt qtrm
             val exc1 = LIFT_MATCH ("regularize (constant " ^ qtrm_str ^ " not found)")
             val exc2 = LIFT_MATCH ("regularize (constant " ^ qtrm_str ^ " mismatch)")
             val rtrm' = #rconst (qconsts_lookup thy qtrm) handle NotFound => raise exc1
           in 
             if Pattern.matches thy (rtrm', rtrm) 
             then rtrm else raise exc2
           end
       end 

  | (t1 $ t2, t1' $ t2') =>
       (regularize_trm ctxt (t1, t1')) $ (regularize_trm ctxt (t2, t2'))

  | (Bound i, Bound i') =>
       if i = i' then rtrm 
       else raise (LIFT_MATCH "regularize (bounds mismatch)")

  | _ =>
       let 
         val rtrm_str = Syntax.string_of_term ctxt rtrm
         val qtrm_str = Syntax.string_of_term ctxt qtrm
       in
         raise (LIFT_MATCH ("regularize failed (default: " ^ rtrm_str ^ "," ^ qtrm_str ^ ")"))
       end

fun regularize_trm_chk ctxt (rtrm, qtrm) =
  regularize_trm ctxt (rtrm, qtrm) 
  |> Syntax.check_term ctxt

(*
Injection of Rep/Abs means:

  For abstractions
:
  * If the type of the abstraction needs lifting, then we add Rep/Abs 
    around the abstraction; otherwise we leave it unchanged.
 
  For applications:
  
  * If the application involves a bounded quantifier, we recurse on 
    the second argument. If the application is a bounded abstraction,
    we always put an Rep/Abs around it (since bounded abstractions
    are assumed to always need lifting). Otherwise we recurse on both 
    arguments.

  For constants:

  * If the constant is (op =), we leave it always unchanged. 
    Otherwise the type of the constant needs lifting, we put
    and Rep/Abs around it. 

  For free variables:

  * We put aRep/Abs around it if the type needs lifting.

  Vars case cannot occur.
*)

fun mk_repabs ctxt (T, T') trm = 
  absrep_fun repF ctxt (T, T') $ (absrep_fun absF ctxt (T, T') $ trm)


(* bound variables need to be treated properly,     *)
(* as the type of subterms needs to be calculated   *)

fun inj_repabs_trm ctxt (rtrm, qtrm) =
 case (rtrm, qtrm) of
    (Const (@{const_name "Ball"}, T) $ r $ t, Const (@{const_name "All"}, _) $ t') =>
       Const (@{const_name "Ball"}, T) $ r $ (inj_repabs_trm ctxt (t, t'))

  | (Const (@{const_name "Bex"}, T) $ r $ t, Const (@{const_name "Ex"}, _) $ t') =>
       Const (@{const_name "Bex"}, T) $ r $ (inj_repabs_trm ctxt (t, t'))

  | (Const (@{const_name "Babs"}, T) $ r $ t, t' as (Abs _)) =>
      let
        val rty = fastype_of rtrm
        val qty = fastype_of qtrm
      in
        mk_repabs ctxt (rty, qty) (Const (@{const_name "Babs"}, T) $ r $ (inj_repabs_trm ctxt (t, t')))
      end

  | (Abs (x, T, t), Abs (x', T', t')) =>
      let
        val rty = fastype_of rtrm
        val qty = fastype_of qtrm
        val (y, s) = Term.dest_abs (x, T, t)
        val (_, s') = Term.dest_abs (x', T', t')
        val yvar = Free (y, T)
        val result = Term.lambda_name (y, yvar) (inj_repabs_trm ctxt (s, s'))
      in
        if rty = qty then result
        else mk_repabs ctxt (rty, qty) result
      end

  | (t $ s, t' $ s') =>  
       (inj_repabs_trm ctxt (t, t')) $ (inj_repabs_trm ctxt (s, s'))

  | (Free (_, T), Free (_, T')) => 
        if T = T' then rtrm 
        else mk_repabs ctxt (T, T') rtrm

  | (_, Const (@{const_name "op ="}, _)) => rtrm

  | (_, Const (_, T')) =>
      let
        val rty = fastype_of rtrm
      in 
        if rty = T' then rtrm
        else mk_repabs ctxt (rty, T') rtrm
      end   
  
  | _ => raise (LIFT_MATCH "injection (default)")

fun inj_repabs_trm_chk ctxt (rtrm, qtrm) =
  inj_repabs_trm ctxt (rtrm, qtrm) 
  |> Syntax.check_term ctxt

end; (* structure *)