86 fun replace_typ ty_ss (Type (a, Ts)) = Type (replace_str ty_ss a, map (replace_typ ty_ss) Ts) |
86 fun replace_typ ty_ss (Type (a, Ts)) = Type (replace_str ty_ss a, map (replace_typ ty_ss) Ts) |
87 | replace_typ ty_ss T = T |
87 | replace_typ ty_ss T = T |
88 |
88 |
89 fun raw_dts ty_ss dts = |
89 fun raw_dts ty_ss dts = |
90 let |
90 let |
91 val ty_ss' = ty_ss ~~ (add_raws ty_ss) |
|
92 |
91 |
93 fun raw_dts_aux1 (bind, tys, mx) = |
92 fun raw_dts_aux1 (bind, tys, mx) = |
94 (raw_bind bind, map (replace_typ ty_ss') tys, mx) |
93 (raw_bind bind, map (replace_typ ty_ss) tys, mx) |
95 |
94 |
96 fun raw_dts_aux2 (ty_args, bind, mx, constrs) = |
95 fun raw_dts_aux2 (ty_args, bind, mx, constrs) = |
97 (ty_args, raw_bind bind, mx, map raw_dts_aux1 constrs) |
96 (ty_args, raw_bind bind, mx, map raw_dts_aux1 constrs) |
98 in |
97 in |
99 map raw_dts_aux2 dts |
98 map raw_dts_aux2 dts |
106 fun replace_term trm_ss ty_ss trm = |
105 fun replace_term trm_ss ty_ss trm = |
107 trm |> Term.map_aterms (replace_aterm trm_ss) |> map_types (replace_typ ty_ss) |
106 trm |> Term.map_aterms (replace_aterm trm_ss) |> map_types (replace_typ ty_ss) |
108 *} |
107 *} |
109 |
108 |
110 ML {* |
109 ML {* |
111 fun get_constrs dts = |
110 fun get_cnstrs dts = |
112 flat (map (fn (_, _, _, constrs) => constrs) dts) |
111 flat (map (fn (_, _, _, constrs) => constrs) dts) |
113 |
112 |
114 fun get_typed_constrs dts = |
113 fun get_typed_cnstrs dts = |
115 flat (map (fn (_, bn, _, constrs) => |
114 flat (map (fn (_, bn, _, constrs) => |
116 (map (fn (bn', _, _) => (Binding.name_of bn, Binding.name_of bn')) constrs)) dts) |
115 (map (fn (bn', _, _) => (Binding.name_of bn, Binding.name_of bn')) constrs)) dts) |
117 |
116 |
118 fun get_constr_strs dts = |
117 fun get_cnstr_strs dts = |
119 map (fn (bn, _, _) => Binding.name_of bn) (get_constrs dts) |
118 map (fn (bn, _, _) => Binding.name_of bn) (get_cnstrs dts) |
120 |
119 |
121 fun get_bn_fun_strs bn_funs = |
120 fun get_bn_fun_strs bn_funs = |
122 map (fn (bn_fun, _, _) => Binding.name_of bn_fun) bn_funs |
121 map (fn (bn_fun, _, _) => Binding.name_of bn_fun) bn_funs |
123 *} |
122 *} |
124 |
123 |
125 ML {* |
124 ML {* |
126 fun raw_dts_decl dt_names dts lthy = |
125 fun rawify_dts dt_names dts dts_env = |
127 let |
126 let |
128 val thy = ProofContext.theory_of lthy |
127 val raw_dts = raw_dts dts_env dts |
129 val dt_full_names = map (Sign.full_bname thy) dt_names |
|
130 val raw_dts = raw_dts dt_full_names dts |
|
131 val raw_dt_names = add_raws dt_names |
128 val raw_dt_names = add_raws dt_names |
132 in |
129 in |
133 (raw_dt_names, raw_dts) |
130 (raw_dt_names, raw_dts) |
134 end |
131 end |
135 *} |
132 *} |
136 |
133 |
137 ML {* |
134 ML {* |
138 fun raw_bn_fun_decl dt_names dts bn_funs bn_eqs lthy = |
135 fun rawify_bn_funs dts_env cnstrs_env bn_fun_env bn_funs bn_eqs = |
|
136 let |
|
137 val bn_funs' = map (fn (bn, ty, mx) => |
|
138 (raw_bind bn, replace_typ dts_env ty, mx)) bn_funs |
|
139 |
|
140 val bn_eqs' = map (fn (attr, trm) => |
|
141 (attr, replace_term (cnstrs_env @ bn_fun_env) dts_env trm)) bn_eqs |
|
142 in |
|
143 (bn_funs', bn_eqs') |
|
144 end |
|
145 *} |
|
146 |
|
147 ML {* |
|
148 fun add_primrec_wrapper funs eqs lthy = |
|
149 if null funs then (([], []), lthy) |
|
150 else |
|
151 let |
|
152 val eqs' = map (fn (_, eq) => (Attrib.empty_binding, eq)) eqs |
|
153 val funs' = map (fn (bn, ty, mx) => (bn, SOME ty, mx)) funs |
|
154 in |
|
155 Primrec.add_primrec funs' eqs' lthy |
|
156 end |
|
157 *} |
|
158 |
|
159 |
|
160 ML {* |
|
161 fun nominal_datatype2 dts bn_funs bn_eqs binds lthy = |
139 let |
162 let |
140 val thy = ProofContext.theory_of lthy |
163 val thy = ProofContext.theory_of lthy |
141 |
164 val thy_name = Context.theory_name thy |
142 val dt_names' = add_raws dt_names |
165 |
143 val dt_full_names = map (Sign.full_bname thy) dt_names |
166 val conf = Datatype.default_config |
144 val dt_full_names' = map (Sign.full_bname thy) dt_names' |
167 |
145 val dt_env = dt_full_names ~~ dt_full_names' |
168 val dt_names = map (fn (_, s, _, _) => Binding.name_of s) dts |
146 |
169 val dt_full_names = map (Long_Name.qualify thy_name) dt_names |
147 val ctrs_names = map (Sign.full_bname thy) (get_constr_strs dts) |
170 val dt_full_names' = add_raws dt_full_names |
148 val ctrs_names' = map (fn (x, y) => (Sign.full_bname_path thy (add_raw x) (add_raw y))) |
171 val dts_env = dt_full_names ~~ dt_full_names' |
149 (get_typed_constrs dts) |
172 |
150 val ctrs_env = ctrs_names ~~ ctrs_names' |
173 val cnstrs = get_cnstr_strs dts |
|
174 val cnstrs_ty = get_typed_cnstrs dts |
|
175 val cnstrs_full_names = map (Long_Name.qualify thy_name) cnstrs |
|
176 val cnstrs_full_names' = map (fn (x, y) => Long_Name.qualify thy_name |
|
177 (Long_Name.qualify (add_raw x) (add_raw y))) cnstrs_ty |
|
178 val cnstrs_env = cnstrs_full_names ~~ cnstrs_full_names' |
151 |
179 |
152 val bn_fun_strs = get_bn_fun_strs bn_funs |
180 val bn_fun_strs = get_bn_fun_strs bn_funs |
153 val bn_fun_strs' = add_raws bn_fun_strs |
181 val bn_fun_strs' = add_raws bn_fun_strs |
154 val bn_fun_env = bn_fun_strs ~~ bn_fun_strs' |
182 val bn_fun_env = bn_fun_strs ~~ bn_fun_strs' |
155 |
183 |
156 val bn_funs' = map (fn (bn, ty, mx) => |
184 val (raw_dt_names, raw_dts) = rawify_dts dt_names dts dts_env |
157 (raw_bind bn, SOME (replace_typ dt_env ty), mx)) bn_funs |
185 |
158 |
186 val (raw_bn_funs, raw_bn_eqs) = rawify_bn_funs dts_env cnstrs_env bn_fun_env bn_funs bn_eqs |
159 val bn_eqs' = map (fn (_, trm) => |
187 in |
160 (Attrib.empty_binding, replace_term (ctrs_env @ bn_fun_env) dt_env trm)) bn_eqs |
188 lthy |
161 in |
189 |> Local_Theory.theory_result (Datatype.add_datatype conf raw_dt_names raw_dts) |
162 if null bn_eqs |
190 ||>> add_primrec_wrapper raw_bn_funs raw_bn_eqs |
163 then (([], []), lthy) |
191 end |
164 else Primrec.add_primrec bn_funs' bn_eqs' lthy |
192 *} |
165 end |
|
166 *} |
|
167 |
|
168 ML {* |
|
169 fun nominal_datatype2 dts bn_funs bn_eqs binds lthy = |
|
170 let |
|
171 val conf = Datatype.default_config |
|
172 val dts_names = map (fn (_, s, _, _) => Binding.name_of s) dts |
|
173 |
|
174 val (raw_dt_names, raw_dts) = raw_dts_decl dts_names dts lthy |
|
175 |
|
176 in |
|
177 lthy |
|
178 |> Local_Theory.theory_result (Datatype.add_datatype conf raw_dt_names raw_dts) |
|
179 ||>> raw_bn_fun_decl dts_names dts bn_funs bn_eqs |
|
180 end |
|
181 *} |
|
182 |
|
183 |
193 |
184 ML {* |
194 ML {* |
185 (* parsing the datatypes and declaring *) |
195 (* parsing the datatypes and declaring *) |
186 (* constructors in the local theory *) |
196 (* constructors in the local theory *) |
187 fun prepare_dts dt_strs lthy = |
197 fun prepare_dts dt_strs lthy = |