Nominal-General/nominal_eqvt.ML
author Christian Urban <urbanc@in.tum.de>
Mon, 07 Jun 2010 16:17:35 +0200
changeset 2216 1a9dbfe04f7d
parent 2168 ce0255ffaeb4
child 2306 86c977b4a9bb
permissions -rw-r--r--
new title for POPL paper

(*  Title:      nominal_eqvt.ML
    Author:     Stefan Berghofer (original code)
    Author:     Christian Urban

    Automatic proofs for equivariance of inductive predicates.
*)

signature NOMINAL_EQVT =
sig
  val eqvt_rel_tac: Proof.context -> string list -> term -> thm -> thm list -> int -> tactic
  val eqvt_rel_single_case_tac: Proof.context -> string list -> term -> thm -> int -> tactic
  
  val equivariance: term list -> thm -> thm list -> Proof.context -> thm list * local_theory
  val equivariance_cmd: string -> Proof.context -> local_theory
end

structure Nominal_Eqvt : NOMINAL_EQVT =
struct

open Nominal_Permeq;
open Nominal_ThmDecls;

val atomize_conv = 
  MetaSimplifier.rewrite_cterm (true, false, false) (K (K NONE))
    (HOL_basic_ss addsimps @{thms induct_atomize});
val atomize_intr = Conv.fconv_rule (Conv.prems_conv ~1 atomize_conv);
fun atomize_induct ctxt = Conv.fconv_rule (Conv.prems_conv ~1
  (Conv.params_conv ~1 (K (Conv.prems_conv ~1 atomize_conv)) ctxt));


(** 
 given the theorem F[t]; proves the theorem F[f t] 

  - F needs to be monotone
  - f returns either SOME for a term it fires on 
    and NONE elsewhere 
**)
fun map_term f t = 
  (case f t of
     NONE => map_term' f t 
   | x => x)
and map_term' f (t $ u) = 
    (case (map_term f t, map_term f u) of
        (NONE, NONE) => NONE
      | (SOME t'', NONE) => SOME (t'' $ u)
      | (NONE, SOME u'') => SOME (t $ u'')
      | (SOME t'', SOME u'') => SOME (t'' $ u''))
  | map_term' f (Abs (s, T, t)) = 
      (case map_term f t of
        NONE => NONE
      | SOME t'' => SOME (Abs (s, T, t'')))
  | map_term' _ _  = NONE;

fun map_thm_tac ctxt tac thm =
let
  val monos = Inductive.get_monos ctxt
  val simps = HOL_basic_ss addsimps @{thms split_def}
