Various changes to support Nominal2 commands in local contexts.
authorwebertj
Wed, 27 Mar 2013 16:08:30 +0100
changeset 3214 13ab4f0a0b0e
parent 3213 a8724924a62e
child 3215 3cfd4fc42840
Various changes to support Nominal2 commands in local contexts.
Nominal/Nominal2.thy
Nominal/Nominal2_Abs.thy
Nominal/Nominal2_Base.thy
Nominal/nominal_basics.ML
Nominal/nominal_eqvt.ML
Nominal/nominal_inductive.ML
Nominal/nominal_library.ML
Nominal/nominal_thmdecls.ML
--- a/Nominal/Nominal2.thy	Tue Mar 26 16:41:31 2013 +0100
+++ b/Nominal/Nominal2.thy	Wed Mar 27 16:08:30 2013 +0100
@@ -281,7 +281,7 @@
     let
       val AlphaResult {alpha_trms, alpha_bn_trms, alpha_raw_induct, alpha_intros, ...} = alpha_result
     in
-      Nominal_Eqvt.raw_equivariance (alpha_trms @ alpha_bn_trms) alpha_raw_induct alpha_intros lthy5
+      Nominal_Eqvt.raw_equivariance lthy5 (alpha_trms @ alpha_bn_trms) alpha_raw_induct alpha_intros
     end
 
   val alpha_eqvt_norm = map (Nominal_ThmDecls.eqvt_transform lthy5) alpha_eqvt
@@ -445,7 +445,7 @@
     |> map (simplify (HOL_basic_ss addsimps qfv_supp_thms))
     |> map (simplify (HOL_basic_ss addsimps @{thms prod_fv_supp prod_alpha_eq Abs_eq_iff[symmetric]}))
 
-  (* filters the theormes that are of the form "qfv = supp" *)
+  (* filters the theorems that are of the form "qfv = supp" *)
   fun is_qfv_thm (@{term Trueprop} $ (Const (@{const_name HOL.eq}, _) $ lhs $ _)) = member (op=) qfvs lhs
   | is_qfv_thm _ = false
 
@@ -742,8 +742,4 @@
     (main_parser >> nominal_datatype2_cmd)
 *}
 
-
 end
-
-
-
--- a/Nominal/Nominal2_Abs.thy	Tue Mar 26 16:41:31 2013 +0100
+++ b/Nominal/Nominal2_Abs.thy	Wed Mar 27 16:08:30 2013 +0100
@@ -9,7 +9,7 @@
 section {* Abstractions *}
 
 fun
-  alpha_set 
+  alpha_set
 where
   alpha_set[simp del]:
   "alpha_set (bs, x) R f p (cs, y) \<longleftrightarrow> 
@@ -1079,18 +1079,11 @@
 lemma prod_fv_supp:
   shows "prod_fv supp supp = supp"
 by (rule ext)
-   (auto simp add: prod_fv.simps supp_Pair)
+   (auto simp add: supp_Pair)
 
 lemma prod_alpha_eq:
   shows "prod_alpha (op=) (op=) = (op=)"
   unfolding prod_alpha_def
   by (auto intro!: ext)
 
-
-
-
-
-
-
 end
-
--- a/Nominal/Nominal2_Base.thy	Tue Mar 26 16:41:31 2013 +0100
+++ b/Nominal/Nominal2_Base.thy	Wed Mar 27 16:08:30 2013 +0100
@@ -769,7 +769,7 @@
 
 subsection {* Eqvt infrastructure *}
 
-text {* Setup of the theorem attributes @{text eqvt} and @{text eqvt_raw} *}
+text {* Setup of the theorem attributes @{text eqvt} and @{text eqvt_raw}. *}
 
 ML_file "nominal_thmdecls.ML"
 setup "Nominal_ThmDecls.setup"
@@ -1056,11 +1056,11 @@
   given.
 *}
 
-class le_eqvt = ord +
-  assumes le_eqvt [eqvt]: "p \<bullet> (x \<le> y) = ((p \<bullet> x) \<le> (p \<bullet> (y::('a::{order, pt}))))"
+class le_eqvt = order +
+  assumes le_eqvt [eqvt]: "p \<bullet> (x \<le> y) = ((p \<bullet> x) \<le> (p \<bullet> (y::('a::{pt, order}))))"
 
 class inf_eqvt = Inf +
-  assumes inf_eqvt [eqvt]: "p \<bullet> (Inf X) = Inf (p \<bullet> (X::'a::{complete_lattice, pt} set))"
+  assumes inf_eqvt [eqvt]: "p \<bullet> (Inf X) = Inf (p \<bullet> (X::('a::{pt, complete_lattice}) set))"
 
 instantiation bool :: le_eqvt
 begin
@@ -3228,13 +3228,11 @@
   (@{const_name "atom"}, SOME @{typ "'a::at_base \<Rightarrow> atom"}) *}
 
 
-
 section {* Library functions for the nominal infrastructure *}
 
 ML_file "nominal_library.ML"
 
 
-
 section {* The freshness lemma according to Andy Pitts *}
 
 lemma freshness_lemma:
@@ -3417,18 +3415,16 @@
   shows "(FRESH x. P x \<or> Q x) \<longleftrightarrow> (FRESH x. P x) \<or> (FRESH x. Q x)"
 using P Q by (rule FRESH_binop_iff)
 
+
 section {* Automation for creating concrete atom types *}
 
-text {* at the moment only single-sort concrete atoms are supported *}
+text {* At the moment only single-sort concrete atoms are supported. *}
 
 ML_file "nominal_atoms.ML"
 
 
-section {* automatic equivariance procedure for inductive definitions *}
+section {* Automatic equivariance procedure for inductive definitions *}
 
 ML_file "nominal_eqvt.ML"
 
-
-
-
 end
