quotient_def.ML
author Christian Urban <urbanc@in.tum.de>
Thu, 19 Nov 2009 14:17:10 +0100
changeset 319 0ae9d9e66cb7
parent 318 746b17e1d6d8
child 321 f46dc0ca08c3
permissions -rw-r--r--
updated to new Isabelle


signature QUOTIENT_DEF =
sig
  datatype flag = absF | repF
  val get_fun: flag -> (typ * typ) list -> Proof.context -> typ -> term
  val make_def: binding -> typ -> mixfix -> Attrib.binding -> term ->
    Proof.context -> (term * thm) * local_theory

  val quotdef: (binding * typ * mixfix) * (Attrib.binding * term) ->
    local_theory -> (term * thm) * local_theory
  val quotdef_cmd: (binding * string * mixfix) * (Attrib.binding * string) ->
    local_theory -> local_theory
  val diff: (typ * typ) -> (typ * typ) list -> (typ * typ) list
end;

structure Quotient_Def: QUOTIENT_DEF =
struct

(* wrapper for define *)
fun define name mx attr rhs lthy =
let
  val ((rhs, (_ , thm)), lthy') =
     Local_Theory.define "" ((name, mx), (attr, rhs)) lthy
in
  ((rhs, thm), lthy')
end


(* calculates the aggregate abs and rep functions for a given type; 
   repF is for constants' arguments; absF is for constants;
   function types need to be treated specially, since repF and absF
   change *)

datatype flag = absF | repF

fun negF absF = repF
  | negF repF = absF

fun get_fun flag qenv lthy ty =
let
  
  fun get_fun_aux s fs =
   (case (maps_lookup (ProofContext.theory_of lthy) s) of
      SOME info => list_comb (Const (#mapfun info, dummyT), fs)
    | NONE      => error ("no map association for type " ^ s))

  fun get_const flag qty =
  let 
    val thy = ProofContext.theory_of lthy
    val qty_name = Long_Name.base_name (fst (dest_Type qty))
  in
    case flag of
      absF => Const (Sign.full_bname thy ("ABS_" ^ qty_name), dummyT)
    | repF => Const (Sign.full_bname thy ("REP_" ^ qty_name), dummyT)
  end

  fun mk_identity ty = Abs ("", ty, Bound 0)

in
  if (AList.defined (op=) qenv ty)
  then (get_const flag ty)
  else (case ty of
          TFree _ => mk_identity ty
        | Type (_, []) => mk_identity ty 
        | Type ("fun" , [ty1, ty2]) => 
            let
              val fs_ty1 = get_fun (negF flag) qenv lthy ty1
              val fs_ty2 = get_fun flag qenv lthy ty2
            in  
              get_fun_aux "fun" [fs_ty1, fs_ty2]
            end 
        | Type (s, tys) => get_fun_aux s (map (get_fun flag qenv lthy) tys)
        | _ => error ("no type variables allowed"))
end

(* returns all subterms where two types differ *)
fun diff (T, S) Ds =
  case (T, S) of
    (TVar v, TVar u) => if v = u then Ds else (T, S)::Ds 
  | (TFree x, TFree y) => if x = y then Ds else (T, S)::Ds
  | (Type (a, Ts), Type (b, Us)) => 
      if a = b then diffs (Ts, Us) Ds else (T, S)::Ds
  | _ => (T, S)::Ds
and diffs (T::Ts, U::Us) Ds = diffs (Ts, Us) (diff (T, U) Ds)
  | diffs ([], []) Ds = Ds
  | diffs _ _ = error "Unequal length of type arguments"


(* sanity check that the calculated quotient environment
   matches with the stored quotient environment. *)
fun sanity_chk qenv lthy =
let
  val global_qenv = Quotient_Info.mk_qenv lthy
  val thy = ProofContext.theory_of lthy

  fun error_msg lthy (qty, rty) =
  let 
    val qtystr = quote (Syntax.string_of_typ lthy qty)
    val rtystr = quote (Syntax.string_of_typ lthy rty)
  in
    error (implode ["Quotient type ", qtystr, " does not match with ", rtystr])
  end

  fun is_inst (qty, rty) (qty', rty') =
    if Sign.typ_instance thy (qty, qty')
    then let
           val inst = Sign.typ_match thy (qty', qty) Vartab.empty
         in
           rty = Envir.subst_type inst rty'       
         end
    else false

  fun chk_inst (qty, rty) = 
    if exists (is_inst (qty, rty)) global_qenv 
    then true
    else error_msg lthy (qty, rty)
in
  map chk_inst qenv
end

fun make_def nconst_bname qty mx attr rhs lthy =
let
  val (arg_tys, res_ty) = strip_type qty

  val rty = fastype_of rhs
  val qenv = distinct (op=) (diff (qty, rty) [])   

  val _ = sanity_chk qenv lthy

  val rep_fns = map (get_fun repF qenv lthy) arg_tys
  val abs_fn  = (get_fun absF qenv lthy) res_ty

  fun mk_fun_map t s =  
        Const (@{const_name "fun_map"}, dummyT) $ t $ s

  val absrep_trm = (fold_rev mk_fun_map rep_fns abs_fn $ rhs)
                   |> Syntax.check_term lthy 

  val ((trm, thm), lthy') = define nconst_bname mx attr absrep_trm lthy

  val nconst_str = Binding.name_of nconst_bname
  val qcinfo = {qconst = trm, rconst = rhs}
  val lthy'' = Local_Theory.declaration true
                (fn phi => qconsts_update_generic nconst_str 
                             (qconsts_transfer phi qcinfo)) lthy'
in
  ((trm, thm), lthy'')
end

(* interface and syntax setup *)

(* the ML-interface takes a 5-tuple consisting of  *)
(*                                                 *)
(* - the name of the constant to be lifted         *)
(* - its type                                      *)
(* - its mixfix annotation                         *)
(* - a meta-equation defining the constant,        *)
(*   and the attributes of for this meta-equality  *)

fun quotdef ((bind, qty, mx), (attr, prop)) lthy =
let   
  val (_, prop') = LocalDefs.cert_def lthy prop
  val (_, rhs) = Primitive_Defs.abs_def prop'
in
  make_def bind qty mx attr rhs lthy 
end

fun quotdef_cmd ((bind, qtystr, mx), (attr, propstr)) lthy = 
let
  val qty  = Syntax.read_typ lthy qtystr
  val prop = Syntax.read_prop lthy propstr
in
  quotdef ((bind, qty, mx), (attr, prop)) lthy |> snd
end


val quotdef_parser =
  (OuterParse.binding --
    (OuterParse.$$$ "::" |-- OuterParse.!!! (OuterParse.typ -- 
      OuterParse.opt_mixfix' --| OuterParse.where_)) >> OuterParse.triple2) -- 
       (SpecParse.opt_thm_name ":" -- OuterParse.prop)

val _ = OuterSyntax.local_theory "quotient_def" "lifted definition of constants"
  OuterKeyword.thy_decl (quotdef_parser >> quotdef_cmd)

end; (* structure *)

open Quotient_Def;