in
  EVERY [cut_facts_tac [thm] 1, etac rev_mp 1, 
    REPEAT_DETERM (FIRSTGOAL (simp_tac simps THEN' resolve_tac monos)),
    REPEAT_DETERM (rtac impI 1 THEN (atac 1 ORELSE tac))]
end

fun map_thm ctxt f tac thm =
let
  val opt_goal_trm = map_term f (prop_of thm)
in
  case opt_goal_trm of
    NONE => thm
  | SOME goal =>
     Goal.prove ctxt [] [] goal (fn _ => map_thm_tac ctxt tac thm) 
end

(*
 inductive premises can be of the form
 R ... /\ P ...; split_conj picks out
 the part P ...
*)
fun transform_prem ctxt names thm =
let
  fun split_conj names (Const ("op &", _) $ f1 $ f2) = 
      (case head_of f1 of
         Const (name, _) => if member (op =) names name then SOME f2 else NONE
       | _ => NONE)
  | split_conj _ _ = NONE;
in
  map_thm ctxt (split_conj names) (etac conjunct2 1) thm
end


(** equivariance tactics **)

val perm_boolE = @{thm permute_boolE}
val perm_cancel = @{thms permute_minus_cancel(2)}

fun eqvt_rel_single_case_tac ctxt pred_names pi intro  = 
let
  val thy = ProofContext.theory_of ctxt
  val cpi = Thm.cterm_of thy (mk_minus pi)
  val pi_intro_rule = Drule.instantiate' [] [SOME cpi] perm_boolE
  val simps1 = HOL_basic_ss addsimps @{thms permute_fun_def minus_minus split_paired_all}
  val simps2 = HOL_basic_ss addsimps @{thms permute_bool_def}
in
  eqvt_strict_tac ctxt [] pred_names THEN'
  SUBPROOF (fn {prems, context as ctxt, ...} =>
    let
      val prems' = map (transform_prem ctxt pred_names) prems
      val tac1 = resolve_tac prems'
      val tac2 = EVERY' [ rtac pi_intro_rule, 
            eqvt_strict_tac ctxt perm_cancel pred_names, resolve_tac prems' ]
      val tac3 = EVERY' [ rtac pi_intro_rule, 
            eqvt_strict_tac ctxt perm_cancel pred_names, simp_tac simps1, 
            simp_tac simps2, resolve_tac prems']
    in
      (rtac intro THEN_ALL_NEW FIRST' [tac1, tac2, tac3]) 1 
    end) ctxt
end

fun eqvt_rel_tac ctxt pred_names pi induct intros =
let
  val cases = map (eqvt_rel_single_case_tac ctxt pred_names pi) intros
in
  EVERY' (rtac induct :: cases)
end


(** equivariance procedure *)

(* sets up goal and makes sure parameters
   are untouched PROBLEM: this violates the 
   form of eqvt lemmas *)
fun prepare_goal pi pred =
let
  val (c, xs) = strip_comb pred;
in
  HOLogic.mk_imp (pred, list_comb (c, map (mk_perm pi) xs))
end

(* stores thm under name.eqvt and adds [eqvt]-attribute *)
fun note_named_thm (name, thm) ctxt = 
let
  val thm_name = Binding.qualified_name 
    (Long_Name.qualify (Long_Name.base_name name) "eqvt")
  val attr = Attrib.internal (K eqvt_add)
  val ((_, [thm']), ctxt') =  Local_Theory.note ((thm_name, [attr]), [thm]) ctxt
in
  (thm', ctxt')
end

fun equivariance pred_trms raw_induct intrs ctxt = 
let
  val is_already_eqvt = 
    filter (is_eqvt ctxt) pred_trms
    |> map (Syntax.string_of_term ctxt)
  val _ = if null is_already_eqvt then ()
    else error ("Already equivariant: " ^ commas is_already_eqvt)

  val pred_names = map (fst o dest_Const) pred_trms
  val raw_induct' = atomize_induct ctxt raw_induct
  val intrs' = map atomize_intr intrs
  val (([raw_concl], [raw_pi]), ctxt') = 
    ctxt 
    |> Variable.import_terms false [concl_of raw_induct'] 
    ||>> Variable.variant_fixes ["p"]
  val pi = Free (raw_pi, @{typ perm})
  val preds = map (fst o HOLogic.dest_imp)
    (HOLogic.dest_conj (HOLogic.dest_Trueprop raw_concl));
  val goal = HOLogic.mk_Trueprop 
    (foldr1 HOLogic.mk_conj (map (prepare_goal pi) preds))
  val thms = Datatype_Aux.split_conj_thm (Goal.prove ctxt' [] [] goal 
    (fn {context,...} => eqvt_rel_tac context pred_names pi raw_induct' intrs' 1)
    |> singleton (ProofContext.export ctxt' ctxt))
  val thms' = map (fn th => zero_var_indexes (th RS mp)) thms
in
  ctxt |> fold_map note_named_thm (pred_names ~~ thms')   
end

fun equivariance_cmd pred_name ctxt =
let
  val thy = ProofContext.theory_of ctxt
  val (_, {preds, raw_induct, intrs, ...}) =
    Inductive.the_inductive ctxt (Sign.intern_const thy pred_name)
in
  equivariance preds raw_induct intrs ctxt |> snd
end

local structure P = Parse and K = Keyword in

val _ =
  Outer_Syntax.local_theory "equivariance"
    "Proves equivariance for inductive predicate involving nominal datatypes." 
      K.thy_decl (P.xname >> equivariance_cmd);
end;

end (* structure *)