quotient_def.ML
changeset 277 37636f2b1c19
child 279 b2fd070c8833
equal deleted inserted replaced
275:34ad627ac5d5 277:37636f2b1c19
       
     1 
       
     2 signature QUOTIENT_DEF =
       
     3 sig
       
     4   datatype flag = absF | repF
       
     5   val get_fun: flag -> (typ * typ) list -> Proof.context -> typ -> term * (typ * typ)
       
     6   val make_def: binding -> term -> typ -> mixfix -> Attrib.binding -> (typ * typ) list ->
       
     7     Proof.context -> (term * thm) * local_theory
       
     8 
       
     9   val quotdef: (binding * typ * mixfix) * (Attrib.binding * term) ->
       
    10     local_theory -> (term * thm) * local_theory
       
    11   val quotdef_cmd: (binding * string * mixfix) * (Attrib.binding * string) ->
       
    12     local_theory -> local_theory
       
    13 end;
       
    14 
       
    15 structure Quotient_Def: QUOTIENT_DEF =
       
    16 struct
       
    17 
       
    18 fun define name mx attr rhs lthy =
       
    19 let
       
    20   val ((rhs, (_ , thm)), lthy') =
       
    21      LocalTheory.define Thm.internalK ((name, mx), (attr, rhs)) lthy
       
    22 in
       
    23   ((rhs, thm), lthy')
       
    24 end
       
    25 
       
    26 fun lookup_qenv qenv qty =
       
    27   (case (AList.lookup (op=) qenv qty) of
       
    28     SOME rty => SOME (qty, rty)
       
    29   | NONE => NONE)
       
    30 
       
    31 
       
    32 (* calculates the aggregate abs and rep functions for a given type; 
       
    33    repF is for constants' arguments; absF is for constants;
       
    34    function types need to be treated specially, since repF and absF
       
    35    change *)
       
    36 
       
    37 datatype flag = absF | repF
       
    38 
       
    39 fun negF absF = repF
       
    40   | negF repF = absF
       
    41 
       
    42 fun get_fun flag qenv lthy ty =
       
    43 let
       
    44   
       
    45   fun get_fun_aux s fs_tys =
       
    46   let
       
    47     val (fs, tys) = split_list fs_tys
       
    48     val (otys, ntys) = split_list tys
       
    49     val oty = Type (s, otys)
       
    50     val nty = Type (s, ntys)
       
    51     val ftys = map (op -->) tys
       
    52   in
       
    53    (case (maps_lookup (ProofContext.theory_of lthy) s) of
       
    54       SOME info => (list_comb (Const (#mapfun info, ftys ---> (oty --> nty)), fs), (oty, nty))
       
    55     | NONE      => error ("no map association for type " ^ s))
       
    56   end
       
    57 
       
    58   fun get_fun_fun fs_tys =
       
    59   let
       
    60     val (fs, tys) = split_list fs_tys
       
    61     val ([oty1, oty2], [nty1, nty2]) = split_list tys
       
    62     val oty = nty1 --> oty2
       
    63     val nty = oty1 --> nty2
       
    64     val ftys = map (op -->) tys
       
    65   in
       
    66     (list_comb (Const (@{const_name "fun_map"}, ftys ---> oty --> nty), fs), (oty, nty))
       
    67   end
       
    68 
       
    69   fun get_const flag (qty, rty) =
       
    70   let 
       
    71     val thy = ProofContext.theory_of lthy
       
    72     val qty_name = Long_Name.base_name (fst (dest_Type qty))
       
    73   in
       
    74     case flag of
       
    75       absF => (Const (Sign.full_bname thy ("ABS_" ^ qty_name), rty --> qty), (rty, qty))
       
    76     | repF => (Const (Sign.full_bname thy ("REP_" ^ qty_name), qty --> rty), (qty, rty))
       
    77   end
       
    78 
       
    79   fun mk_identity ty = Abs ("", ty, Bound 0)
       
    80 
       
    81 in
       
    82   if (AList.defined (op=) qenv ty)
       
    83   then (get_const flag (the (lookup_qenv qenv ty)))
       
    84   else (case ty of
       
    85           TFree _ => (mk_identity ty, (ty, ty))
       
    86         | Type (_, []) => (mk_identity ty, (ty, ty)) 
       
    87         | Type ("fun" , [ty1, ty2]) => 
       
    88                  get_fun_fun [get_fun (negF flag) qenv lthy ty1, get_fun flag qenv lthy ty2]
       
    89         | Type (s, tys) => get_fun_aux s (map (get_fun flag qenv lthy) tys)
       
    90         | _ => raise ERROR ("no type variables"))
       
    91 end
       
    92 
       
    93 fun make_def nconst_bname rhs qty mx attr qenv lthy =
       
    94 let
       
    95   val (arg_tys, res_ty) = strip_type qty
       
    96 
       
    97   val rep_fns = map (fst o get_fun repF qenv lthy) arg_tys
       
    98   val abs_fn  = (fst o get_fun absF qenv lthy) res_ty
       
    99 
       
   100   fun mk_fun_map t s = 
       
   101         Const (@{const_name "fun_map"}, dummyT) $ t $ s
       
   102 
       
   103   val absrep_fn = fold_rev mk_fun_map rep_fns abs_fn
       
   104                   |> Syntax.check_term lthy 
       
   105 in
       
   106   define nconst_bname mx attr (absrep_fn $ rhs) lthy
       
   107 end
       
   108 
       
   109 
       
   110 (* returns all subterms where two types differ *)
       
   111 fun diff (T, S) Ds =
       
   112   case (T, S) of
       
   113     (TVar v, TVar u) => if v = u then Ds else (T, S)::Ds 
       
   114   | (TFree x, TFree y) => if x = y then Ds else (T, S)::Ds
       
   115   | (Type (a, Ts), Type (b, Us)) => 
       
   116       if a = b then diffs (Ts, Us) Ds else (T, S)::Ds
       
   117   | _ => (T, S)::Ds
       
   118 and diffs (T::Ts, U::Us) Ds = diffs (Ts, Us) (diff (T, U) Ds)
       
   119   | diffs ([], []) Ds = Ds
       
   120   | diffs _ _ = error "Unequal length of type arguments"
       
   121 
       
   122 
       
   123 fun error_msg lthy (qty, rty) =
       
   124 let 
       
   125   val qtystr = quote (Syntax.string_of_typ lthy qty)
       
   126   val rtystr = quote (Syntax.string_of_typ lthy rty)
       
   127 in
       
   128   error (implode ["Quotient type ", qtystr, " does not match with ", rtystr])
       
   129 end
       
   130 
       
   131 fun sanity_chk lthy qenv =
       
   132 let
       
   133    val qenv' = Quotient_Info.mk_qenv lthy
       
   134    val thy = ProofContext.theory_of lthy
       
   135 
       
   136    fun is_inst thy (qty, rty) (qty', rty') =
       
   137    if Sign.typ_instance thy (qty, qty')
       
   138    then let
       
   139           val inst = Sign.typ_match thy (qty', qty) Vartab.empty
       
   140         in
       
   141           rty = Envir.subst_type inst rty'
       
   142         end
       
   143    else false
       
   144 
       
   145    fun chk_inst (qty, rty) = 
       
   146      if exists (is_inst thy (qty, rty)) qenv' then true
       
   147      else error_msg lthy (qty, rty)
       
   148 in
       
   149   forall chk_inst qenv
       
   150 end
       
   151 
       
   152 
       
   153 fun quotdef ((bind, qty, mx), (attr, prop)) lthy =
       
   154 let   
       
   155   val (_, prop') = PrimitiveDefs.dest_def lthy (K true) (K false) (K false) prop
       
   156   val (_, rhs) = PrimitiveDefs.abs_def prop'
       
   157 
       
   158   val rty = fastype_of rhs
       
   159   val qenv = distinct (op=) (diff (qty, rty) []) 
       
   160 in
       
   161   sanity_chk lthy qenv;
       
   162   make_def bind rhs qty mx attr qenv lthy 
       
   163 end
       
   164 
       
   165 
       
   166 val quotdef_parser =
       
   167   (OuterParse.binding --
       
   168     (OuterParse.$$$ "::" |-- OuterParse.!!! (OuterParse.typ -- 
       
   169       OuterParse.opt_mixfix' --| OuterParse.where_)) >> OuterParse.triple2) -- 
       
   170        (SpecParse.opt_thm_name ":" -- OuterParse.prop)
       
   171 
       
   172 fun quotdef_cmd ((bind, qtystr, mx), (attr, propstr)) lthy = 
       
   173 let
       
   174   val qty  = (Syntax.check_typ lthy o Syntax.parse_typ lthy) qtystr
       
   175   val prop = (Syntax.check_prop lthy o Syntax.parse_prop lthy) propstr
       
   176 in
       
   177   quotdef ((bind, qty, mx), (attr, prop)) lthy |> snd
       
   178 end
       
   179 
       
   180 val _ = OuterSyntax.local_theory "quotient_def" "lifted definition of constants"
       
   181   OuterKeyword.thy_decl (quotdef_parser >> quotdef_cmd)
       
   182 
       
   183 end; (* structure *)
       
   184 
       
   185 open Quotient_Def;