Nominal/nominal_function_common.ML
changeset 2819 4bd584ff4fab
child 2821 c7d4bd9e89e0
equal deleted inserted replaced
2818:8fe80e9f796d 2819:4bd584ff4fab
       
     1 (*  Nominal Function Common
       
     2     Author: Christian Urban
       
     3 
       
     4 Common definitions and other infrastructure for the function package.
       
     5 *)
       
     6 
       
     7 signature FUNCTION_DATA =
       
     8 sig
       
     9 
       
    10 type info =
       
    11  {is_partial : bool,
       
    12   defname : string,
       
    13     (* contains no logical entities: invariant under morphisms: *)
       
    14   add_simps : (binding -> binding) -> string -> (binding -> binding) ->
       
    15     Attrib.src list -> thm list -> local_theory -> thm list * local_theory,
       
    16   case_names : string list,
       
    17   fs : term list,
       
    18   R : term,
       
    19   psimps: thm list,
       
    20   pinducts: thm list,
       
    21   simps : thm list option,
       
    22   inducts : thm list option,
       
    23   termination: thm}
       
    24 
       
    25 end
       
    26 
       
    27 structure Function_Data : FUNCTION_DATA =
       
    28 struct
       
    29 
       
    30 type info =
       
    31  {is_partial : bool,
       
    32   defname : string,
       
    33     (* contains no logical entities: invariant under morphisms: *)
       
    34   add_simps : (binding -> binding) -> string -> (binding -> binding) ->
       
    35     Attrib.src list -> thm list -> local_theory -> thm list * local_theory,
       
    36   case_names : string list,
       
    37   fs : term list,
       
    38   R : term,
       
    39   psimps: thm list,
       
    40   pinducts: thm list,
       
    41   simps : thm list option,
       
    42   inducts : thm list option,
       
    43   termination: thm}
       
    44 
       
    45 end
       
    46 
       
    47 structure Nominal_Function_Common =
       
    48 struct
       
    49 
       
    50 open Function_Data
       
    51 
       
    52 local open Function_Lib in
       
    53 
       
    54 (* Profiling *)
       
    55 val profile = Unsynchronized.ref false;
       
    56 
       
    57 fun PROFILE msg = if !profile then timeap_msg msg else I
       
    58 
       
    59 
       
    60 val acc_const_name = @{const_name accp}
       
    61 fun mk_acc domT R =
       
    62   Const (acc_const_name, (domT --> domT --> HOLogic.boolT) --> domT --> HOLogic.boolT) $ R 
       
    63 
       
    64 val function_name = suffix "C"
       
    65 val graph_name = suffix "_graph"
       
    66 val rel_name = suffix "_rel"
       
    67 val dom_name = suffix "_dom"
       
    68 
       
    69 (* Termination rules *)
       
    70 
       
    71 structure TerminationRule = Generic_Data
       
    72 (
       
    73   type T = thm list
       
    74   val empty = []
       
    75   val extend = I
       
    76   val merge = Thm.merge_thms
       
    77 );
       
    78 
       
    79 val get_termination_rules = TerminationRule.get
       
    80 val store_termination_rule = TerminationRule.map o cons
       
    81 val apply_termination_rule = resolve_tac o get_termination_rules o Context.Proof
       
    82 
       
    83 
       
    84 (* Function definition result data *)
       
    85 
       
    86 datatype function_result = FunctionResult of
       
    87  {fs: term list,
       
    88   G: term,
       
    89   R: term,
       
    90   psimps : thm list,
       
    91   simple_pinducts : thm list,
       
    92   cases : thm,
       
    93   termination : thm,
       
    94   domintros : thm list option}
       
    95 
       
    96 fun morph_function_data ({add_simps, case_names, fs, R, psimps, pinducts,
       
    97   simps, inducts, termination, defname, is_partial} : info) phi =
       
    98     let
       
    99       val term = Morphism.term phi val thm = Morphism.thm phi val fact = Morphism.fact phi
       
   100       val name = Binding.name_of o Morphism.binding phi o Binding.name
       
   101     in
       
   102       { add_simps = add_simps, case_names = case_names,
       
   103         fs = map term fs, R = term R, psimps = fact psimps,
       
   104         pinducts = fact pinducts, simps = Option.map fact simps,
       
   105         inducts = Option.map fact inducts, termination = thm termination,
       
   106         defname = name defname, is_partial=is_partial }
       
   107     end
       
   108 
       
   109 structure FunctionData = Generic_Data
       
   110 (
       
   111   type T = (term * info) Item_Net.T;
       
   112   val empty : T = Item_Net.init (op aconv o pairself fst) (single o fst);
       
   113   val extend = I;
       
   114   fun merge tabs : T = Item_Net.merge tabs;
       
   115 )
       
   116 
       
   117 val get_function = FunctionData.get o Context.Proof;
       
   118 
       
   119 
       
   120 fun lift_morphism thy f =
       
   121   let
       
   122     val term = Drule.term_rule thy f
       
   123   in
       
   124     Morphism.thm_morphism f $> Morphism.term_morphism term
       
   125     $> Morphism.typ_morphism (Logic.type_map term)
       
   126   end
       
   127 
       
   128 fun import_function_data t ctxt =
       
   129   let
       
   130     val thy = Proof_Context.theory_of ctxt
       
   131     val ct = cterm_of thy t
       
   132     val inst_morph = lift_morphism thy o Thm.instantiate
       
   133 
       
   134     fun match (trm, data) =
       
   135       SOME (morph_function_data data (inst_morph (Thm.match (cterm_of thy trm, ct))))
       
   136       handle Pattern.MATCH => NONE
       
   137   in
       
   138     get_first match (Item_Net.retrieve (get_function ctxt) t)
       
   139   end
       
   140 
       
   141 fun import_last_function ctxt =
       
   142   case Item_Net.content (get_function ctxt) of
       
   143     [] => NONE
       
   144   | (t, data) :: _ =>
       
   145     let
       
   146       val ([t'], ctxt') = Variable.import_terms true [t] ctxt
       
   147     in
       
   148       import_function_data t' ctxt'
       
   149     end
       
   150 
       
   151 val all_function_data = Item_Net.content o get_function
       
   152 
       
   153 fun add_function_data (data : info as {fs, termination, ...}) =
       
   154   FunctionData.map (fold (fn f => Item_Net.update (f, data)) fs)
       
   155   #> store_termination_rule termination
       
   156 
       
   157 
       
   158 (* Simp rules for termination proofs *)
       
   159 
       
   160 structure Termination_Simps = Named_Thms
       
   161 (
       
   162   val name = "termination_simp"
       
   163   val description = "simplification rules for termination proofs"
       
   164 )
       
   165 
       
   166 
       
   167 (* Default Termination Prover *)
       
   168 
       
   169 structure TerminationProver = Generic_Data
       
   170 (
       
   171   type T = Proof.context -> tactic
       
   172   val empty = (fn _ => error "Termination prover not configured")
       
   173   val extend = I
       
   174   fun merge (a, _) = a
       
   175 )
       
   176 
       
   177 val set_termination_prover = TerminationProver.put
       
   178 val get_termination_prover = TerminationProver.get o Context.Proof
       
   179 
       
   180 
       
   181 (* Configuration management *)
       
   182 datatype nominal_function_opt
       
   183   = Sequential
       
   184   | Default of string
       
   185   | DomIntros
       
   186   | No_Partials
       
   187   | Invariant of string
       
   188 
       
   189 datatype nominal_function_config = NominalFunctionConfig of
       
   190  {sequential: bool,
       
   191   default: string option,
       
   192   domintros: bool,
       
   193   partials: bool,
       
   194   inv: string option}
       
   195 
       
   196 fun apply_opt Sequential (NominalFunctionConfig {sequential, default, domintros, partials, inv}) =
       
   197     NominalFunctionConfig 
       
   198       {sequential=true, default=default, domintros=domintros, partials=partials, inv=inv}
       
   199   | apply_opt (Default d) (NominalFunctionConfig {sequential, default, domintros, partials, inv}) =
       
   200     NominalFunctionConfig 
       
   201       {sequential=sequential, default=SOME d, domintros=domintros, partials=partials, inv=inv}
       
   202   | apply_opt DomIntros (NominalFunctionConfig {sequential, default, domintros, partials, inv}) =
       
   203     NominalFunctionConfig 
       
   204       {sequential=sequential, default=default, domintros=true, partials=partials, inv=inv}
       
   205   | apply_opt No_Partials (NominalFunctionConfig {sequential, default, domintros, partials, inv}) =
       
   206     NominalFunctionConfig 
       
   207       {sequential=sequential, default=default, domintros=domintros, partials=false, inv=inv}
       
   208   | apply_opt (Invariant s) (NominalFunctionConfig {sequential, default, domintros, partials, inv}) =
       
   209     NominalFunctionConfig 
       
   210       {sequential=sequential, default=default, domintros=domintros, partials=partials, inv = SOME s}
       
   211 
       
   212 val nominal_default_config =
       
   213   NominalFunctionConfig { sequential=false, default=NONE,
       
   214     domintros=false, partials=true, inv=NONE}
       
   215 
       
   216 
       
   217 (* Analyzing function equations *)
       
   218 
       
   219 fun split_def ctxt check_head geq =
       
   220   let
       
   221     fun input_error msg = cat_lines [msg, Syntax.string_of_term ctxt geq]
       
   222     val qs = Term.strip_qnt_vars "all" geq
       
   223     val imp = Term.strip_qnt_body "all" geq
       
   224     val (gs, eq) = Logic.strip_horn imp
       
   225 
       
   226     val (f_args, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
       
   227       handle TERM _ => error (input_error "Not an equation")
       
   228 
       
   229     val (head, args) = strip_comb f_args
       
   230 
       
   231     val fname = fst (dest_Free head) handle TERM _ => ""
       
   232     val _ = check_head fname
       
   233   in
       
   234     (fname, qs, gs, args, rhs)
       
   235   end
       
   236 
       
   237 (* Check for all sorts of errors in the input *)
       
   238 fun check_defs ctxt fixes eqs =
       
   239   let
       
   240     val fnames = map (fst o fst) fixes
       
   241 
       
   242     fun check geq =
       
   243       let
       
   244         fun input_error msg = error (cat_lines [msg, Syntax.string_of_term ctxt geq])
       
   245 
       
   246         fun check_head fname =
       
   247           member (op =) fnames fname orelse
       
   248           input_error ("Illegal equation head. Expected " ^ commas_quote fnames)
       
   249 
       
   250         val (fname, qs, gs, args, rhs) = split_def ctxt check_head geq
       
   251 
       
   252         val _ = length args > 0 orelse input_error "Function has no arguments:"
       
   253 
       
   254         fun add_bvs t is = add_loose_bnos (t, 0, is)
       
   255         val rvs = (subtract (op =) (fold add_bvs args []) (add_bvs rhs []))
       
   256                     |> map (fst o nth (rev qs))
       
   257 
       
   258         val _ = null rvs orelse input_error
       
   259           ("Variable" ^ plural " " "s " rvs ^ commas_quote rvs ^
       
   260            " occur" ^ plural "s" "" rvs ^ " on right hand side only:")
       
   261 
       
   262         val _ = forall (not o Term.exists_subterm
       
   263           (fn Free (n, _) => member (op =) fnames n | _ => false)) (gs @ args)
       
   264           orelse input_error "Defined function may not occur in premises or arguments"
       
   265 
       
   266         val freeargs = map (fn t => subst_bounds (rev (map Free qs), t)) args
       
   267         val funvars = filter (fn q => exists (exists_subterm (fn (Free q') $ _ => q = q' | _ => false)) freeargs) qs
       
   268         val _ = null funvars orelse (warning (cat_lines
       
   269           ["Bound variable" ^ plural " " "s " funvars ^
       
   270           commas_quote (map fst funvars) ^ " occur" ^ plural "s" "" funvars ^
       
   271           " in function position.", "Misspelled constructor???"]); true)
       
   272       in
       
   273         (fname, length args)
       
   274       end
       
   275 
       
   276     val grouped_args = AList.group (op =) (map check eqs)
       
   277     val _ = grouped_args
       
   278       |> map (fn (fname, ars) =>
       
   279         length (distinct (op =) ars) = 1
       
   280         orelse error ("Function " ^ quote fname ^
       
   281           " has different numbers of arguments in different equations"))
       
   282 
       
   283     val not_defined = subtract (op =) (map fst grouped_args) fnames
       
   284     val _ = null not_defined
       
   285       orelse error ("No defining equations for function" ^
       
   286         plural " " "s " not_defined ^ commas_quote not_defined)
       
   287 
       
   288     fun check_sorts ((fname, fT), _) =
       
   289       Sorts.of_sort (Sign.classes_of (Proof_Context.theory_of ctxt)) (fT, HOLogic.typeS)
       
   290       orelse error (cat_lines
       
   291       ["Type of " ^ quote fname ^ " is not of sort " ^ quote "type" ^ ":",
       
   292        Syntax.string_of_typ (Config.put show_sorts true ctxt) fT])
       
   293 
       
   294     val _ = map check_sorts fixes
       
   295   in
       
   296     ()
       
   297   end
       
   298 
       
   299 (* Preprocessors *)
       
   300 
       
   301 type fixes = ((string * typ) * mixfix) list
       
   302 type 'a spec = (Attrib.binding * 'a list) list
       
   303 type preproc = nominal_function_config -> Proof.context -> fixes -> term spec ->
       
   304   (term list * (thm list -> thm spec) * (thm list -> thm list list) * string list)
       
   305 
       
   306 val fname_of = fst o dest_Free o fst o strip_comb o fst o HOLogic.dest_eq o
       
   307   HOLogic.dest_Trueprop o Logic.strip_imp_concl o snd o dest_all_all
       
   308 
       
   309 fun mk_case_names i "" k = mk_case_names i (string_of_int (i + 1)) k
       
   310   | mk_case_names _ n 0 = []
       
   311   | mk_case_names _ n 1 = [n]
       
   312   | mk_case_names _ n k = map (fn i => n ^ "_" ^ string_of_int i) (1 upto k)
       
   313 
       
   314 fun empty_preproc check _ ctxt fixes spec =
       
   315   let
       
   316     val (bnds, tss) = split_list spec
       
   317     val ts = flat tss
       
   318     val _ = check ctxt fixes ts
       
   319     val fnames = map (fst o fst) fixes
       
   320     val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) ts
       
   321 
       
   322     fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) 
       
   323       (indices ~~ xs) |> map (map snd)
       
   324 
       
   325     (* using theorem names for case name currently disabled *)
       
   326     val cnames = map_index (fn (i, _) => mk_case_names i "" 1) bnds |> flat
       
   327   in
       
   328     (ts, curry op ~~ bnds o Library.unflat tss, sort, cnames)
       
   329   end
       
   330 
       
   331 structure Preprocessor = Generic_Data
       
   332 (
       
   333   type T = preproc
       
   334   val empty : T = empty_preproc check_defs
       
   335   val extend = I
       
   336   fun merge (a, _) = a
       
   337 )
       
   338 
       
   339 val get_preproc = Preprocessor.get o Context.Proof
       
   340 val set_preproc = Preprocessor.map o K
       
   341 
       
   342 
       
   343 
       
   344 local
       
   345   val option_parser = Parse.group "option"
       
   346     ((Parse.reserved "sequential" >> K Sequential)
       
   347      || ((Parse.reserved "default" |-- Parse.term) >> Default)
       
   348      || (Parse.reserved "domintros" >> K DomIntros)
       
   349      || (Parse.reserved "no_partials" >> K No_Partials))
       
   350 
       
   351   fun config_parser default =
       
   352     (Scan.optional (Parse.$$$ "(" |-- Parse.!!! (Parse.list1 option_parser) --| Parse.$$$ ")") [])
       
   353      >> (fn opts => fold apply_opt opts default)
       
   354 in
       
   355   fun function_parser default_cfg =
       
   356       config_parser default_cfg -- Parse.fixes -- Parse_Spec.where_alt_specs
       
   357 end
       
   358 
       
   359 
       
   360 end
       
   361 end