86 fun replace_typ ty_ss (Type (a, Ts)) = Type (replace_str ty_ss a, map (replace_typ ty_ss) Ts) |
85 fun replace_typ ty_ss (Type (a, Ts)) = Type (replace_str ty_ss a, map (replace_typ ty_ss) Ts) |
87 | replace_typ ty_ss T = T |
86 | replace_typ ty_ss T = T |
88 |
87 |
89 fun raw_dts ty_ss dts = |
88 fun raw_dts ty_ss dts = |
90 let |
89 let |
91 val ty_ss' = ty_ss ~~ (add_raws ty_ss) |
|
92 |
90 |
93 fun raw_dts_aux1 (bind, tys, mx) = |
91 fun raw_dts_aux1 (bind, tys, mx) = |
94 (raw_bind bind, map (replace_typ ty_ss') tys, mx) |
92 (raw_bind bind, map (replace_typ ty_ss) tys, mx) |
95 |
93 |
96 fun raw_dts_aux2 (ty_args, bind, mx, constrs) = |
94 fun raw_dts_aux2 (ty_args, bind, mx, constrs) = |
97 (ty_args, raw_bind bind, mx, map raw_dts_aux1 constrs) |
95 (ty_args, raw_bind bind, mx, map raw_dts_aux1 constrs) |
98 in |
96 in |
99 map raw_dts_aux2 dts |
97 map raw_dts_aux2 dts |
106 fun replace_term trm_ss ty_ss trm = |
104 fun replace_term trm_ss ty_ss trm = |
107 trm |> Term.map_aterms (replace_aterm trm_ss) |> map_types (replace_typ ty_ss) |
105 trm |> Term.map_aterms (replace_aterm trm_ss) |> map_types (replace_typ ty_ss) |
108 *} |
106 *} |
109 |
107 |
110 ML {* |
108 ML {* |
111 fun get_constrs dts = |
109 fun get_cnstrs dts = |
112 flat (map (fn (_, _, _, constrs) => constrs) dts) |
110 map (fn (_, _, _, constrs) => constrs) dts |
113 |
111 |
114 fun get_typed_constrs dts = |
112 fun get_typed_cnstrs dts = |
115 flat (map (fn (_, bn, _, constrs) => |
113 flat (map (fn (_, bn, _, constrs) => |
116 (map (fn (bn', _, _) => (Binding.name_of bn, Binding.name_of bn')) constrs)) dts) |
114 (map (fn (bn', _, _) => (Binding.name_of bn, Binding.name_of bn')) constrs)) dts) |
117 |
115 |
118 fun get_constr_strs dts = |
116 fun get_cnstr_strs dts = |
119 map (fn (bn, _, _) => Binding.name_of bn) (get_constrs dts) |
117 map (fn (bn, _, _) => Binding.name_of bn) (flat (get_cnstrs dts)) |
120 |
118 |
121 fun get_bn_fun_strs bn_funs = |
119 fun get_bn_fun_strs bn_funs = |
122 map (fn (bn_fun, _, _) => Binding.name_of bn_fun) bn_funs |
120 map (fn (bn_fun, _, _) => Binding.name_of bn_fun) bn_funs |
123 *} |
121 *} |
124 |
122 |
125 ML {* |
123 ML {* |
126 fun raw_dts_decl dt_names dts lthy = |
124 fun rawify_dts dt_names dts dts_env = |
127 let |
125 let |
128 val thy = ProofContext.theory_of lthy |
126 val raw_dts = raw_dts dts_env dts |
129 val dt_full_names = map (Sign.full_bname thy) dt_names |
|
130 val raw_dts = raw_dts dt_full_names dts |
|
131 val raw_dt_names = add_raws dt_names |
127 val raw_dt_names = add_raws dt_names |
132 in |
128 in |
133 (raw_dt_names, raw_dts) |
129 (raw_dt_names, raw_dts) |
134 end |
130 end |
135 *} |
131 *} |
136 |
132 |
137 ML {* |
133 ML {* |
138 fun raw_bn_fun_decl dt_names dts bn_funs bn_eqs lthy = |
134 fun rawify_bn_funs dts_env cnstrs_env bn_fun_env bn_funs bn_eqs = |
|
135 let |
|
136 val bn_funs' = map (fn (bn, ty, mx) => |
|
137 (raw_bind bn, replace_typ dts_env ty, mx)) bn_funs |
|
138 |
|
139 val bn_eqs' = map (fn (attr, trm) => |
|
140 (attr, replace_term (cnstrs_env @ bn_fun_env) dts_env trm)) bn_eqs |
|
141 in |
|
142 (bn_funs', bn_eqs') |
|
143 end |
|
144 *} |
|
145 |
|
146 ML {* |
|
147 fun add_primrec_wrapper funs eqs lthy = |
|
148 if null funs then (([], []), lthy) |
|
149 else |
|
150 let |
|
151 val eqs' = map (fn (_, eq) => (Attrib.empty_binding, eq)) eqs |
|
152 val funs' = map (fn (bn, ty, mx) => (bn, SOME ty, mx)) funs |
|
153 in |
|
154 Primrec.add_primrec funs' eqs' lthy |
|
155 end |
|
156 *} |
|
157 |
|
158 |
|
159 ML {* |
|
160 fun nominal_datatype2 dts bn_funs bn_eqs binds lthy = |
139 let |
161 let |
140 val thy = ProofContext.theory_of lthy |
162 val thy = ProofContext.theory_of lthy |
141 |
163 val thy_name = Context.theory_name thy |
142 val dt_names' = add_raws dt_names |
164 |
143 val dt_full_names = map (Sign.full_bname thy) dt_names |
165 val conf = Datatype.default_config |
144 val dt_full_names' = map (Sign.full_bname thy) dt_names' |
166 |
145 val dt_env = dt_full_names ~~ dt_full_names' |
167 val dt_names = map (fn (_, s, _, _) => Binding.name_of s) dts |
146 |
168 val dt_full_names = map (Long_Name.qualify thy_name) dt_names |
147 val ctrs_names = map (Sign.full_bname thy) (get_constr_strs dts) |
169 val dt_full_names' = add_raws dt_full_names |
148 val ctrs_names' = map (fn (x, y) => (Sign.full_bname_path thy (add_raw x) (add_raw y))) |
170 val dts_env = dt_full_names ~~ dt_full_names' |
149 (get_typed_constrs dts) |
171 |
150 val ctrs_env = ctrs_names ~~ ctrs_names' |
172 val cnstrs = get_cnstr_strs dts |
|
173 val cnstrs_ty = get_typed_cnstrs dts |
|
174 val cnstrs_full_names = map (Long_Name.qualify thy_name) cnstrs |
|
175 val cnstrs_full_names' = map (fn (x, y) => Long_Name.qualify thy_name |
|
176 (Long_Name.qualify (add_raw x) (add_raw y))) cnstrs_ty |
|
177 val cnstrs_env = cnstrs_full_names ~~ cnstrs_full_names' |
151 |
178 |
152 val bn_fun_strs = get_bn_fun_strs bn_funs |
179 val bn_fun_strs = get_bn_fun_strs bn_funs |
153 val bn_fun_strs' = add_raws bn_fun_strs |
180 val bn_fun_strs' = add_raws bn_fun_strs |
154 val bn_fun_env = bn_fun_strs ~~ bn_fun_strs' |
181 val bn_fun_env = bn_fun_strs ~~ bn_fun_strs' |
155 |
182 |
156 val bn_funs' = map (fn (bn, ty, mx) => |
183 val (raw_dt_names, raw_dts) = rawify_dts dt_names dts dts_env |
157 (raw_bind bn, SOME (replace_typ dt_env ty), mx)) bn_funs |
184 |
158 |
185 val (raw_bn_funs, raw_bn_eqs) = rawify_bn_funs dts_env cnstrs_env bn_fun_env bn_funs bn_eqs |
159 val bn_eqs' = map (fn (_, trm) => |
186 in |
160 (Attrib.empty_binding, replace_term (ctrs_env @ bn_fun_env) dt_env trm)) bn_eqs |
187 lthy |
161 in |
188 |> Local_Theory.theory_result (Datatype.add_datatype conf raw_dt_names raw_dts) |
162 if null bn_eqs |
189 ||>> add_primrec_wrapper raw_bn_funs raw_bn_eqs |
163 then (([], []), lthy) |
190 end |
164 else Primrec.add_primrec bn_funs' bn_eqs' lthy |
191 *} |
165 end |
|
166 *} |
|
167 |
|
168 ML {* |
|
169 fun nominal_datatype2 dts bn_funs bn_eqs binds lthy = |
|
170 let |
|
171 val conf = Datatype.default_config |
|
172 val dts_names = map (fn (_, s, _, _) => Binding.name_of s) dts |
|
173 |
|
174 val (raw_dt_names, raw_dts) = raw_dts_decl dts_names dts lthy |
|
175 |
|
176 in |
|
177 lthy |
|
178 |> Local_Theory.theory_result (Datatype.add_datatype conf raw_dt_names raw_dts) |
|
179 ||>> raw_bn_fun_decl dts_names dts bn_funs bn_eqs |
|
180 end |
|
181 *} |
|
182 |
|
183 |
192 |
184 ML {* |
193 ML {* |
185 (* parsing the datatypes and declaring *) |
194 (* parsing the datatypes and declaring *) |
186 (* constructors in the local theory *) |
195 (* constructors in the local theory *) |
187 fun prepare_dts dt_strs lthy = |
196 fun prepare_dts dt_strs lthy = |
259 case AList.lookup (op =) xs x of |
266 case AList.lookup (op =) xs x of |
260 SOME x => x |
267 SOME x => x |
261 | NONE => error ("cannot find " ^ x ^ " in the binding specification."); |
268 | NONE => error ("cannot find " ^ x ^ " in the binding specification."); |
262 *} |
269 *} |
263 |
270 |
264 |
|
265 |
|
266 ML {* |
271 ML {* |
267 fun prepare_binds dt_strs lthy = |
272 fun prepare_binds dt_strs lthy = |
268 let |
273 let |
269 fun extract_annos_binds dt_strs = |
274 fun extract_annos_binds dt_strs = |
270 map ((map (fn (_, antys, _, bns) => (map fst antys, bns))) o forth) dt_strs |
275 map (map (fn (_, antys, _, bns) => (map fst antys, bns))) dt_strs |
271 |
276 |
272 fun prep_bn env bn_str = |
277 fun prep_bn env bn_str = |
273 case (Syntax.read_term lthy bn_str) of |
278 case (Syntax.read_term lthy bn_str) of |
274 Free (x, _) => (env_lookup env x, NONE) |
279 Free (x, _) => (NONE, env_lookup env x) |
275 | Const (a, T) $ Free (x, _) => (env_lookup env x, SOME (Const (a, T))) |
280 | Const (a, T) $ Free (x, _) => (SOME (Const (a, T)), env_lookup env x) |
276 | _ => error (bn_str ^ " not allowed as binding specification."); |
281 | _ => error (bn_str ^ " not allowed as binding specification."); |
277 |
282 |
278 fun prep_typ env opt_name = |
283 fun prep_typ env (i, opt_name) = |
279 case opt_name of |
284 case opt_name of |
280 NONE => [] |
285 NONE => [] |
281 | SOME x => find_all (op=) env x; |
286 | SOME x => find_all (op=) env (x,i); |
282 |
287 |
283 (* annos - list of annotation for each type (either NONE or SOME fo a type *) |
288 (* annos - list of annotation for each type (either NONE or SOME fo a type *) |
284 |
289 |
285 fun prep_binds (annos, bind_strs) = |
290 fun prep_binds (annos, bind_strs) = |
286 let |
291 let |
287 val env = mk_env annos (* for ever label the index *) |
292 val env = mk_env annos (* for every label the index *) |
288 val binds = map (fn (x, y) => (x, prep_bn env y)) bind_strs |
293 val binds = map (fn (x, y) => (x, prep_bn env y)) bind_strs |
289 in |
294 in |
290 map (prep_typ binds) annos |
295 map_index (prep_typ binds) annos |
291 end |
296 end |
292 |
297 |
293 in |
298 in |
294 map (map prep_binds) (extract_annos_binds dt_strs) |
299 map (map prep_binds) (extract_annos_binds (get_cnstrs dt_strs)) |
295 end |
300 end |
296 *} |
301 *} |
297 |
302 |
298 ML {* |
303 ML {* |
299 fun nominal_datatype2_cmd (dt_strs, bn_fun_strs, bn_eq_strs) lthy = |
304 fun nominal_datatype2_cmd (dt_strs, bn_fun_strs, bn_eq_strs) lthy = |