Quot/quotient_term.ML
author Christian Urban <urbanc@in.tum.de>
Wed, 13 Jan 2010 09:41:57 +0100
changeset 858 bb012513fb39
parent 856 433f7c17255f
child 865 5c6d76c3ba5c
permissions -rw-r--r--
tuned

signature QUOTIENT_TERM =
sig
  exception LIFT_MATCH of string

  datatype flag = absF | repF

  val absrep_fun: flag -> Proof.context -> typ * typ -> term
  val absrep_fun_chk: flag -> Proof.context -> typ * typ -> term

  val equiv_relation: Proof.context -> typ * typ -> term
  val equiv_relation_chk: Proof.context -> typ * typ -> term

  val regularize_trm: Proof.context -> term * term -> term
  val regularize_trm_chk: Proof.context -> term * term -> term

  val inj_repabs_trm: Proof.context -> term * term -> term
  val inj_repabs_trm_chk: Proof.context -> term * term -> term
end;

structure Quotient_Term: QUOTIENT_TERM =
struct

open Quotient_Info;

exception LIFT_MATCH of string



(*** Aggregate Rep/Abs Function ***)


(* The flag repF is for types in negative position; absF is for types 
   in positive position. Because of this, function types need to be   
   treated specially, since there the polarity changes.               
*)

datatype flag = absF | repF

fun negF absF = repF
  | negF repF = absF

fun is_identity (Const (@{const_name "id"}, _)) = true
  | is_identity _ = false

fun mk_identity ty = Const (@{const_name "id"}, ty --> ty)

fun mk_fun_compose flag (trm1, trm2) = 
  case flag of
    absF => Const (@{const_name "comp"}, dummyT) $ trm1 $ trm2
  | repF => Const (@{const_name "comp"}, dummyT) $ trm2 $ trm1

fun get_mapfun ctxt s =
let
  val thy = ProofContext.theory_of ctxt
  val exc = LIFT_MATCH ("No map function for type " ^ quote s ^ " found.")
  val mapfun = #mapfun (maps_lookup thy s) handle NotFound => raise exc
in
  Const (mapfun, dummyT)
end

(* makes a Free out of a TVar *)
fun mk_Free (TVar ((x, i), _)) = Free (unprefix "'" x ^ string_of_int i, dummyT)

(* produces an aggregate map function for the       
   rty-part of a quotient definition; abstracts     
   over all variables listed in vs (these variables 
   correspond to the type variables in rty)         
                                                    
   for example for: (?'a list * ?'b)                
   it produces:     %a b. prod_map (map a) b 
*)
fun mk_mapfun ctxt vs rty =
let
  val vs' = map (mk_Free) vs

  fun mk_mapfun_aux rty =
    case rty of
      TVar _ => mk_Free rty
    | Type (_, []) => mk_identity rty
    | Type (s, tys) => list_comb (get_mapfun ctxt s, map mk_mapfun_aux tys)
    | _ => raise LIFT_MATCH ("mk_mapfun (default)")
in
  fold_rev Term.lambda vs' (mk_mapfun_aux rty)
end

(* looks up the (varified) rty and qty for 
   a quotient definition                   
*)
fun get_rty_qty ctxt s =
let
  val thy = ProofContext.theory_of ctxt
  val exc = LIFT_MATCH ("No quotient type " ^ quote s ^ " found.")
  val qdata = (quotdata_lookup thy s) handle NotFound => raise exc
