--- a/Nominal/Nominal2.thy Thu Dec 15 16:20:11 2011 +0000
+++ b/Nominal/Nominal2.thy Thu Dec 15 16:20:42 2011 +0000
@@ -60,10 +60,10 @@
ML {*
fun get_cnstrs dts =
- map (fn (_, _, _, constrs) => constrs) dts
+ map snd dts
fun get_typed_cnstrs dts =
- flat (map (fn (_, bn, _, constrs) =>
+ flat (map (fn ((bn, _, _), constrs) =>
(map (fn (bn', _, _) => (Binding.name_of bn, Binding.name_of bn')) constrs)) dts)
fun get_cnstr_strs dts =
@@ -94,8 +94,8 @@
fun raw_dts_aux1 (bind, tys, _) =
(raw_bind bind, map (replace_typ ty_ss) tys, NoSyn)
- fun raw_dts_aux2 (ty_args, bind, _, constrs) =
- (ty_args, raw_bind bind, NoSyn, map raw_dts_aux1 constrs)
+ fun raw_dts_aux2 ((bind, ty_args, _), constrs) =
+ ((raw_bind bind, ty_args, NoSyn), map raw_dts_aux1 constrs)
in
map raw_dts_aux2 dts
end
@@ -146,7 +146,7 @@
val thy = Local_Theory.exit_global lthy
val thy_name = Context.theory_name thy
- val dt_names = map (fn (_, s, _, _) => Binding.name_of s) dts
+ val dt_names = map (fn ((s, _, _), _) => Binding.name_of s) dts
val dt_full_names = map (Long_Name.qualify thy_name) dt_names
val dt_full_names' = add_raws dt_full_names
val dts_env = dt_full_names ~~ dt_full_names'
@@ -333,7 +333,7 @@
alpha_bn_rsp @ raw_perm_bn_rsp) lthy5
val _ = trace_msg (K "Defining the quotient types...")
- val qty_descr = map (fn (vs, bind, mx, _) => (vs, bind, mx)) dts
+ val qty_descr = map (fn ((bind, vs, mx), _) => (map fst vs, bind, mx)) dts
val (qty_infos, lthy7) =
let
@@ -527,45 +527,58 @@
section {* Preparing and parsing of the specification *}
+ML {*
+fun parse_spec ctxt ((tname, tvs, mx), constrs) =
+let
+ val tvs' = map (apsnd (Typedecl.read_constraint ctxt)) tvs
+ val constrs' = constrs
+ |> map (fn (c, Ts, mx', bns) => (c, map ((Syntax.parse_typ ctxt) o snd) Ts, mx'))
+in
+ ((tname, tvs', mx), constrs')
+end
+
+fun check_specs ctxt specs =
+ let
+ fun prep_spec ((tname, args, mx), constrs) tys =
+ let
+ val (args', tys1) = chop (length args) tys;
+ val (constrs', tys3) = (constrs, tys1) |-> fold_map (fn (cname, cargs, mx') => fn tys2 =>
+ let val (cargs', tys3) = chop (length cargs) tys2;
+ in ((cname, cargs', mx'), tys3) end);
+ in
+ (((tname, map dest_TFree args', mx), constrs'), tys3)
+ end
+
+ val all_tys =
+ specs |> maps (fn ((_, args, _), cs) => map TFree args @ maps #2 cs)
+ |> Syntax.check_typs ctxt;
+
+ in
+ #1 (fold_map prep_spec specs all_tys)
+ end
+*}
+
ML {*
(* generates the parsed datatypes and
declares the constructors
*)
fun prepare_dts dt_strs thy =
let
- fun inter_fs_sort thy (a, S) =
- (a, Type.inter_sort (Sign.tsig_of thy) (@{sort fs}, S))
-
- fun mk_type tname sorts (cname, cargs, mx) =
- let
- val full_tname = Sign.full_name thy tname
- val ty = Type (full_tname, map (TFree o inter_fs_sort thy) sorts)
- in
- (cname, cargs ---> ty, mx)
- end
+ val ctxt = Proof_Context.init_global thy
+ |> fold (fn ((_, args, _), _) => fold (fn (a, _) =>
+ Variable.declare_typ (TFree (a, dummyS))) args) dt_strs
+
+ val dts = check_specs ctxt (map (parse_spec ctxt) dt_strs)
+
+ fun mk_constr_trms ((tname, tvs, _), constrs) =
+ let
+ val full_tname = Sign.full_name thy tname
+ val ty = Type (full_tname, map TFree tvs)
+ in
+ map (fn (c, tys, mx) => (c, tys ---> ty, mx)) constrs
+ end
- fun prep_constr (cname, cargs, mx, _) (constrs, sorts) =
- let
- val (cargs', sorts') =
- fold_map (Datatype.read_typ thy) (map snd cargs) sorts
- |>> map (map_type_tfree (TFree o inter_fs_sort thy))
- in
- (constrs @ [(cname, cargs', mx)], sorts')
- end
-
- fun prep_dts (tvs, tname, mx, constrs) (constr_trms, dts, sorts) =
- let
-
- val (constrs', sorts') =
- fold prep_constr constrs ([], sorts)
-
- val constr_trms' =
- map (mk_type tname (rev sorts')) constrs'
- in
- (constr_trms @ constr_trms', dts @ [(tvs, tname, mx, constrs')], sorts')
- end
-
- val (constr_trms, dts, sorts) = fold prep_dts dt_strs ([], [], []);
+ val constr_trms = flat (map mk_constr_trms dts)
in
thy
|> Sign.add_consts_i constr_trms
@@ -681,7 +694,7 @@
fun nominal_datatype2_cmd (opt_thms_name, dt_strs, bn_fun_strs, bn_eq_strs) lthy =
let
val pre_typs =
- map (fn (tvs, tname, mx, _) => (tname, length tvs, mx)) dt_strs
+ map (fn ((tname, tvs, mx), _) => (tname, length tvs, mx)) dt_strs
(* this theory is used just for parsing *)
val thy = Proof_Context.theory_of lthy
@@ -707,6 +720,7 @@
structure S = Scan
fun triple ((x, y), z) = (x, y, z)
+ fun triple2 ((x, y), z) = (y, x, z)
fun tuple1 ((x, y, z), u) = (x, y, z, u)
fun tuple2 (((x, y), z), u) = (x, y, u, z)
fun tuple3 ((x, y), (z, u)) = (x, y, z, u)
@@ -730,8 +744,8 @@
(* datatype parser *)
val dt_parser =
- (P.type_args -- P.binding -- P.opt_mixfix >> triple) --
- (P.$$$ "=" |-- P.enum1 "|" cnstr_parser) >> tuple1
+ (P.type_args_constrained -- P.binding -- P.opt_mixfix >> triple2) --
+ (P.$$$ "=" |-- P.enum1 "|" cnstr_parser)
(* binding function parser *)
val bnfun_parser =
@@ -748,11 +762,6 @@
(main_parser >> nominal_datatype2_cmd)
*}
-(*
-ML {*
-trace := true
-*}
-*)
end