Nominal/Nominal2.thy
changeset 3065 51ef8a3cb6ef
parent 3063 32abaea424bd
child 3076 2b1b8404fe0d
--- a/Nominal/Nominal2.thy	Thu Dec 15 16:20:11 2011 +0000
+++ b/Nominal/Nominal2.thy	Thu Dec 15 16:20:42 2011 +0000
@@ -60,10 +60,10 @@
 
 ML {*
 fun get_cnstrs dts =
-  map (fn (_, _, _, constrs) => constrs) dts
+  map snd dts
 
 fun get_typed_cnstrs dts =
-  flat (map (fn (_, bn, _, constrs) => 
+  flat (map (fn ((bn, _, _), constrs) => 
    (map (fn (bn', _, _) => (Binding.name_of bn, Binding.name_of bn')) constrs)) dts)
 
 fun get_cnstr_strs dts =
@@ -94,8 +94,8 @@
   fun raw_dts_aux1 (bind, tys, _) =
     (raw_bind bind, map (replace_typ ty_ss) tys, NoSyn)
 
-  fun raw_dts_aux2 (ty_args, bind, _, constrs) =
-    (ty_args, raw_bind bind, NoSyn, map raw_dts_aux1 constrs)
+  fun raw_dts_aux2 ((bind, ty_args, _), constrs) =
+    ((raw_bind bind, ty_args, NoSyn), map raw_dts_aux1 constrs)
 in
   map raw_dts_aux2 dts
 end
@@ -146,7 +146,7 @@
   val thy = Local_Theory.exit_global lthy
   val thy_name = Context.theory_name thy
 
-  val dt_names = map (fn (_, s, _, _) => Binding.name_of s) dts
+  val dt_names = map (fn ((s, _, _), _) => Binding.name_of s) dts
   val dt_full_names = map (Long_Name.qualify thy_name) dt_names 
   val dt_full_names' = add_raws dt_full_names
   val dts_env = dt_full_names ~~ dt_full_names'
@@ -333,7 +333,7 @@
       alpha_bn_rsp @ raw_perm_bn_rsp) lthy5
 
   val _ = trace_msg (K "Defining the quotient types...")
-  val qty_descr = map (fn (vs, bind, mx, _) => (vs, bind, mx)) dts
+  val qty_descr = map (fn ((bind, vs, mx), _) => (map fst vs, bind, mx)) dts
      
   val (qty_infos, lthy7) = 
     let
@@ -527,45 +527,58 @@
 
 section {* Preparing and parsing of the specification *}
 
+ML {*
+fun parse_spec ctxt ((tname, tvs, mx), constrs) =
+let
+  val tvs' = map (apsnd (Typedecl.read_constraint ctxt)) tvs
+  val constrs' = constrs
+    |> map (fn (c, Ts, mx', bns) => (c, map ((Syntax.parse_typ ctxt) o snd) Ts, mx'))
+in
+  ((tname, tvs', mx), constrs')
+end
+
+fun check_specs ctxt specs =
+  let
+    fun prep_spec ((tname, args, mx), constrs) tys =
+      let
+        val (args', tys1) = chop (length args) tys;
+        val (constrs', tys3) = (constrs, tys1) |-> fold_map (fn (cname, cargs, mx') => fn tys2 =>
+          let val (cargs', tys3) = chop (length cargs) tys2;
+          in ((cname, cargs', mx'), tys3) end);
+      in 
+        (((tname, map dest_TFree args', mx), constrs'), tys3) 
+      end
+
+    val all_tys =
+      specs |> maps (fn ((_, args, _), cs) => map TFree args @ maps #2 cs)
+      |> Syntax.check_typs ctxt;
+
+  in 
+    #1 (fold_map prep_spec specs all_tys) 
+  end
+*}
+
 ML {* 
 (* generates the parsed datatypes and 
    declares the constructors 
 *)
 fun prepare_dts dt_strs thy = 
 let
-  fun inter_fs_sort thy (a, S) = 
-    (a, Type.inter_sort (Sign.tsig_of thy) (@{sort fs}, S)) 
-
-  fun mk_type tname sorts (cname, cargs, mx) =
-  let
-    val full_tname = Sign.full_name thy tname
-    val ty = Type (full_tname, map (TFree o inter_fs_sort thy) sorts)
-  in
-    (cname, cargs ---> ty, mx)
-  end
+  val ctxt = Proof_Context.init_global thy
+    |> fold (fn ((_, args, _), _) => fold (fn (a, _) =>
+         Variable.declare_typ (TFree (a, dummyS))) args) dt_strs
+ 
+  val dts = check_specs ctxt (map (parse_spec ctxt) dt_strs)
+  
+  fun mk_constr_trms ((tname, tvs, _), constrs) =
+    let
+      val full_tname = Sign.full_name thy tname
+      val ty = Type (full_tname, map TFree tvs)
+    in
+      map (fn (c, tys, mx) => (c, tys ---> ty, mx)) constrs
+    end 
 
-  fun prep_constr (cname, cargs, mx, _) (constrs, sorts) =
-  let 
-    val (cargs', sorts') = 
-      fold_map (Datatype.read_typ thy) (map snd cargs) sorts
-      |>> map (map_type_tfree (TFree o inter_fs_sort thy)) 
-  in 
-    (constrs @ [(cname, cargs', mx)], sorts') 
-  end
-  
-  fun prep_dts (tvs, tname, mx, constrs) (constr_trms, dts, sorts) =
-  let 
-
-    val (constrs', sorts') = 
-      fold prep_constr constrs ([], sorts)
-
-    val constr_trms' = 
-      map (mk_type tname (rev sorts')) constrs'
-  in 
-    (constr_trms @ constr_trms', dts @ [(tvs, tname, mx, constrs')], sorts') 
-  end
-
-  val (constr_trms, dts, sorts) = fold prep_dts dt_strs ([], [], []);
+  val constr_trms = flat (map mk_constr_trms dts)
 in
   thy
   |> Sign.add_consts_i constr_trms
@@ -681,7 +694,7 @@
 fun nominal_datatype2_cmd (opt_thms_name, dt_strs, bn_fun_strs, bn_eq_strs) lthy = 
 let
   val pre_typs = 
-    map (fn (tvs, tname, mx, _) => (tname, length tvs, mx)) dt_strs
+    map (fn ((tname, tvs, mx), _) => (tname, length tvs, mx)) dt_strs
 
   (* this theory is used just for parsing *)
   val thy = Proof_Context.theory_of lthy  
@@ -707,6 +720,7 @@
   structure S = Scan
 
   fun triple ((x, y), z) = (x, y, z)
+  fun triple2 ((x, y), z) = (y, x, z)
   fun tuple1 ((x, y, z), u) = (x, y, z, u)
   fun tuple2 (((x, y), z), u) = (x, y, u, z)
   fun tuple3 ((x, y), (z, u)) = (x, y, z, u)
@@ -730,8 +744,8 @@
 
 (* datatype parser *)
 val dt_parser =
-  (P.type_args -- P.binding -- P.opt_mixfix >> triple) -- 
-    (P.$$$ "=" |-- P.enum1 "|" cnstr_parser) >> tuple1
+  (P.type_args_constrained -- P.binding -- P.opt_mixfix >> triple2) -- 
+    (P.$$$ "=" |-- P.enum1 "|" cnstr_parser)
 
 (* binding function parser *)
 val bnfun_parser = 
@@ -748,11 +762,6 @@
   (main_parser >> nominal_datatype2_cmd)
 *}
 
-(*
-ML {*
-trace := true
-*}
-*)
 
 end