quotient_def.ML
author Cezary Kaliszyk <kaliszyk@in.tum.de>
Mon, 23 Nov 2009 15:08:09 +0100
changeset 342 eb15be678ac4
parent 331 345c422b1cb5
child 365 ba057402ea53
permissions -rw-r--r--
lift_thm with a goal.


signature QUOTIENT_DEF =
sig
  datatype flag = absF | repF
  val get_fun: flag -> Proof.context -> typ * typ -> term
  val make_def: binding -> typ -> mixfix -> Attrib.binding -> term ->
    Proof.context -> (term * thm) * local_theory

  val quotdef: (binding * typ * mixfix) * (Attrib.binding * term) ->
    local_theory -> (term * thm) * local_theory
  val quotdef_cmd: (binding * string * mixfix) * (Attrib.binding * string) ->
    local_theory -> local_theory
end;

structure Quotient_Def: QUOTIENT_DEF =
struct

(* wrapper for define *)
fun define name mx attr rhs lthy =
let
  val ((rhs, (_ , thm)), lthy') =
     Local_Theory.define ((name, mx), (attr, rhs)) lthy
in
  ((rhs, thm), lthy')
end

datatype flag = absF | repF

fun negF absF = repF
  | negF repF = absF

fun mk_identity ty = Const (@{const_name "id"}, ty --> ty)

fun ty_strs lthy (ty1, ty2) = 
  (quote (Syntax.string_of_typ lthy ty1),
   quote (Syntax.string_of_typ lthy ty2))

fun ty_lift_error1 lthy rty qty =
let
  val (rty_str, qty_str) = ty_strs lthy (rty, qty) 
  val msg = ["quotient type", qty_str, "and lifted type", rty_str, "do not match."]
in
  raise LIFT_MATCH (space_implode " " msg)
end

fun ty_lift_error2 lthy rty qty =
let
  val (rty_str, qty_str) = ty_strs lthy (rty, qty)   
  val msg = ["No type variables allowed in", qty_str, "and", rty_str, "."]
in
  raise LIFT_MATCH (space_implode " " msg)
end

fun get_fun_aux lthy s fs =
  case (maps_lookup (ProofContext.theory_of lthy) s) of
    SOME info => list_comb (Const (#mapfun info, dummyT), fs)
  | NONE      => raise LIFT_MATCH (space_implode " " ["No map function for type", quote s, "."])

fun get_const flag lthy _ qty =
(* FIXME: check here that _ and qty are related *)
let 
  val thy = ProofContext.theory_of lthy
  val qty_name = Long_Name.base_name (fst (dest_Type qty))
in
  case flag of
    absF => Const (Sign.full_bname thy ("ABS_" ^ qty_name), dummyT)
  | repF => Const (Sign.full_bname thy ("REP_" ^ qty_name), dummyT)
end


(* calculates the aggregate abs and rep functions for a given type; 
   repF is for constants' arguments; absF is for constants;
   function types need to be treated specially, since repF and absF
   change *)

fun get_fun flag lthy (rty, qty) =
  case (rty, qty) of 
    (Type ("fun", [ty1, ty2]), Type ("fun", [ty1', ty2'])) =>
     let
       val fs_ty1 = get_fun (negF flag) lthy (ty1, ty1')
       val fs_ty2 = get_fun flag lthy (ty2, ty2')
     in  
       get_fun_aux lthy "fun" [fs_ty1, fs_ty2]
     end 
  | (Type (s, []), Type (s', [])) =>
     if s = s'
     then mk_identity qty 
     else get_const flag lthy rty qty
  | (Type (s, tys), Type (s', tys')) =>
     if s = s'
     then get_fun_aux lthy s' (map (get_fun flag lthy) (tys ~~ tys'))
     else get_const flag lthy rty qty
  | (TFree x, TFree x') =>
     if x = x'
     then mk_identity qty 
     else ty_lift_error1 lthy rty qty
  | (TVar _, TVar _) => ty_lift_error2 lthy rty qty
  | _ => ty_lift_error1 lthy rty qty

fun make_def qconst_bname qty mx attr rhs lthy =
let
  val rty = fastype_of rhs
  val (arg_rtys, res_rty) = strip_type rty
  val (arg_qtys, res_qty) = strip_type qty
  
  val rep_fns = map (get_fun repF lthy) (arg_rtys ~~ arg_qtys)
  val abs_fn  = get_fun absF lthy (res_rty, res_qty)

  fun mk_fun_map t s =  
        Const (@{const_name "fun_map"}, dummyT) $ t $ s

  val absrep_trm = (fold_rev mk_fun_map rep_fns abs_fn $ rhs)
                   |> Syntax.check_term lthy 

  val ((trm, thm), lthy') = define qconst_bname mx attr absrep_trm lthy

  val qconst_str = Binding.name_of qconst_bname
  fun qcinfo phi = qconsts_transfer phi {qconst = trm, rconst = rhs}
  val lthy'' = Local_Theory.declaration true
                 (fn phi => qconsts_update_gen qconst_str (qcinfo phi)) lthy'
in
  ((trm, thm), lthy'')
end

(* interface and syntax setup *)

(* the ML-interface takes a 5-tuple consisting of  *)
(*                                                 *)
(* - the name of the constant to be lifted         *)
(* - its type                                      *)
(* - its mixfix annotation                         *)
(* - a meta-equation defining the constant,        *)
(*   and the attributes of for this meta-equality  *)

fun quotdef ((bind, qty, mx), (attr, prop)) lthy =
let   
  val (_, prop') = LocalDefs.cert_def lthy prop
  val (_, rhs) = Primitive_Defs.abs_def prop'
in
  make_def bind qty mx attr rhs lthy 
end

fun quotdef_cmd ((bind, qtystr, mx), (attr, propstr)) lthy = 
let
  val qty  = Syntax.read_typ lthy qtystr
  val prop = Syntax.read_prop lthy propstr
in
  quotdef ((bind, qty, mx), (attr, prop)) lthy |> snd
end

val quotdef_parser =
  (OuterParse.binding --
    (OuterParse.$$$ "::" |-- OuterParse.!!! (OuterParse.typ -- 
      OuterParse.opt_mixfix' --| OuterParse.where_)) >> OuterParse.triple2) -- 
       (SpecParse.opt_thm_name ":" -- OuterParse.prop)

val _ = OuterSyntax.local_theory "quotient_def" "lifted definition of constants"
  OuterKeyword.thy_decl (quotdef_parser >> quotdef_cmd)

end; (* structure *)

open Quotient_Def;