in
  (#rtyp qdata, #qtyp qdata)
end

(* takes two type-environments and looks    
   up in both of them the variable v, which 
   must be listed in the environment        
*)
fun double_lookup rtyenv qtyenv v =
let
  val v' = fst (dest_TVar v)
in
  (snd (the (Vartab.lookup rtyenv v')), snd (the (Vartab.lookup qtyenv v')))
end

(* matches a type pattern with a type *)
fun match ctxt err ty_pat ty =
let
  val thy = ProofContext.theory_of ctxt
in
  Sign.typ_match thy (ty_pat, ty) Vartab.empty
  handle MATCH_TYPE => err ctxt ty_pat ty
end

(* produces the rep or abs constant for a qty *)
fun absrep_const flag ctxt qty_str =
let
  val thy = ProofContext.theory_of ctxt
  val qty_name = Long_Name.base_name qty_str
in
  case flag of
    absF => Const (Sign.full_bname thy ("abs_" ^ qty_name), dummyT)
  | repF => Const (Sign.full_bname thy ("rep_" ^ qty_name), dummyT)
end

fun absrep_match_err ctxt ty_pat ty =
let
  val ty_pat_str = Syntax.string_of_typ ctxt ty_pat
  val ty_str = Syntax.string_of_typ ctxt ty 
in
  raise LIFT_MATCH (space_implode " " 
    ["absrep_fun (Types ", quote ty_pat_str, "and", quote ty_str, " do not match.)"])
end


(** generation of an aggregate absrep function **)

(* - In case of equal types we just return the identity.           
     
   - In case of TFrees we also return the identity.
                                                             
   - In case of function types we recurse taking   
     the polarity change into account.              
                                                                   
   - If the type constructors are equal, we recurse for the        
     arguments and build the appropriate map function.             
                                                                   
   - If the type constructors are unequal, there must be an        
     instance of quotient types:         
                          
       - we first look up the corresponding rty_pat and qty_pat    
         from the quotient definition; the arguments of qty_pat    
         must be some distinct TVars                               
       - we then match the rty_pat with rty and qty_pat with qty;  
         if matching fails the types do not correspond -> error                  
       - the matching produces two environments; we look up the    
         assignments for the qty_pat variables and recurse on the  
         assignments                                               
       - we prefix the aggregate map function for the rty_pat,     
         which is an abstraction over all type variables           
       - finally we compose the result with the appropriate        
         absrep function in case at least one argument produced
         a non-identity function /
         otherwise we just return the appropriate absrep
         function                                          
                                                                   
     The composition is necessary for types like                   
                                                                 
        ('a list) list / ('a foo) foo                              
                                                                 
     The matching is necessary for types like                      
                                                                 
        ('a * 'a) list / 'a bar   

     The test is necessary in order to eliminate superfluous
     identity maps.                                 
*)  

fun absrep_fun flag ctxt (rty, qty) =
  if rty = qty  
  then mk_identity rty
  else
    case (rty, qty) of
      (Type ("fun", [ty1, ty2]), Type ("fun", [ty1', ty2'])) =>
        let
          val arg1 = absrep_fun (negF flag) ctxt (ty1, ty1')
          val arg2 = absrep_fun flag ctxt (ty2, ty2')
        in
          list_comb (get_mapfun ctxt "fun", [arg1, arg2])
        end
    | (Type (s, tys), Type (s', tys')) =>
        if s = s'
        then 
           let
             val args = map (absrep_fun flag ctxt) (tys ~~ tys')
           in
             list_comb (get_mapfun ctxt s, args)
           end 
        else
           let
             val (rty_pat, qty_pat as Type (_, vs)) = get_rty_qty ctxt s'
             val rtyenv = match ctxt absrep_match_err rty_pat rty
             val qtyenv = match ctxt absrep_match_err qty_pat qty
             val args_aux = map (double_lookup rtyenv qtyenv) vs
             val args = map (absrep_fun flag ctxt) args_aux
             val map_fun = mk_mapfun ctxt vs rty_pat       
             val result = list_comb (map_fun, args) 
           in
             if forall is_identity args
             then absrep_const flag ctxt s'
             else mk_fun_compose flag (absrep_const flag ctxt s', result)
           end
    | (TFree x, TFree x') =>
        if x = x'
        then mk_identity rty
        else raise (LIFT_MATCH "absrep_fun (frees)")
    | (TVar _, TVar _) => raise (LIFT_MATCH "absrep_fun (vars)")
    | _ => raise (LIFT_MATCH "absrep_fun (default)")

fun absrep_fun_chk flag ctxt (rty, qty) =
  absrep_fun flag ctxt (rty, qty)
  |> Syntax.check_term ctxt




(*** Aggregate Equivalence Relation ***)


(* works very similar to the absrep generation,
   except there is no need for polarities
*)

(* instantiates TVars so that the term is of type ty *)
fun force_typ ctxt trm ty =
let
  val thy = ProofContext.theory_of ctxt 
  val trm_ty = fastype_of trm
  val ty_inst = Sign.typ_match thy (trm_ty, ty) Vartab.empty
in
  map_types (Envir.subst_type ty_inst) trm
end

fun is_eq (Const (@{const_name "op ="}, _)) = true
  | is_eq _ = false

fun mk_rel_compose (trm1, trm2) =
  Const (@{const_name "rel_conj"}, dummyT) $ trm1 $ trm2

fun get_relmap ctxt s =
let
  val thy = ProofContext.theory_of ctxt
  val exc =  LIFT_MATCH ("get_relmap (no relation map function found for type " ^ s ^ ")") 
  val relmap =  #relmap (maps_lookup thy s) handle NotFound => raise exc
in
  Const (relmap, dummyT)
end

fun mk_relmap ctxt vs rty =
let
  val vs' = map (mk_Free) vs

  fun mk_relmap_aux rty =
    case rty of
      TVar _ => mk_Free rty
    | Type (_, []) => HOLogic.eq_const rty
    | Type (s, tys) => list_comb (get_relmap ctxt s, map mk_relmap_aux tys)
    | _ => raise LIFT_MATCH ("mk_relmap (default)")
in
  fold_rev Term.lambda vs' (mk_relmap_aux rty)
end

fun get_equiv_rel ctxt s =
let
  val thy = ProofContext.theory_of ctxt
  val exc = LIFT_MATCH ("get_quotdata (no quotient found for type " ^ s ^ ")") 
in
  #equiv_rel (quotdata_lookup thy s) handle NotFound => raise exc
end

fun equiv_match_err ctxt ty_pat ty =
let
  val ty_pat_str = Syntax.string_of_typ ctxt ty_pat
  val ty_str = Syntax.string_of_typ ctxt ty 
in
  raise LIFT_MATCH (space_implode " " 
    ["equiv_relation (Types ", quote ty_pat_str, "and", quote ty_str, " do not match.)"])
end

(* builds the aggregate equivalence relation 
   that will be the argument of Respects     
*)
fun equiv_relation ctxt (rty, qty) =
  if rty = qty
  then HOLogic.eq_const rty
  else
    case (rty, qty) of
      (Type (s, tys), Type (s', tys')) =>
       if s = s' 
       then 
         let
           val args = map (equiv_relation ctxt) (tys ~~ tys')
         in
           list_comb (get_relmap ctxt s, args) 
         end  
       else 
         let
           val (rty_pat, qty_pat as Type (_, vs)) = get_rty_qty ctxt s'
           val rtyenv = match ctxt equiv_match_err rty_pat rty
           val qtyenv = match ctxt equiv_match_err qty_pat qty
           val args_aux = map (double_lookup rtyenv qtyenv) vs 
           val args = map (equiv_relation ctxt) args_aux
           val rel_map = mk_relmap ctxt vs rty_pat       
           val result = list_comb (rel_map, args)
           val eqv_rel = get_equiv_rel ctxt s'
           val eqv_rel' = force_typ ctxt eqv_rel ([rty, rty] ---> @{typ bool})
         in
           if forall is_eq args 
           then eqv_rel'
           else mk_rel_compose (result, eqv_rel')
         end
      | _ => HOLogic.eq_const rty

fun equiv_relation_chk ctxt (rty, qty) =
  equiv_relation ctxt (rty, qty)
  |> Syntax.check_term ctxt



(*** Regularization ***)

(* Regularizing an rtrm means:
 
 - Quantifiers over types that need lifting are replaced 
   by bounded quantifiers, for example:

      All P  ----> All (Respects R) P

   where the aggregate relation R is given by the rty and qty;
 
 - Abstractions over types that need lifting are replaced
   by bounded abstractions, for example:
      
      %x. P  ----> Ball (Respects R) %x. P

 - Equalities over types that need lifting are replaced by
   corresponding equivalence relations, for example:

      A = B  ----> R A B

   or

      A = B  ----> (R ===> R) A B

   for more complicated types of A and B
*)

val mk_babs = Const (@{const_name Babs}, dummyT)
val mk_ball = Const (@{const_name Ball}, dummyT)
val mk_bex  = Const (@{const_name Bex}, dummyT)
val mk_resp = Const (@{const_name Respects}, dummyT)

(* - applies f to the subterm of an abstraction, 
     otherwise to the given term,                
   - used by regularize, therefore abstracted    
     variables do not have to be treated specially 
*)
fun apply_subt f (trm1, trm2) =
  case (trm1, trm2) of
    (Abs (x, T, t), Abs (_ , _, t')) => Abs (x, T, f (t, t'))
  | _ => f (trm1, trm2)

(* the major type of All and Ex quantifiers *)
fun qnt_typ ty = domain_type (domain_type ty)

(* produces a regularized version of rtrm

   - the result might contain dummyTs           

   - for regularisation we do not need any      
     special treatment of bound variables       
*)
fun regularize_trm ctxt (rtrm, qtrm) =
  case (rtrm, qtrm) of
    (Abs (x, ty, t), Abs (_, ty', t')) =>
       let
         val subtrm = Abs(x, ty, regularize_trm ctxt (t, t'))
       in
         if ty = ty' then subtrm
         else mk_babs $ (mk_resp $ equiv_relation ctxt (ty, ty')) $ subtrm
       end

  | (Const (@{const_name "All"}, ty) $ t, Const (@{const_name "All"}, ty') $ t') =>
       let
         val subtrm = apply_subt (regularize_trm ctxt) (t, t')
       in
         if ty = ty' then Const (@{const_name "All"}, ty) $ subtrm
         else mk_ball $ (mk_resp $ equiv_relation ctxt (qnt_typ ty, qnt_typ ty')) $ subtrm
       end

  | (Const (@{const_name "Ex"}, ty) $ t, Const (@{const_name "Ex"}, ty') $ t') =>
       let
         val subtrm = apply_subt (regularize_trm ctxt) (t, t')
       in
         if ty = ty' then Const (@{const_name "Ex"}, ty) $ subtrm
         else mk_bex $ (mk_resp $ equiv_relation ctxt (qnt_typ ty, qnt_typ ty')) $ subtrm
       end

  | (* equalities need to be replaced by appropriate equivalence relations *) 
    (Const (@{const_name "op ="}, ty), Const (@{const_name "op ="}, ty')) =>
         if ty = ty' then rtrm
         else equiv_relation ctxt (domain_type ty, domain_type ty') 

  | (* in this case we just check whether the given equivalence relation is correct *) 
    (rel, Const (@{const_name "op ="}, ty')) =>
       let 
         (* FIXME: better exception handling *)  
         fun exc rel rel' = LIFT_MATCH ("regularise (relation mismatch)\n[" ^
           Syntax.string_of_term ctxt rel ^ " :: " ^
           Syntax.string_of_typ ctxt (fastype_of rel) ^ "]\n[" ^
           Syntax.string_of_term ctxt rel' ^ " :: " ^
           Syntax.string_of_typ ctxt (fastype_of rel') ^ "]")
         val rel_ty = fastype_of rel
         val rel' = equiv_relation_chk ctxt (domain_type rel_ty, domain_type ty') 
       in 
         if rel' aconv rel then rtrm else raise (exc rel rel')
       end  

  | (_, Const _) =>
       let
         val thy = ProofContext.theory_of ctxt
         fun matches_typ T T' =
           case (T, T') of
             (TFree _, TFree _) => true
           | (TVar _, TVar _) => true
           | (Type (s, tys), Type (s', tys')) => (
               (s = s' andalso tys = tys') orelse
               (* 'andalso' is buildin syntax so it needs to be expanded *)
               (fold (fn x => fn y => x andalso y) (map2 matches_typ tys tys') (s = s')
                handle UnequalLengths => false
               ) orelse
               let
                 val rty = #rtyp (Quotient_Info.quotdata_lookup thy s')
               in
                 Sign.typ_instance thy (T, rty)
               end
               handle Not_found => false (* raised by quotdata_lookup *)
             )
           | _ => false
         fun same_const (Const (s, T)) (Const (s', T')) = (s = s') andalso matches_typ T T'
           | same_const _ _ = false
       in
         if same_const rtrm qtrm then rtrm
         else
           let
             val qtrm_str = Syntax.string_of_term ctxt qtrm
             val exc1 = LIFT_MATCH ("regularize (constant " ^ qtrm_str ^ " not found)")
             val exc2 = LIFT_MATCH ("regularize (constant " ^ qtrm_str ^ " mismatch)")
             val rtrm' = #rconst (qconsts_lookup thy qtrm) handle NotFound => raise exc1
           in 
             if Pattern.matches thy (rtrm', rtrm) 
             then rtrm else raise exc2
           end
       end 

  | (t1 $ t2, t1' $ t2') =>
       (regularize_trm ctxt (t1, t1')) $ (regularize_trm ctxt (t2, t2'))

  | (Bound i, Bound i') =>
       if i = i' then rtrm 
       else raise (LIFT_MATCH "regularize (bounds mismatch)")

  | _ =>
       let 
         val rtrm_str = Syntax.string_of_term ctxt rtrm
         val qtrm_str = Syntax.string_of_term ctxt qtrm
       in
         raise (LIFT_MATCH ("regularize failed (default: " ^ rtrm_str ^ "," ^ qtrm_str ^ ")"))
       end

fun regularize_trm_chk ctxt (rtrm, qtrm) =
  regularize_trm ctxt (rtrm, qtrm) 
  |> Syntax.check_term ctxt



(*** Rep/Abs Injection ***)

(*
Injection of Rep/Abs means:

  For abstractions:

  * If the type of the abstraction needs lifting, then we add Rep/Abs 
    around the abstraction; otherwise we leave it unchanged.
 
  For applications:
  
  * If the application involves a bounded quantifier, we recurse on 
    the second argument. If the application is a bounded abstraction,
    we always put an Rep/Abs around it (since bounded abstractions
    are assumed to always need lifting). Otherwise we recurse on both 
    arguments.

  For constants:

  * If the constant is (op =), we leave it always unchanged. 
    Otherwise the type of the constant needs lifting, we put
    and Rep/Abs around it. 

  For free variables:

  * We put a Rep/Abs around it if the type needs lifting.

  Vars case cannot occur.
*)

fun mk_repabs ctxt (T, T') trm = 
  absrep_fun repF ctxt (T, T') $ (absrep_fun absF ctxt (T, T') $ trm)

fun inj_repabs_err ctxt msg rtrm qtrm =
let
  val rtrm_str = Syntax.string_of_term ctxt rtrm
  val qtrm_str = Syntax.string_of_term ctxt qtrm 
in
  raise LIFT_MATCH (space_implode " " [msg, quote rtrm_str, "and", quote qtrm_str])
end


(* bound variables need to be treated properly,
   as the type of subterms needs to be calculated   *)
fun inj_repabs_trm ctxt (rtrm, qtrm) =
 case (rtrm, qtrm) of
    (Const (@{const_name "Ball"}, T) $ r $ t, Const (@{const_name "All"}, _) $ t') =>
       Const (@{const_name "Ball"}, T) $ r $ (inj_repabs_trm ctxt (t, t'))

  | (Const (@{const_name "Bex"}, T) $ r $ t, Const (@{const_name "Ex"}, _) $ t') =>
       Const (@{const_name "Bex"}, T) $ r $ (inj_repabs_trm ctxt (t, t'))

  | (Const (@{const_name "Babs"}, T) $ r $ t, t' as (Abs _)) =>
      let
        val rty = fastype_of rtrm
        val qty = fastype_of qtrm
      in
        mk_repabs ctxt (rty, qty) (Const (@{const_name "Babs"}, T) $ r $ (inj_repabs_trm ctxt (t, t')))
      end

  | (Abs (x, T, t), Abs (x', T', t')) =>
      let
        val rty = fastype_of rtrm
        val qty = fastype_of qtrm
        val (y, s) = Term.dest_abs (x, T, t)
        val (_, s') = Term.dest_abs (x', T', t')
        val yvar = Free (y, T)
        val result = Term.lambda_name (y, yvar) (inj_repabs_trm ctxt (s, s'))
      in
        if rty = qty then result
        else mk_repabs ctxt (rty, qty) result
      end

  | (t $ s, t' $ s') =>  
       (inj_repabs_trm ctxt (t, t')) $ (inj_repabs_trm ctxt (s, s'))

  | (Free (_, T), Free (_, T')) => 
        if T = T' then rtrm 
        else mk_repabs ctxt (T, T') rtrm

  | (_, Const (@{const_name "op ="}, _)) => rtrm

  | (_, Const (_, T')) =>
      let
        val rty = fastype_of rtrm
      in 
        if rty = T' then rtrm
        else mk_repabs ctxt (rty, T') rtrm
      end   
  
  | _ => inj_repabs_err ctxt "injection (default):" rtrm qtrm

fun inj_repabs_trm_chk ctxt (rtrm, qtrm) =
  inj_repabs_trm ctxt (rtrm, qtrm) 
  |> Syntax.check_term ctxt

end; (* structure *)