diff -r ade2f8fcf8e8 -r 51ef8a3cb6ef Nominal/Nominal2.thy --- a/Nominal/Nominal2.thy Thu Dec 15 16:20:11 2011 +0000 +++ b/Nominal/Nominal2.thy Thu Dec 15 16:20:42 2011 +0000 @@ -60,10 +60,10 @@ ML {* fun get_cnstrs dts = - map (fn (_, _, _, constrs) => constrs) dts + map snd dts fun get_typed_cnstrs dts = - flat (map (fn (_, bn, _, constrs) => + flat (map (fn ((bn, _, _), constrs) => (map (fn (bn', _, _) => (Binding.name_of bn, Binding.name_of bn')) constrs)) dts) fun get_cnstr_strs dts = @@ -94,8 +94,8 @@ fun raw_dts_aux1 (bind, tys, _) = (raw_bind bind, map (replace_typ ty_ss) tys, NoSyn) - fun raw_dts_aux2 (ty_args, bind, _, constrs) = - (ty_args, raw_bind bind, NoSyn, map raw_dts_aux1 constrs) + fun raw_dts_aux2 ((bind, ty_args, _), constrs) = + ((raw_bind bind, ty_args, NoSyn), map raw_dts_aux1 constrs) in map raw_dts_aux2 dts end @@ -146,7 +146,7 @@ val thy = Local_Theory.exit_global lthy val thy_name = Context.theory_name thy - val dt_names = map (fn (_, s, _, _) => Binding.name_of s) dts + val dt_names = map (fn ((s, _, _), _) => Binding.name_of s) dts val dt_full_names = map (Long_Name.qualify thy_name) dt_names val dt_full_names' = add_raws dt_full_names val dts_env = dt_full_names ~~ dt_full_names' @@ -333,7 +333,7 @@ alpha_bn_rsp @ raw_perm_bn_rsp) lthy5 val _ = trace_msg (K "Defining the quotient types...") - val qty_descr = map (fn (vs, bind, mx, _) => (vs, bind, mx)) dts + val qty_descr = map (fn ((bind, vs, mx), _) => (map fst vs, bind, mx)) dts val (qty_infos, lthy7) = let @@ -527,45 +527,58 @@ section {* Preparing and parsing of the specification *} +ML {* +fun parse_spec ctxt ((tname, tvs, mx), constrs) = +let + val tvs' = map (apsnd (Typedecl.read_constraint ctxt)) tvs + val constrs' = constrs + |> map (fn (c, Ts, mx', bns) => (c, map ((Syntax.parse_typ ctxt) o snd) Ts, mx')) +in + ((tname, tvs', mx), constrs') +end + +fun check_specs ctxt specs = + let + fun prep_spec ((tname, args, mx), constrs) tys = + let + val (args', tys1) = chop (length args) tys; + val (constrs', tys3) = (constrs, tys1) |-> fold_map (fn (cname, cargs, mx') => fn tys2 => + let val (cargs', tys3) = chop (length cargs) tys2; + in ((cname, cargs', mx'), tys3) end); + in + (((tname, map dest_TFree args', mx), constrs'), tys3) + end + + val all_tys = + specs |> maps (fn ((_, args, _), cs) => map TFree args @ maps #2 cs) + |> Syntax.check_typs ctxt; + + in + #1 (fold_map prep_spec specs all_tys) + end +*} + ML {* (* generates the parsed datatypes and declares the constructors *) fun prepare_dts dt_strs thy = let - fun inter_fs_sort thy (a, S) = - (a, Type.inter_sort (Sign.tsig_of thy) (@{sort fs}, S)) - - fun mk_type tname sorts (cname, cargs, mx) = - let - val full_tname = Sign.full_name thy tname - val ty = Type (full_tname, map (TFree o inter_fs_sort thy) sorts) - in - (cname, cargs ---> ty, mx) - end + val ctxt = Proof_Context.init_global thy + |> fold (fn ((_, args, _), _) => fold (fn (a, _) => + Variable.declare_typ (TFree (a, dummyS))) args) dt_strs + + val dts = check_specs ctxt (map (parse_spec ctxt) dt_strs) + + fun mk_constr_trms ((tname, tvs, _), constrs) = + let + val full_tname = Sign.full_name thy tname + val ty = Type (full_tname, map TFree tvs) + in + map (fn (c, tys, mx) => (c, tys ---> ty, mx)) constrs + end - fun prep_constr (cname, cargs, mx, _) (constrs, sorts) = - let - val (cargs', sorts') = - fold_map (Datatype.read_typ thy) (map snd cargs) sorts - |>> map (map_type_tfree (TFree o inter_fs_sort thy)) - in - (constrs @ [(cname, cargs', mx)], sorts') - end - - fun prep_dts (tvs, tname, mx, constrs) (constr_trms, dts, sorts) = - let - - val (constrs', sorts') = - fold prep_constr constrs ([], sorts) - - val constr_trms' = - map (mk_type tname (rev sorts')) constrs' - in - (constr_trms @ constr_trms', dts @ [(tvs, tname, mx, constrs')], sorts') - end - - val (constr_trms, dts, sorts) = fold prep_dts dt_strs ([], [], []); + val constr_trms = flat (map mk_constr_trms dts) in thy |> Sign.add_consts_i constr_trms @@ -681,7 +694,7 @@ fun nominal_datatype2_cmd (opt_thms_name, dt_strs, bn_fun_strs, bn_eq_strs) lthy = let val pre_typs = - map (fn (tvs, tname, mx, _) => (tname, length tvs, mx)) dt_strs + map (fn ((tname, tvs, mx), _) => (tname, length tvs, mx)) dt_strs (* this theory is used just for parsing *) val thy = Proof_Context.theory_of lthy @@ -707,6 +720,7 @@ structure S = Scan fun triple ((x, y), z) = (x, y, z) + fun triple2 ((x, y), z) = (y, x, z) fun tuple1 ((x, y, z), u) = (x, y, z, u) fun tuple2 (((x, y), z), u) = (x, y, u, z) fun tuple3 ((x, y), (z, u)) = (x, y, z, u) @@ -730,8 +744,8 @@ (* datatype parser *) val dt_parser = - (P.type_args -- P.binding -- P.opt_mixfix >> triple) -- - (P.$$$ "=" |-- P.enum1 "|" cnstr_parser) >> tuple1 + (P.type_args_constrained -- P.binding -- P.opt_mixfix >> triple2) -- + (P.$$$ "=" |-- P.enum1 "|" cnstr_parser) (* binding function parser *) val bnfun_parser = @@ -748,11 +762,6 @@ (main_parser >> nominal_datatype2_cmd) *} -(* -ML {* -trace := true -*} -*) end