|
1 // Sudoku |
|
2 //======== |
|
3 |
|
4 // call parallel version with |
|
5 // |
|
6 // scala -cp scala-parallel-collections_2.13-0.2.0.jar sudoku_test.scala |
|
7 // |
|
8 // or |
|
9 // |
|
10 // scalac -cp scala-parallel-collections_2.13-0.2.0.jar sudoku_test.scala |
|
11 // java -cp .:scala-library-2.13.0.jar:scala-parallel-collections_2.13-0.2.0.jar Sudoku |
|
12 |
|
13 object Sudoku extends App { |
|
14 |
|
15 import scala.collection.parallel.CollectionConverters._ |
|
16 |
|
17 type Pos = (Int, Int) |
|
18 val emptyValue = '.' |
|
19 val maxValue = 9 |
|
20 |
|
21 val allValues = "123456789".toList |
|
22 val indexes = (0 to 8).toList |
|
23 |
|
24 |
|
25 def empty(game: String) = game.indexOf(emptyValue) |
|
26 def isDone(game: String) = empty(game) == -1 |
|
27 def emptyPosition(game: String) = { |
|
28 val e = empty(game) |
|
29 (e % maxValue, e / maxValue) |
|
30 } |
|
31 |
|
32 |
|
33 def get_row(game: String, y: Int) = indexes.map(col => game(y * maxValue + col)) |
|
34 def get_col(game: String, x: Int) = indexes.map(row => game(x + row * maxValue)) |
|
35 |
|
36 def get_box(game: String, pos: Pos): List[Char] = { |
|
37 def base(p: Int): Int = (p / 3) * 3 |
|
38 val x0 = base(pos._1) |
|
39 val y0 = base(pos._2) |
|
40 for (x <- (x0 until x0 + 3).toList; |
|
41 y <- (y0 until y0 + 3).toList) yield game(x + y * maxValue) |
|
42 } |
|
43 |
|
44 |
|
45 def update(game: String, pos: Int, value: Char): String = |
|
46 game.updated(pos, value) |
|
47 |
|
48 def toAvoid(game: String, pos: Pos): List[Char] = |
|
49 (get_col(game, pos._1) ++ get_row(game, pos._2) ++ get_box(game, pos)) |
|
50 |
|
51 def candidates(game: String, pos: Pos): List[Char] = |
|
52 allValues.diff(toAvoid(game, pos)) |
|
53 |
|
54 def search(game: String): List[String] = { |
|
55 if (isDone(game)) List(game) |
|
56 else { |
|
57 val cs = candidates(game, emptyPosition(game)) |
|
58 cs.par.map(c => search(update(game, empty(game), c))).toList.flatten |
|
59 } |
|
60 } |
|
61 |
|
62 |
|
63 def pretty(game: String): String = |
|
64 "\n" + (game.grouped(maxValue).mkString(",\n")) |
|
65 |
|
66 |
|
67 val game2 = """8........ |
|
68 |..36..... |
|
69 |.7..9.2.. |
|
70 |.5...7... |
|
71 |....457.. |
|
72 |...1...3. |
|
73 |..1....68 |
|
74 |..85...1. |
|
75 |.9....4..""".stripMargin.replaceAll("\\n", "") |
|
76 |
|
77 // for measuring time |
|
78 def time_needed[T](i: Int, code: => T) = { |
|
79 val start = System.nanoTime() |
|
80 for (j <- 1 to i) code |
|
81 val end = System.nanoTime() |
|
82 s"${(end - start) / 1.0e9} secs" |
|
83 } |
|
84 |
|
85 |
|
86 println(time_needed(10, search(game2))) |
|
87 |
|
88 } |