Nominal-General/nominal_permeq.ML
author Christian Urban <urbanc@in.tum.de>
Sun, 25 Apr 2010 09:13:16 +0200
changeset 1947 51f411b1197d
parent 1866 6d4e4bf9bce6
child 2064 2725853f43b9
permissions -rw-r--r--
tuned and cleaned

(*  Title:      nominal_permeq.ML
    Author:     Christian Urban
    Author:     Brian Huffman
*)

signature NOMINAL_PERMEQ =
sig
  val eqvt_tac: Proof.context -> thm list -> string list -> int -> tactic
  val eqvt_strict_tac: Proof.context -> thm list -> string list -> int -> tactic
  
  val perm_simp_meth: thm list * string list -> Proof.context -> Method.method
  val perm_strict_simp_meth: thm list * string list -> Proof.context -> Method.method
  val args_parser: (thm list * string list) context_parser

  val trace_eqvt: bool Config.T
  val setup: theory -> theory
end;

(* 

- eqvt_tac and eqvt_strict_tac take a list of theorems
  which are first tried to simplify permutations

  the string list contains constants that should not be
  analysed (for example there is no raw eqvt-lemma for
  the constant The; therefore it should not be analysed 

- setting [[trace_eqvt = true]] switches on tracing 
  information  


TODO:

 - provide a proper parser for the method (see Nominal2_Eqvt)
   
 - proably the list of bad constant should be a dataslot

*)

structure Nominal_Permeq: NOMINAL_PERMEQ =
struct

open Nominal_ThmDecls;

(* tracing infrastructure *)
val (trace_eqvt, trace_eqvt_setup) = Attrib.config_bool "trace_eqvt" (K false);

fun trace_enabled ctxt = Config.get ctxt trace_eqvt

fun trace_msg ctxt result = 
let
  val lhs_str = Syntax.string_of_term ctxt (term_of (Thm.lhs_of result))
  val rhs_str = Syntax.string_of_term ctxt (term_of (Thm.rhs_of result))
in
  warning (Pretty.string_of (Pretty.strs ["Rewriting", lhs_str, "to", rhs_str]))
end

fun trace_conv ctxt conv ctrm =
let
  val result = conv ctrm
in
  if Thm.is_reflexive result 
  then result
  else (trace_msg ctxt result; result)
end

(* this conversion always fails, but prints 
   out the analysed term  *)
fun trace_info_conv ctxt ctrm = 
let
  val trm = term_of ctrm
  val _ = case (head_of trm) of 
      @{const "Trueprop"} => ()
    | _ => warning ("Analysing term " ^ Syntax.string_of_term ctxt trm)
in
  Conv.no_conv ctrm
end

(* conversion for applications: 
   only applies the conversion, if the head of the
   application is not a "bad head" *)
fun has_bad_head bad_hds trm = 
  case (head_of trm) of 
    Const (s, _) => member (op=) bad_hds s 
  | _ => false 

fun eqvt_apply_conv bad_hds ctrm =
  case (term_of ctrm) of
    Const (@{const_name "permute"}, _) $ _ $ (trm $ _) =>
      let
        val (perm, t) = Thm.dest_comb ctrm
        val (_, p) = Thm.dest_comb perm
        val (f, x) = Thm.dest_comb t
        val a = ctyp_of_term x;
        val b = ctyp_of_term t;
        val ty_insts = map SOME [b, a]
        val term_insts = map SOME [p, f, x]                        
      in
        if has_bad_head bad_hds trm
        then Conv.no_conv ctrm
        else Drule.instantiate' ty_insts term_insts @{thm eqvt_apply}
      end
  | _ => Conv.no_conv ctrm

(* conversion for lambdas *)
fun eqvt_lambda_conv ctrm =
  case (term_of ctrm) of
    Const (@{const_name "permute"}, _) $ _ $ (Abs _) =>
      Conv.rewr_conv @{thm eqvt_lambda} ctrm
  | _ => Conv.no_conv ctrm

(* conversion that raises an error or prints a warning message, 
   if a permutation on a constant or application cannot be analysed *)
fun progress_info_conv ctxt strict_flag bad_hds ctrm =
let
  fun msg trm =
    let
      val hd = head_of trm 
    in 
    if is_Const hd andalso (fst (dest_Const hd)) mem bad_hds then ()
    else (if strict_flag then error else warning) 
           ("Cannot solve equivariance for " ^ (Syntax.string_of_term ctxt trm))
    end

  val _ = case (term_of ctrm) of
      Const (@{const_name "permute"}, _) $ _ $ (trm as Const _) => msg trm
    | Const (@{const_name "permute"}, _) $ _ $ (trm as _ $ _) => msg trm
    | _ => () 
in
  Conv.all_conv ctrm 
end

(* main conversion *)
fun eqvt_conv ctxt strict_flag user_thms bad_hds ctrm =
let
  val first_conv_wrapper = 
    if trace_enabled ctxt 
    then Conv.first_conv o (cons (trace_info_conv ctxt)) o (map (trace_conv ctxt))
    else Conv.first_conv

  val pre_thms = map safe_mk_equiv user_thms @ @{thms eqvt_bound} @ get_eqvts_raw_thms ctxt
  val post_thms = map safe_mk_equiv @{thms permute_pure}
in
  first_conv_wrapper
    [ More_Conv.rewrs_conv pre_thms,
      eqvt_apply_conv bad_hds,
      eqvt_lambda_conv,
      More_Conv.rewrs_conv post_thms,
      progress_info_conv ctxt strict_flag bad_hds
    ] ctrm
end

(* raises an error if some permutations cannot be eliminated *)
fun eqvt_strict_tac ctxt user_thms bad_hds = 
  CONVERSION (More_Conv.top_conv (fn ctxt => eqvt_conv ctxt true user_thms bad_hds) ctxt)

(* prints a warning message if some permutations cannot be eliminated *)
fun eqvt_tac ctxt user_thms bad_hds = 
  CONVERSION (More_Conv.top_conv (fn ctxt => eqvt_conv ctxt false user_thms bad_hds) ctxt)

(* setup of the configuration value *)
val setup =
  trace_eqvt_setup


(** methods **)

val add_thms_parser = Scan.optional (Scan.lift (Args.add -- Args.colon) |-- Attrib.thms) [];
val exclude_consts_parser = Scan.optional (Scan.lift ((Args.$$$ "exclude") -- Args.colon) |-- 
  (Scan.repeat (Args.const true))) []

val args_parser =  
  add_thms_parser -- exclude_consts_parser ||
  exclude_consts_parser -- add_thms_parser >> swap

fun perm_simp_meth (thms, consts) ctxt = 
  SIMPLE_METHOD (HEADGOAL (eqvt_tac ctxt thms consts))

fun perm_strict_simp_meth (thms, consts) ctxt = 
  SIMPLE_METHOD (HEADGOAL (eqvt_strict_tac ctxt thms consts))

end; (* structure *)