148 map (map (map (map (fn (opt_trm, i, j) => |
148 map (map (map (map (fn (opt_trm, i, j) => |
149 (Option.map (apfst (replace_term (cnstrs_env @ bn_fun_env) dts_env)) opt_trm, i, j))))) binds |
149 (Option.map (apfst (replace_term (cnstrs_env @ bn_fun_env) dts_env)) opt_trm, i, j))))) binds |
150 *} |
150 *} |
151 |
151 |
152 ML {* |
152 ML {* |
153 fun prep_bn dt_names eqs lthy = |
153 fun find [] _ = error ("cannot find element") |
|
154 | find ((x, z)::xs) y = if (Long_Name.base_name x) = y then z else find xs y |
|
155 *} |
|
156 |
|
157 ML {* |
|
158 fun prep_bn dt_names dts eqs lthy = |
154 let |
159 let |
155 fun aux eq = |
160 fun aux eq = |
156 let |
161 let |
157 val (lhs, rhs) = eq |
162 val (lhs, rhs) = eq |
158 |> strip_qnt_body "all" |
163 |> strip_qnt_body "all" |
159 |> HOLogic.dest_Trueprop |
164 |> HOLogic.dest_Trueprop |
160 |> HOLogic.dest_eq |
165 |> HOLogic.dest_eq |
161 val (head, [cnstr]) = strip_comb lhs |
166 val (bn_fun, [cnstr]) = strip_comb lhs |
162 val (_, ty) = dest_Free head |
167 val (_, ty) = dest_Free bn_fun |
163 val (ty_name, _) = dest_Type (domain_type ty) |
168 val (ty_name, _) = dest_Type (domain_type ty) |
164 val dt_index = find_index (fn x => x = ty_name) dt_names |
169 val dt_index = find_index (fn x => x = ty_name) dt_names |
165 val (_, cnstr_args) = strip_comb cnstr |
170 val (cnstr_head, cnstr_args) = strip_comb cnstr |
166 val included = map (fn i => length (cnstr_args) - i - 1) (loose_bnos rhs) |
171 val included = map (fn i => length (cnstr_args) - i - 1) (loose_bnos rhs) |
167 in |
172 in |
168 (head, dt_index, included) |
173 (dt_index, (bn_fun, (cnstr_head, included))) |
169 end |
174 end |
170 in |
175 |
171 map aux eqs |
176 fun order dts i ts = |
|
177 let |
|
178 val dt = nth dts i |
|
179 val cts = map (fn (x, _, _) => Binding.name_of x) ((fn (_, _, _, x) => x) dt) |
|
180 val ts' = map (fn (x, y) => (fst (dest_Const x), y)) ts |
|
181 in |
|
182 map (find ts') cts |
|
183 end |
|
184 |
|
185 val unordered = AList.group (op=) (map aux eqs) |
|
186 val unordered' = map (fn (x, y) => (x, AList.group (op=) y)) unordered |
|
187 val ordered = map (fn (x, y) => (x, map (fn (v, z) => (v, order dts x z)) y)) unordered' |
|
188 in |
|
189 ordered |
172 end |
190 end |
173 *} |
191 *} |
174 |
192 |
175 ML {* |
193 ML {* |
176 fun add_primrec_wrapper funs eqs lthy = |
194 fun add_primrec_wrapper funs eqs lthy = |
221 |
239 |
222 val (raw_bn_funs, raw_bn_eqs) = rawify_bn_funs dts_env cnstrs_env bn_fun_env bn_funs bn_eqs |
240 val (raw_bn_funs, raw_bn_eqs) = rawify_bn_funs dts_env cnstrs_env bn_fun_env bn_funs bn_eqs |
223 |
241 |
224 val raw_binds = rawify_binds dts_env cnstrs_env bn_fun_full_env binds |
242 val raw_binds = rawify_binds dts_env cnstrs_env bn_fun_full_env binds |
225 |
243 |
226 val raw_bns = prep_bn dt_full_names' (map snd raw_bn_eqs) lthy |
244 val raw_bns = prep_bn dt_full_names' raw_dts (map snd raw_bn_eqs) lthy |
227 |
245 |
228 val _ = tracing (cat_lines (map PolyML.makestring raw_bns)) |
246 val _ = tracing (cat_lines (map PolyML.makestring raw_bns)) |
229 in |
247 in |
230 lthy |
248 lthy |
231 |> add_datatype_wrapper raw_dt_names raw_dts |
249 |> add_datatype_wrapper raw_dt_names raw_dts |