--- a/Nominal/nominal_basics.ML	Tue Mar 26 16:41:31 2013 +0100
+++ b/Nominal/nominal_basics.ML	Wed Mar 27 16:08:30 2013 +0100
@@ -1,5 +1,6 @@
-(*  Title:      nominal_basic.ML
+(*  Title:      nominal_basics.ML
     Author:     Christian Urban
+    Author:     Tjark Weber
 
   Basic functions for nominal.
 *)
@@ -23,9 +24,9 @@
   val map4: ('a -> 'b -> 'c -> 'd -> 'e) -> 'a list -> 'b list -> 'c list -> 'd list -> 'e list
   val split_filter: ('a -> bool) -> 'a list -> 'a list * 'a list
   val fold_left: ('a * 'a -> 'a) -> 'a list -> 'a -> 'a
-  
+
   val is_true: term -> bool
- 
+
   val dest_listT: typ -> typ
   val dest_fsetT: typ -> typ
 
@@ -49,6 +50,10 @@
   val mk_perm: term -> term -> term
   val dest_perm: term -> term * term
 
+  (* functions to deal with constants in local contexts *)
+  val long_name: Proof.context -> string -> string
+  val is_fixed: Proof.context -> term -> bool
+  val fixed_nonfixed_args: Proof.context -> term -> term * term list
 end
 
 
@@ -126,17 +131,14 @@
   | fold_left f [x] z = x
   | fold_left f (x :: y :: xs) z = fold_left f (f (x, y) :: xs) z
 
-
-
 fun is_true @{term "Trueprop True"} = true
-  | is_true _ = false 
+  | is_true _ = false
 
 fun dest_listT (Type (@{type_name list}, [T])) = T
   | dest_listT T = raise TYPE ("dest_listT: list type expected", [T], [])
 
 fun dest_fsetT (Type (@{type_name fset}, [T])) = T
-  | dest_fsetT T = raise TYPE ("dest_fsetT: fset type expected", [T], []);
-
+  | dest_fsetT T = raise TYPE ("dest_fsetT: fset type expected", [T], [])
 
 fun mk_id trm = HOLogic.id_const (fastype_of trm) $ trm
 
@@ -146,7 +148,7 @@
 
 fun mk_exists (a, T) t =  HOLogic.exists_const T $ Abs (a, T, t)
 
-fun sum_case_const ty1 ty2 ty3 = 
+fun sum_case_const ty1 ty2 ty3 =
   Const (@{const_name sum_case}, [ty1 --> ty3, ty2 --> ty3, Type (@{type_name sum}, [ty1, ty2])] ---> ty3)
 
 fun mk_sum_case trm1 trm2 =
@@ -155,12 +157,10 @@
     val ty2 = domain_type (fastype_of trm2)
   in
     sum_case_const ty1 ty2 ty3 $ trm1 $ trm2
-  end 
-
+  end
 
-fun mk_equiv r = r RS @{thm eq_reflection};
-fun safe_mk_equiv r = mk_equiv r handle Thm.THM _ => r;
-
+fun mk_equiv r = r RS @{thm eq_reflection}
+fun safe_mk_equiv r = mk_equiv r handle Thm.THM _ => r
 
 fun mk_minus p = @{term "uminus::perm => perm"} $ p
 
@@ -172,7 +172,53 @@
 fun mk_perm p trm = mk_perm_ty (fastype_of trm) p trm
 
 fun dest_perm (Const (@{const_name "permute"}, _) $ p $ t) = (p, t)
-  | dest_perm t = raise TERM ("dest_perm", [t]);
+  | dest_perm t = raise TERM ("dest_perm", [t])
+
+
+(** functions to deal with constants in local contexts **)
+
+(* returns the fully qualified name of a constant *)
+fun long_name ctxt name =
+  case head_of (Syntax.read_term ctxt name) of
+    Const (s, _) => s
+  | _ => error ("Undeclared constant: " ^ quote name)
+
+(* returns true iff the argument term is a fixed Free *)
+fun is_fixed_Free ctxt (Free (s, _)) = Variable.is_fixed ctxt s
+  | is_fixed_Free _ _ = false
+
+(* returns true iff c is a constant or fixed Free applied to
+   fixed parameters *)
+fun is_fixed ctxt c =
+  let
+    val (c, args) = strip_comb c
+  in
+    (is_Const c orelse is_fixed_Free ctxt c)
+      andalso List.all (is_fixed_Free ctxt) args
+  end
+
+(* splits a list into the longest prefix containing only elements
+   that satisfy p, and the rest of the list *)
+fun chop_while p =
+  let
+    fun chop_while_aux acc [] =
+      (rev acc, [])
+      | chop_while_aux acc (x::xs) =
+      if p x then chop_while_aux (x::acc) xs else (rev acc, x::xs)
+  in
+    chop_while_aux []
+  end
+
+(* takes a combination "c $ fixed1 $ ... $ fixedN $ not-fixed $ ..."
+   to the pair ("c $ fixed1 $ ... $ fixedN", ["not-fixed", ...]). *)
+fun fixed_nonfixed_args ctxt c_args =
+  let
+    val (c, args)     = strip_comb c_args
+    val (frees, args) = chop_while (is_fixed_Free ctxt) args
+    val c_frees       = list_comb (c, frees)
+  in
+    (c_frees, args)
+  end
 
 end (* structure *)
 
--- a/Nominal/nominal_eqvt.ML	Tue Mar 26 16:41:31 2013 +0100
+++ b/Nominal/nominal_eqvt.ML	Wed Mar 27 16:08:30 2013 +0100
@@ -1,14 +1,14 @@
 (*  Title:      nominal_eqvt.ML
     Author:     Stefan Berghofer (original code)
     Author:     Christian Urban
+    Author:     Tjark Weber
 
     Automatic proofs for equivariance of inductive predicates.
 *)
 
-
 signature NOMINAL_EQVT =
 sig
-  val raw_equivariance: term list -> thm -> thm list -> Proof.context -> thm list
+  val raw_equivariance: Proof.context -> term list -> thm -> thm list -> thm list
   val equivariance_cmd: string -> Proof.context -> local_theory
 end
 
@@ -18,24 +18,26 @@
 open Nominal_Permeq;
 open Nominal_ThmDecls;
 
-val atomize_conv = 
+val atomize_conv =
   Raw_Simplifier.rewrite_cterm (true, false, false) (K (K NONE))
-    (HOL_basic_ss addsimps @{thms induct_atomize});
-val atomize_intr = Conv.fconv_rule (Conv.prems_conv ~1 atomize_conv);
+    (HOL_basic_ss addsimps @{thms induct_atomize})
+
+val atomize_intr = Conv.fconv_rule (Conv.prems_conv ~1 atomize_conv)
+
 fun atomize_induct ctxt = Conv.fconv_rule (Conv.prems_conv ~1
-  (Conv.params_conv ~1 (K (Conv.prems_conv ~1 atomize_conv)) ctxt));
+  (Conv.params_conv ~1 (K (Conv.prems_conv ~1 atomize_conv)) ctxt))
 
 
 (** equivariance tactics **)
 
-fun eqvt_rel_single_case_tac ctxt pred_names pi intro  = 
+fun eqvt_rel_single_case_tac ctxt pred_names pi intro =
   let
     val thy = Proof_Context.theory_of ctxt
     val cpi = Thm.cterm_of thy pi
     val pi_intro_rule = Drule.instantiate' [] [NONE, SOME cpi] @{thm permute_boolI}
     val eqvt_sconfig = eqvt_strict_config addexcls pred_names
     val simps1 = HOL_basic_ss addsimps @{thms permute_fun_def permute_self split_paired_all}
-    val simps2 = HOL_basic_ss addsimps @{thms permute_bool_def  permute_minus_cancel(2)}
+    val simps2 = HOL_basic_ss addsimps @{thms permute_bool_def permute_minus_cancel(2)}
   in
     eqvt_tac ctxt eqvt_sconfig THEN'
     SUBPROOF (fn {prems, context as ctxt, ...} =>
@@ -43,9 +45,8 @@
         val prems' = map (transform_prem2 ctxt pred_names) prems
         val prems'' = map (fn thm => eqvt_rule ctxt eqvt_sconfig (thm RS pi_intro_rule)) prems'
         val prems''' = map (simplify simps2 o simplify simps1) prems''
-
       in
-        HEADGOAL (rtac intro THEN_ALL_NEW resolve_tac (prems' @ prems'' @ prems''')) 
+        HEADGOAL (rtac intro THEN_ALL_NEW resolve_tac (prems' @ prems'' @ prems'''))
       end) ctxt
   end
 
@@ -57,78 +58,88 @@
   end
 
 
-(** equivariance procedure *)
+(** equivariance procedure **)
 
-fun prepare_goal pi pred =
+fun prepare_goal ctxt pi pred_with_args =
   let
-    val (c, xs) = strip_comb pred;
+    val (c, xs) = strip_comb pred_with_args
+    fun is_nonfixed_Free (Free (s, _)) = not (Variable.is_fixed ctxt s)
+      | is_nonfixed_Free _ = false
+    fun mk_perm_nonfixed_Free t =
+      if is_nonfixed_Free t then mk_perm pi t else t
   in
-    HOLogic.mk_imp (pred, list_comb (c, map (mk_perm pi) xs))
+    HOLogic.mk_imp (pred_with_args,
+      list_comb (c, map mk_perm_nonfixed_Free xs))
   end
 
-(* stores thm under name.eqvt and adds [eqvt]-attribute *)
+fun name_of (Const (s, _)) = s
 
-fun get_name (Const (a, _)) = a
-  | get_name (Free  (a, _)) = a
-
-fun raw_equivariance pred_trms raw_induct intrs ctxt = 
+fun raw_equivariance ctxt preds raw_induct intrs =
   let
-    val is_already_eqvt = 
-      filter (is_eqvt ctxt) pred_trms
-      |> map (Syntax.string_of_term ctxt)
+    (* FIXME: polymorphic predicates should either be rejected or
+              specialized to arguments of sort pt *)
+
+    val is_already_eqvt = filter (is_eqvt ctxt) preds
     val _ = if null is_already_eqvt then ()
-      else error ("Already equivariant: " ^ commas is_already_eqvt)
+      else error ("Already equivariant: " ^ commas
+        (map (Syntax.string_of_term ctxt) is_already_eqvt))
 
-    val pred_names = map get_name pred_trms
+    val pred_names = map (name_of o head_of) preds
     val raw_induct' = atomize_induct ctxt raw_induct
     val intrs' = map atomize_intr intrs
-  
-    val (([raw_concl], [raw_pi]), ctxt') = 
-      ctxt 
-      |> Variable.import_terms false [concl_of raw_induct'] 
+
+    val (([raw_concl], [raw_pi]), ctxt') =
+      ctxt
+      |> Variable.import_terms false [concl_of raw_induct']
       ||>> Variable.variant_fixes ["p"]
     val pi = Free (raw_pi, @{typ perm})
-  
-    val preds = map (fst o HOLogic.dest_imp)
-      (HOLogic.dest_conj (HOLogic.dest_Trueprop raw_concl));
-  
-    val goal = HOLogic.mk_Trueprop 
-      (foldr1 HOLogic.mk_conj (map (prepare_goal pi) preds)) 
-  in 
-    Goal.prove ctxt' [] [] goal 
-      (fn {context,...} => eqvt_rel_tac context pred_names pi raw_induct' intrs' 1)
-      |> Datatype_Aux.split_conj_thm 
+
+    val preds_with_args = raw_concl
+      |> HOLogic.dest_Trueprop
+      |> HOLogic.dest_conj
+      |> map (fst o HOLogic.dest_imp)
+
+    val goal = preds_with_args
+      |> map (prepare_goal ctxt pi)
+      |> foldr1 HOLogic.mk_conj
+      |> HOLogic.mk_Trueprop
+  in
+    Goal.prove ctxt' [] [] goal
+      (fn {context, ...} => eqvt_rel_tac context pred_names pi raw_induct' intrs' 1)
+      |> Datatype_Aux.split_conj_thm
       |> Proof_Context.export ctxt' ctxt
       |> map (fn th => th RS mp)
       |> map zero_var_indexes
   end
 
 
-fun note_named_thm (name, thm) ctxt = 
+(** stores thm under name.eqvt and adds [eqvt]-attribute **)
+
+fun note_named_thm (name, thm) ctxt =
   let
-    val thm_name = Binding.qualified_name 
+    val thm_name = Binding.qualified_name
       (Long_Name.qualify (Long_Name.base_name name) "eqvt")
     val attr = Attrib.internal (K eqvt_add)
-    val ((_, [thm']), ctxt') =  Local_Theory.note ((thm_name, [attr]), [thm]) ctxt
+    val ((_, [thm']), ctxt') = Local_Theory.note ((thm_name, [attr]), [thm]) ctxt
   in
     (thm', ctxt')
   end
 
+
+(** equivariance command **)
+
 fun equivariance_cmd pred_name ctxt =
   let
-    val thy = Proof_Context.theory_of ctxt
     val ({names, ...}, {preds, raw_induct, intrs, ...}) =
-      Inductive.the_inductive ctxt (Sign.intern_const thy pred_name)
-    val thms = raw_equivariance preds raw_induct intrs ctxt 
+      Inductive.the_inductive ctxt (long_name ctxt pred_name)
+    val thms = raw_equivariance ctxt preds raw_induct intrs
   in
     fold_map note_named_thm (names ~~ thms) ctxt |> snd
   end
 
-
 val _ =
   Outer_Syntax.local_theory @{command_spec "equivariance"}
-    "Proves equivariance for inductive predicate involving nominal datatypes." 
-      (Parse.xname >> equivariance_cmd)
-
+    "Proves equivariance for inductive predicate involving nominal datatypes."
+      (Parse.const >> equivariance_cmd)
 
 end (* structure *)
--- a/Nominal/nominal_inductive.ML	Tue Mar 26 16:41:31 2013 +0100
+++ b/Nominal/nominal_inductive.ML	Wed Mar 27 16:08:30 2013 +0100
@@ -1,5 +1,6 @@
 (*  Title:      nominal_inductive.ML
     Author:     Christian Urban
+    Author:     Tjark Weber
 
     Infrastructure for proving strong induction theorems
     for inductive predicates involving nominal datatypes.
@@ -7,27 +8,26 @@
     Code based on an earlier version by Stefan Berghofer.
 *)
 
-
 signature NOMINAL_INDUCTIVE =
 sig
-  val prove_strong_inductive: string list -> string list -> term list list -> thm -> thm list -> 
+  val prove_strong_inductive: string list -> string list -> term list list -> thm -> thm list ->
     Proof.context -> Proof.state
-
   val prove_strong_inductive_cmd: xstring * (string * string list) list -> Proof.context -> Proof.state
 end
 
 structure Nominal_Inductive : NOMINAL_INDUCTIVE =
 struct
 
-
-fun mk_cplus p q = Thm.apply (Thm.apply @{cterm "plus :: perm => perm => perm"} p) q 
+fun mk_cplus p q =
+  Thm.apply (Thm.apply @{cterm "plus :: perm => perm => perm"} p) q
 
-fun mk_cminus p = Thm.apply @{cterm "uminus :: perm => perm"} p 
+fun mk_cminus p =
+  Thm.apply @{cterm "uminus :: perm => perm"} p
 
-fun minus_permute_intro_tac p = 
+fun minus_permute_intro_tac p =
   rtac (Drule.instantiate' [] [SOME (mk_cminus p)] @{thm permute_boolE})
 
-fun minus_permute_elim p thm = 
+fun minus_permute_elim p thm =
   thm RS (Drule.instantiate' [] [NONE, SOME (mk_cminus p)] @{thm permute_boolI})
 
 (* fixme: move to nominal_library *)
@@ -36,25 +36,28 @@
   | real_head_of (Const (@{const_name all}, _) $ Abs (_, _, t)) = real_head_of t
   | real_head_of (Const (@{const_name All}, _) $ Abs (_, _, t)) = real_head_of t
   | real_head_of (Const ("HOL.induct_forall", _) $ Abs (_, _, t)) = real_head_of t
-  | real_head_of t = head_of t  
+  | real_head_of t = head_of t
 
 
-fun mk_vc_compat (avoid, avoid_trm) prems concl_args params = 
-  let
-    val vc_goal = concl_args
-      |> HOLogic.mk_tuple
-      |> mk_fresh_star avoid_trm 
-      |> HOLogic.mk_Trueprop
-      |> (curry Logic.list_implies) prems
-      |> fold_rev (Logic.all o Free) params
-    val finite_goal = avoid_trm
-      |> mk_finite
-      |> HOLogic.mk_Trueprop
-      |> (curry Logic.list_implies) prems
-      |> fold_rev (Logic.all o Free) params
-  in 
-    if null avoid then [] else [vc_goal, finite_goal]
-  end
+fun mk_vc_compat (avoid, avoid_trm) prems concl_args params =
+  if null avoid then
+    []
+  else
+    let
+      val vc_goal = concl_args
+        |> HOLogic.mk_tuple
+        |> mk_fresh_star avoid_trm
+        |> HOLogic.mk_Trueprop
+        |> (curry Logic.list_implies) prems
+        |> fold_rev (Logic.all o Free) params
+      val finite_goal = avoid_trm
+        |> mk_finite
+        |> HOLogic.mk_Trueprop
+        |> (curry Logic.list_implies) prems
+        |> fold_rev (Logic.all o Free) params
+    in
+      [vc_goal, finite_goal]
+    end
 
 (* fixme: move to nominal_library *)
 fun map_term prop f trm =
@@ -77,12 +80,15 @@
     |> (fn t => HOLogic.all_const @{typ perm} $  lambda p t)
   end
 
-fun induct_forall_const T = Const ("HOL.induct_forall", (T --> @{typ bool}) --> @{typ bool})
-fun mk_induct_forall (a, T) t =  induct_forall_const T $ Abs (a, T, t)
+fun induct_forall_const T =
+  Const ("HOL.induct_forall", (T --> @{typ bool}) --> @{typ bool})
+
+fun mk_induct_forall (a, T) t =
+  induct_forall_const T $ Abs (a, T, t)
 
 fun add_c_prop qnt Ps (c, c_name, c_ty) trm =
   let
-    fun add t = 
+    fun add t =
       let
         val (P, args) = strip_comb t
         val (P_name, P_ty) = dest_Free P
@@ -91,7 +97,7 @@
           |> qnt ? map (incr_boundvars 1)
       in
         list_comb (Free (P_name, (c_ty :: ty_args) ---> bool), c :: args')
-        |> qnt ? mk_induct_forall (c_name, c_ty)
+          |> qnt ? mk_induct_forall (c_name, c_ty)
       end
   in
     map_term (member (op =) Ps o head_of) add trm
@@ -100,7 +106,7 @@
 fun prep_prem Ps c_name c_ty (avoid, avoid_trm) (params, prems, concl) =
   let
     val prems' = prems
-      |> map (incr_boundvars 1) 
+      |> map (incr_boundvars 1)
       |> map (add_c_prop true Ps (Bound 0, c_name, c_ty))
 
     val avoid_trm' = avoid_trm
@@ -109,14 +115,14 @@
       |> (fn t => mk_fresh_star_ty c_ty t (Bound 0))
       |> HOLogic.mk_Trueprop
 
-    val prems'' = 
-      if null avoid 
-      then prems' 
+    val prems'' =
+      if null avoid
+      then prems'
       else avoid_trm' :: prems'
 
     val concl' = concl
-      |> incr_boundvars 1 
-      |> add_c_prop false Ps (Bound 0, c_name, c_ty)  
+      |> incr_boundvars 1
+      |> add_c_prop false Ps (Bound 0, c_name, c_ty)
   in
     mk_full_horn (params @ [(c_name, c_ty)]) prems'' concl'
   end
@@ -129,7 +135,7 @@
 
 (* fixme: move to nominal_library *)
 fun map7 _ [] [] [] [] [] [] [] = []
-  | map7 f (x :: xs) (y :: ys) (z :: zs) (u :: us) (v :: vs) (r :: rs) (s :: ss) = 
+  | map7 f (x :: xs) (y :: ys) (z :: zs) (u :: us) (v :: vs) (r :: rs) (s :: ss) =
       f x y z u v r s :: map7 f xs ys zs us vs rs ss
 
 (* local abbreviations *)
@@ -137,18 +143,20 @@
 local
   open Nominal_Permeq
 in
-(* by default eqvt_strict_config contains unwanted @{thm permute_pure} *) 
+
+  (* by default eqvt_strict_config contains unwanted @{thm permute_pure} *) 
 
-val eqvt_sconfig = eqvt_strict_config addpres @{thms permute_minus_cancel}
+  val eqvt_sconfig = eqvt_strict_config addpres @{thms permute_minus_cancel}
 
-fun eqvt_stac ctxt = eqvt_tac ctxt eqvt_sconfig
-fun eqvt_srule ctxt = eqvt_rule ctxt eqvt_sconfig
+  fun eqvt_stac ctxt = eqvt_tac ctxt eqvt_sconfig
+  fun eqvt_srule ctxt = eqvt_rule ctxt eqvt_sconfig
 
 end
 
 val all_elims = 
   let
-     fun spec' ct = Drule.instantiate' [SOME (ctyp_of_term ct)] [NONE, SOME ct] @{thm spec}
+    fun spec' ct =
+      Drule.instantiate' [SOME (ctyp_of_term ct)] [NONE, SOME ct] @{thm spec}
   in
     fold (fn ct => fn th => th RS spec' ct)
   end
@@ -216,6 +224,7 @@
 
 val supp_perm_eq' = @{lemma "fresh_star (supp (permute p x)) q ==> permute p x == permute (q + p) x" 
   by (simp add: supp_perm_eq)}
+
 val fresh_star_plus = @{lemma "fresh_star (permute q (permute p x)) c ==> fresh_star (permute (q + p) x) c" 
   by (simp add: permute_plus)}
 
@@ -321,8 +330,9 @@
       |> map Logic.strip_horn
       |> split_list
 
-    val intr_concls_args = map (snd o strip_comb o HOLogic.dest_Trueprop) intr_concls 
-      
+    val intr_concls_args =
+      map (snd o fixed_nonfixed_args ctxt' o HOLogic.dest_Trueprop) intr_concls
+
     val avoid_trms = avoids
       |> (map o map) (setify ctxt') 
       |> map fold_union
@@ -383,35 +393,33 @@
 
 fun prove_strong_inductive_cmd (pred_name, avoids) ctxt =
   let
-    val thy = Proof_Context.theory_of ctxt;
     val ({names, ...}, {raw_induct, intrs, ...}) =
-      Inductive.the_inductive ctxt (Sign.intern_const thy pred_name);
+      Inductive.the_inductive ctxt (long_name ctxt pred_name)
 
-    val rule_names = 
-      hd names
+    val rule_names = hd names
       |> the o Induct.lookup_inductP ctxt
       |> fst o Rule_Cases.get
       |> map (fst o fst)
 
-    val _ = (case duplicates (op = o pairself fst) avoids of
+    val case_names = map fst avoids
+    val _ = case duplicates (op =) case_names of
         [] => ()
-      | xs => error ("Duplicate case names: " ^ commas_quote (map fst xs)))
-
-    val _ = (case subtract (op =) rule_names (map fst avoids) of
+      | xs => error ("Duplicate case names: " ^ commas_quote xs)
+    val _ = case subtract (op =) rule_names case_names of
         [] => ()
-      | xs => error ("No such case(s) in inductive definition: " ^ commas_quote xs))
+      | xs => error ("No such case(s) in inductive definition: " ^ commas_quote xs)
 
     val avoids_ordered = order_default (op =) [] rule_names avoids
-      
+
     fun read_avoids avoid_trms intr =
       let
         (* fixme hack *)
         val (((_, ctrms), _), ctxt') = Variable.import true [intr] ctxt
         val trms = map (term_of o snd) ctrms
-        val ctxt'' = fold Variable.declare_term trms ctxt' 
+        val ctxt'' = fold Variable.declare_term trms ctxt'
       in
-        map (Syntax.read_term ctxt'') avoid_trms 
-      end 
+        map (Syntax.read_term ctxt'') avoid_trms
+      end
 
     val avoid_trms = map2 read_avoids avoids_ordered intrs
   in
@@ -420,11 +428,10 @@
 
 (* outer syntax *)
 local
-
-  val single_avoid_parser = 
+  val single_avoid_parser =
     Parse.name -- (@{keyword ":"} |-- Parse.and_list1 Parse.term)
 
-  val avoids_parser = 
+  val avoids_parser =
     Scan.optional (@{keyword "avoids"} |-- Parse.enum1 "|" single_avoid_parser) []
 
   val main_parser = Parse.xname -- avoids_parser
--- a/Nominal/nominal_library.ML	Tue Mar 26 16:41:31 2013 +0100
+++ b/Nominal/nominal_library.ML	Wed Mar 27 16:08:30 2013 +0100
@@ -182,25 +182,25 @@
     then mk_atom_fset_ty ty t
   else if is_atom_list ctxt ty
     then mk_atom_list_ty ty t
-  else raise TERM ("atomify", [t])
+  else raise TERM ("atomify: term is not an atom, set or list of atoms", [t])
 
 fun setify_ty ctxt ty t =
   if is_atom ctxt ty
-    then  HOLogic.mk_set @{typ atom} [mk_atom_ty ty t]
+    then HOLogic.mk_set @{typ atom} [mk_atom_ty ty t]
   else if is_atom_set ctxt ty
     then mk_atom_set_ty ty t
   else if is_atom_fset ctxt ty
     then @{term "fset :: atom fset => atom set"} $ mk_atom_fset_ty ty t
   else if is_atom_list ctxt ty
     then @{term "set :: atom list => atom set"} $ mk_atom_list_ty ty t
-  else raise TERM ("setify", [t])
+  else raise TERM ("setify: term is not an atom, set or list of atoms", [t])
 
 fun listify_ty ctxt ty t =
   if is_atom ctxt ty
     then HOLogic.mk_list @{typ atom} [mk_atom_ty ty t]
   else if is_atom_list ctxt ty
     then mk_atom_list_ty ty t
-  else raise TERM ("listify", [t])
+  else raise TERM ("listify: term is not an atom or list of atoms", [t])
 
 fun atomify ctxt t = atomify_ty ctxt (fastype_of t) t
 fun setify ctxt t  = setify_ty ctxt (fastype_of t) t
@@ -489,14 +489,13 @@
   map_thm ctxt (split_conj2 names) (etac conjunct2 1) thm
 
 
-(* transformes a theorem into one of the object logic *)
+(* transforms a theorem into one of the object logic *)
 val atomize = Conv.fconv_rule Object_Logic.atomize o forall_intr_vars;
 fun atomize_rule i thm =
   Conv.fconv_rule (Conv.concl_conv i Object_Logic.atomize) thm
 fun atomize_concl thm = atomize_rule (length (prems_of thm)) thm
 
 
-
 (* applies a tactic to a formula composed of conjunctions *)
 fun conj_tac tac i =
   let
@@ -509,7 +508,6 @@
     SUBGOAL select i
   end
 
-
 end (* structure *)
 
 open Nominal_Library;
\ No newline at end of file
--- a/Nominal/nominal_thmdecls.ML	Tue Mar 26 16:41:31 2013 +0100
+++ b/Nominal/nominal_thmdecls.ML	Wed Mar 27 16:08:30 2013 +0100
@@ -1,27 +1,59 @@
 (*  Title:      nominal_thmdecls.ML
     Author:     Christian Urban
+    Author:     Tjark Weber
 
-  Infrastructure for the lemma collection "eqvts".
+  Infrastructure for the lemma collections "eqvts", "eqvts_raw".
 
   Provides the attributes [eqvt] and [eqvt_raw], and the theorem
-  lists eqvts and eqvts_raw. The first attribute will store the 
-  theorem in the eqvts list and also in the eqvts_raw list. For 
-  the latter the theorem is expected to be of the form
+  lists "eqvts" and "eqvts_raw".
+
+  The [eqvt] attribute expects a theorem of the form
+
+    ?p \<bullet> (c ?x1 ?x2 ...) = c (?p \<bullet> ?x1) (?p \<bullet> ?x2) ...    (1)
+
+  or, if c is a relation with arity >= 1, of the form
+
+    c ?x1 ?x2 ... ==> c (?p \<bullet> ?x1) (?p \<bullet> ?x2) ...         (2)
 
-    p o (c x1 x2 ...) = c (p o x1) (p o x2) ...    (1)
+  [eqvt] will store this theorem in the form (1) or, if c
+  is a relation with arity >= 1, in the form
+
+    c (?p \<bullet> ?x1) (?p \<bullet> ?x2) ... = c ?x1 ?x2 ...           (3)
 
-  or
+  in "eqvts". (The orientation of (3) was chosen because
+  Isabelle's simplifier uses equations from left to right.)
+  [eqvt] will also derive and store the theorem
+
+    ?p \<bullet> c == c                                           (4)
+
+  in "eqvts_raw".
 
-    c x1 x2 ... ==> c (p o x1) (p o x2) ...        (2)
+  (1)-(4) are all logically equivalent. We consider (1) and (2)
+  to be more end-user friendly, i.e., slightly more natural to
+  understand and prove, while (3) and (4) make the rewriting
+  system for equivariance more predictable and less prone to
+  looping in Isabelle.
 
-  and it is stored in the form
+  The [eqvt_raw] attribute expects a theorem of the form (4),
+  and merely stores it in "eqvts_raw".
 
-    p o c == c
+  [eqvt_raw] is provided because certain equivariance theorems
+  would lead to looping when used for simplification in the form
+  (1): notably, equivariance of permute (infix \<bullet>), i.e.,
+  ?p \<bullet> (?q \<bullet> ?x) = (?p \<bullet> ?q) \<bullet> (?p \<bullet> ?x).
 
-  The [eqvt_raw] attribute just adds the theorem to eqvts_raw.
+  To support binders such as All/Ex/Ball/Bex etc., which are
+  typically applied to abstractions, argument terms ?xi (as well
+  as permuted arguments ?p \<bullet> ?xi) in (1)-(3) need not be eta-
+  contracted, i.e., they may be of the form "%z. ?xi z" or
+  "%z. (?p \<bullet> ?x) z", respectively.
 
-  TODO: In case of the form in (2) one should also
-        add the equational form to eqvts
+  For convenience, argument terms ?xi (as well as permuted
+  arguments ?p \<bullet> ?xi) in (1)-(3) may actually be tuples, e.g.,
+  "(?xi, ?xj)" or "(?p \<bullet> ?xi, ?p \<bullet> ?xj)", respectively.
+
+  In (1)-(4), "c" is either a (global) constant or a locally
+  fixed parameter, e.g., of a locale or type class.
 *)
 
 signature NOMINAL_THMDECLS =
@@ -46,206 +78,246 @@
   val extend = I;
   val merge = Item_Net.merge);
 
+(* EqvtRawData is implemented with a Termtab (rather than an
+   Item_Net) so that we can efficiently decide whether a given
+   constant has a corresponding equivariance theorem stored, cf.
+   the function is_eqvt. *)
 structure EqvtRawData = Generic_Data
 ( type T = thm Termtab.table;
   val empty = Termtab.empty;
   val extend = I;
   val merge = Termtab.merge (K true));
 
-val eqvts = Item_Net.content o EqvtData.get;
-val eqvts_raw = map snd o Termtab.dest o EqvtRawData.get;
-
-val get_eqvts_thms = eqvts o  Context.Proof;
-val get_eqvts_raw_thms = eqvts_raw o Context.Proof;
-
-fun checked_update key net =
-  (if Item_Net.member net key then 
-     warning "Theorem already declared as equivariant."
-   else (); 
-   Item_Net.update key net)
-
-val add_thm = EqvtData.map o checked_update;
-val del_thm = EqvtData.map o Item_Net.remove;
+val eqvts = Item_Net.content o EqvtData.get
+val eqvts_raw = map snd o Termtab.dest o EqvtRawData.get
 
-fun add_raw_thm thm = 
-  case prop_of thm of
-    Const ("==", _) $ _ $ (c as Const _) => EqvtRawData.map (Termtab.update (c, thm)) 
-  | _ => raise THM ("Theorem must be a meta-equality where the right-hand side is a constant.", 0, [thm]) 
-
-fun del_raw_thm thm = 
-  case prop_of thm of
-    Const ("==", _) $ _ $ (c as Const _) => EqvtRawData.map (Termtab.delete c)
-  | _ => raise THM ("Theorem must be a meta-equality where the right-hand side is a constant.", 0, [thm]) 
-
-fun is_eqvt ctxt trm =
-  case trm of 
-    (c as Const _) => Termtab.defined (EqvtRawData.get (Context.Proof ctxt)) c
-  | _ => false (* raise TERM ("Term must be a constant.", [trm]) *)
-
+val get_eqvts_thms = eqvts o Context.Proof
+val get_eqvts_raw_thms = eqvts_raw o Context.Proof
 
 
-(** transformation of eqvt lemmas **)
+(** raw equivariance lemmas **)
 
-fun get_perms trm =
-  case trm of 
-     Const (@{const_name permute}, _) $ _ $ (Bound _) => 
-       raise TERM ("get_perms called on bound", [trm])
-   | Const (@{const_name permute}, _) $ p $ _ => [p]
-   | t $ u => get_perms t @ get_perms u
-   | Abs (_, _, t) => get_perms t 
-   | _ => []
+(* Returns true iff an equivariance lemma exists in "eqvts_raw"
+   for a given term. *)
+val is_eqvt =
+  Termtab.defined o EqvtRawData.get o Context.Proof
 
-fun add_perm p trm =
+(* Returns c if thm is of the form (4), raises an error
+   otherwise. *)
+fun key_of_raw_thm context thm =
   let
-    fun aux trm = 
-      case trm of 
-        Bound _ => trm
-      | Const _ => trm
-      | t $ u => aux t $ aux u
-      | Abs (x, ty, t) => Abs (x, ty, aux t)
-      | _ => mk_perm p trm
+    fun error_msg () =
+      error
+        ("Theorem must be of the form \"?p \<bullet> c \<equiv> c\", with c a constant or fixed parameter:\n" ^
+         Syntax.string_of_term (Context.proof_of context) (prop_of thm))
+  in
+    case prop_of thm of
+      Const ("==", _) $ (Const (@{const_name "permute"}, _) $ p $ c) $ c' =>
+        if is_Var p andalso is_fixed (Context.proof_of context) c andalso c aconv c' then
+          c
+        else
+          error_msg ()
+    | _ => error_msg ()
+  end
+
+fun add_raw_thm thm context =
+  let
+    val c = key_of_raw_thm context thm
   in
-    strip_comb trm
-    ||> map aux
-    |> list_comb
-  end  
+    if Termtab.defined (EqvtRawData.get context) c then
+      warning ("Replacing existing raw equivariance theorem for \"" ^
+        Syntax.string_of_term (Context.proof_of context) c ^ "\".")
+    else ();
+    EqvtRawData.map (Termtab.update (c, thm)) context
+  end
+
+fun del_raw_thm thm context =
+  let
+    val c = key_of_raw_thm context thm
+  in
+    if Termtab.defined (EqvtRawData.get context) c then
+      EqvtRawData.map (Termtab.delete c) context
+    else (
+      warning ("Cannot delete non-existing raw equivariance theorem for \"" ^
+        Syntax.string_of_term (Context.proof_of context) c ^ "\".");
+      context
+    )
+  end
 
 
-(* tests whether there is a disagreement between the permutations, 
-   and that there is at least one permutation *)
-fun is_bad_list [] = true
-  | is_bad_list [_] = false
-  | is_bad_list (p::q::ps) = if p = q then is_bad_list (q::ps) else true
+(** adding/deleting lemmas to/from "eqvts" **)
+
+fun add_thm thm context =
+  (
+    if Item_Net.member (EqvtData.get context) thm then
+      warning ("Theorem already declared as equivariant:\n" ^
+        Syntax.string_of_term (Context.proof_of context) (prop_of thm))
+    else ();
+    EqvtData.map (Item_Net.update thm) context
+  )
+
+fun del_thm thm context =
+  (
+    if Item_Net.member (EqvtData.get context) thm then
+      EqvtData.map (Item_Net.remove thm) context
+    else (
+      warning ("Cannot delete non-existing equivariance theorem:\n" ^
+        Syntax.string_of_term (Context.proof_of context) (prop_of thm));
+      context
+    )
+  )
 
 
-(* transforms equations into the "p o c == c"-form 
-   from p o (c x1 ...xn) = c (p o x1) ... (p o xn) *)
+(** transformation of equivariance lemmas **)
+
+(* Transforms a theorem of the form (1) into the form (4). *)
+local
 
-fun eqvt_transform_eq_tac thm = 
-let
-  val ss_thms = @{thms permute_minus_cancel permute_prod.simps split_paired_all}
+fun tac thm =
+  let
+    val ss_thms = @{thms "permute_minus_cancel" "permute_prod.simps" "split_paired_all"}
+  in
+    REPEAT o FIRST'
+      [CHANGED o simp_tac (HOL_basic_ss addsimps ss_thms),
+       rtac (thm RS @{thm "trans"}),
+       rtac @{thm "trans"[OF "permute_fun_def"]} THEN' rtac @{thm "ext"}]
+  end
+
 in
-  REPEAT o FIRST' 
-    [CHANGED o simp_tac (HOL_basic_ss addsimps ss_thms),
-     rtac (thm RS @{thm trans}),
-     rtac @{thm trans[OF permute_fun_def]} THEN' rtac @{thm ext}]
-end
-
-fun eqvt_transform_eq ctxt thm = 
-  let
-    val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop (prop_of thm))
-      handle TERM _ => error "Equivariance lemma must be an equality."
-    val (p, t) = dest_perm lhs 
-      handle TERM _ => error "Equivariance lemma is not of the form p \<bullet> c...  = c..."
 
-    val ps = get_perms rhs handle TERM _ => []  
-    val (c, c') = (head_of t, head_of rhs)
-    val msg = "Equivariance lemma is not of the right form "
+fun thm_4_of_1 ctxt thm =
+  let
+    val (p, c) = thm |> prop_of |> HOLogic.dest_Trueprop
+      |> HOLogic.dest_eq |> fst |> dest_perm ||> fst o (fixed_nonfixed_args ctxt)
+    val goal = HOLogic.mk_Trueprop (HOLogic.mk_eq (mk_perm p c, c))
+    val ([goal', p'], ctxt') = Variable.import_terms false [goal, p] ctxt
   in
-    if c <> c' 
-      then error (msg ^ "(constants do not agree).")
-    else if is_bad_list (p :: ps)  
-      then error (msg ^ "(permutations do not agree).") 
-    else if not (rhs aconv (add_perm p t))
-      then error (msg ^ "(arguments do not agree).")
-    else if is_Const t 
-      then safe_mk_equiv thm
-    else 
-      let 
-        val goal = HOLogic.mk_Trueprop (HOLogic.mk_eq (mk_perm p c, c))
-        val ([goal', p'], ctxt') = Variable.import_terms false [goal, p] ctxt
-      in
-        Goal.prove ctxt [] [] goal' (fn _ => eqvt_transform_eq_tac thm 1)
-        |> singleton (Proof_Context.export ctxt' ctxt)
-        |> safe_mk_equiv
-        |> zero_var_indexes
-      end
+    Goal.prove ctxt [] [] goal' (fn _ => tac thm 1)
+      |> singleton (Proof_Context.export ctxt' ctxt)
+      |> (fn th => th RS @{thm "eq_reflection"})
+      |> zero_var_indexes
+  end
+  handle TERM _ =>
+    raise THM ("thm_4_of_1", 0, [thm])
+
+end (* local *)
+
+(* Transforms a theorem of the form (2) into the form (1). *)
+local
+
+fun tac thm thm' =
+  let
+    val ss_thms = @{thms "permute_minus_cancel"(2)}
+  in
+    EVERY' [rtac @{thm "iffI"}, dtac @{thm "permute_boolE"}, rtac thm, atac,
+      rtac @{thm "permute_boolI"}, dtac thm', full_simp_tac (HOL_basic_ss addsimps ss_thms)]
   end
 
-(* transforms equations into the "p o c == c"-form 
-   from R x1 ...xn ==> R (p o x1) ... (p o xn) *)
+in
 
-fun eqvt_transform_imp_tac ctxt p p' thm = 
+fun thm_1_of_2 ctxt thm =
   let
-    val thy = Proof_Context.theory_of ctxt
-    val cp = Thm.cterm_of thy p
-    val cp' = Thm.cterm_of thy (mk_minus p')
-    val thm' = Drule.cterm_instantiate [(cp, cp')] thm
-    val simp = HOL_basic_ss addsimps @{thms permute_minus_cancel(2)}
-  in
-    EVERY' [rtac @{thm iffI}, dtac @{thm permute_boolE}, rtac thm, atac,
-      rtac @{thm permute_boolI}, dtac thm', full_simp_tac simp]
-  end
-
-fun eqvt_transform_imp ctxt thm =
-  let
-    val (prem, concl) = pairself HOLogic.dest_Trueprop (Logic.dest_implies (prop_of thm))
-    val (c, c') = (head_of prem, head_of concl)
-    val ps = get_perms concl handle TERM _ => []  
-    val p = try hd ps
-    val msg = "Equivariance lemma is not of the right form "
+    val (prem, concl) = thm |> prop_of |> Logic.dest_implies |> pairself HOLogic.dest_Trueprop
+    (* since argument terms "?p \<bullet> ?x1" may actually be eta-expanded
+       or tuples, we need the following function to find ?p *)
+    fun find_perm (Const (@{const_name "permute"}, _) $ (p as Var _) $ _) = p
+      | find_perm (Const (@{const_name "Pair"}, _) $ x $ _) = find_perm x
+      | find_perm (Abs (_, _, body)) = find_perm body
+      | find_perm _ = raise THM ("thm_3_of_2", 0, [thm])
+    val p = concl |> dest_comb |> snd |> find_perm
+    val goal = HOLogic.mk_Trueprop (HOLogic.mk_eq (mk_perm p prem, concl))
+    val ([goal', p'], ctxt') = Variable.import_terms false [goal, p] ctxt
+    val certify = cterm_of (theory_of_thm thm)
+    val thm' = Drule.cterm_instantiate [(certify p, certify (mk_minus p'))] thm
   in
-    if c <> c' 
-      then error (msg ^ "(constants do not agree).")
-    else if is_bad_list ps  
-      then error (msg ^ "(permutations do not agree).") 
-    else if not (concl aconv (add_perm (the p) prem)) 
-      then error (msg ^ "(arguments do not agree).")
-    else 
-      let
-        val prem' = mk_perm (the p) prem    
-        val goal = HOLogic.mk_Trueprop (HOLogic.mk_eq (prem', concl))
-        val ([goal', p'], ctxt') = Variable.import_terms false [goal, the p] ctxt
-      in
-        Goal.prove ctxt' [] [] goal'
-          (fn _ => eqvt_transform_imp_tac ctxt' (the p) p' thm 1) 
-        |> singleton (Proof_Context.export ctxt' ctxt)
-      end
-  end     
+    Goal.prove ctxt' [] [] goal' (fn _ => tac thm thm' 1)
+      |> singleton (Proof_Context.export ctxt' ctxt)
+  end
+  handle TERM _ =>
+    raise THM ("thm_1_of_2", 0, [thm])
+
+end (* local *)
+
+(* Transforms a theorem of the form (1) into the form (3). *)
+fun thm_3_of_1 _ thm =
+  (thm RS (@{thm "permute_bool_def"} RS @{thm "sym"} RS @{thm "trans"}) RS @{thm "sym"})
+    |> zero_var_indexes
+
+local
+  val msg = cat_lines
+    ["Equivariance theorem must be of the form",
+     "  ?p \<bullet> (c ?x1 ?x2 ...) = c (?p \<bullet> ?x1) (?p \<bullet> ?x2) ...",
+     "or, if c is a relation with arity >= 1, of the form",
+     "  c ?x1 ?x2 ... ==> c (?p \<bullet> ?x1) (?p \<bullet> ?x2) ..."]
+in
 
-fun eqvt_transform ctxt thm = 
-  case (prop_of thm) of
-    @{const "Trueprop"} $ (Const (@{const_name "HOL.eq"}, _) $ 
-      (Const (@{const_name "permute"}, _) $ _ $ _) $ _) => 
-        eqvt_transform_eq ctxt thm
-  | @{const "==>"} $ (@{const "Trueprop"} $ _) $ (@{const "Trueprop"} $ _) => 
-      eqvt_transform_imp ctxt thm |> eqvt_transform_eq ctxt
-  | _ => raise error "Only _ = _ and _ ==> _ cases are implemented."
- 
+(* Transforms a theorem of the form (1) or (2) into the form (4). *)
+fun eqvt_transform ctxt thm =
+  (case prop_of thm of @{const "Trueprop"} $ _ =>
+    thm_4_of_1 ctxt thm
+  | @{const "==>"} $ _ $ _ =>
+    thm_4_of_1 ctxt (thm_1_of_2 ctxt thm)
+  | _ =>
+    error msg)
+  handle THM _ =>
+    error msg
+
+(* Transforms a theorem of the form (1) into theorems of the
+   form (1) (or, if c is a relation with arity >= 1, of the form
+   (3)) and (4); transforms a theorem of the form (2) into
+   theorems of the form (3) and (4). *)
+fun eqvt_and_raw_transform ctxt thm =
+  (case prop_of thm of @{const "Trueprop"} $ (Const (@{const_name "HOL.eq"}, _) $ _ $ c_args) =>
+    let
+      val th' =
+        if fastype_of c_args = @{typ "bool"}
+            andalso (not o null) (snd (fixed_nonfixed_args ctxt c_args)) then
+          thm_3_of_1 ctxt thm
+        else
+          thm
+    in
+      (th', thm_4_of_1 ctxt thm)
+    end
+  | @{const "==>"} $ _ $ _ =>
+    let
+      val th1 = thm_1_of_2 ctxt thm
+    in
+      (thm_3_of_1 ctxt th1, thm_4_of_1 ctxt th1)
+    end
+  | _ =>
+    error msg)
+  handle THM _ =>
+    error msg
+
+end (* local *)
+
 
 (** attributes **)
 
-val eqvt_add = Thm.declaration_attribute 
-  (fn thm => fn context =>
-   let
-     val thm' = eqvt_transform (Context.proof_of context) thm
-   in
-     context |> add_thm thm |> add_raw_thm thm'
-   end)
+val eqvt_raw_add = Thm.declaration_attribute add_raw_thm
+val eqvt_raw_del = Thm.declaration_attribute del_raw_thm
 
-val eqvt_del = Thm.declaration_attribute
-  (fn thm => fn context =>
-   let
-     val thm' = eqvt_transform (Context.proof_of context) thm
-   in
-     context |> del_thm thm  |> del_raw_thm thm'
-   end)
+fun eqvt_add_or_del eqvt_fn raw_fn =
+  Thm.declaration_attribute
+    (fn thm => fn context =>
+      let
+        val (eqvt, raw) = eqvt_and_raw_transform (Context.proof_of context) thm
+      in
+        context |> eqvt_fn eqvt |> raw_fn raw
+      end)
 
-val eqvt_raw_add = Thm.declaration_attribute add_raw_thm;
-val eqvt_raw_del = Thm.declaration_attribute del_raw_thm;
+val eqvt_add = eqvt_add_or_del add_thm add_raw_thm
+val eqvt_del = eqvt_add_or_del del_thm del_raw_thm
 
 
 (** setup function **)
 
 val setup =
-  Attrib.setup @{binding "eqvt"} (Attrib.add_del eqvt_add eqvt_del) 
-    (cat_lines ["Declaration of equivariance lemmas - they will automtically be",  
-       "brought into the form p o c == c"]) #>
-  Attrib.setup @{binding "eqvt_raw"} (Attrib.add_del eqvt_raw_add eqvt_raw_del) 
-    (cat_lines ["Declaration of equivariance lemmas - no",
-       "transformation is performed"]) #>
+  Attrib.setup @{binding "eqvt"} (Attrib.add_del eqvt_add eqvt_del)
+    "Declaration of equivariance lemmas - they will automatically be brought into the form ?p \<bullet> c \<equiv> c" #>
+  Attrib.setup @{binding "eqvt_raw"} (Attrib.add_del eqvt_raw_add eqvt_raw_del)
+    "Declaration of raw equivariance lemmas - no transformation is performed" #>
   Global_Theory.add_thms_dynamic (@{binding "eqvts"}, eqvts) #>
-  Global_Theory.add_thms_dynamic (@{binding "eqvts_raw"}, eqvts_raw);
-
+  Global_Theory.add_thms_dynamic (@{binding "eqvts_raw"}, eqvts_raw)
 
 end;