progs/sudoku_test.scala
author Christian Urban <christian.urban@kcl.ac.uk>
Mon, 07 Dec 2020 01:25:41 +0000
changeset 383 c02929f2647c
permissions -rw-r--r--
updated

// Sudoku
//========

// call parallel version with
//
// scala -cp scala-parallel-collections_2.13-0.2.0.jar sudoku_test.scala 
//
// or
//
// scalac -cp scala-parallel-collections_2.13-0.2.0.jar sudoku_test.scala
// java -cp .:scala-library-2.13.0.jar:scala-parallel-collections_2.13-0.2.0.jar Sudoku

object Sudoku extends App { 

import scala.collection.parallel.CollectionConverters._

type Pos = (Int, Int)
val emptyValue = '.'
val maxValue = 9

val allValues = "123456789".toList
val indexes = (0 to 8).toList


def empty(game: String) = game.indexOf(emptyValue)
def isDone(game: String) = empty(game) == -1 
def emptyPosition(game: String) = {
  val e = empty(game)
  (e % maxValue, e / maxValue)
}


def get_row(game: String, y: Int) = indexes.map(col => game(y * maxValue + col))
def get_col(game: String, x: Int) = indexes.map(row => game(x + row * maxValue))

def get_box(game: String, pos: Pos): List[Char] = {
    def base(p: Int): Int = (p / 3) * 3
    val x0 = base(pos._1)
    val y0 = base(pos._2)
    for (x <- (x0 until x0 + 3).toList;
         y <- (y0 until y0 + 3).toList) yield game(x + y * maxValue)
}         


def update(game: String, pos: Int, value: Char): String = 
  game.updated(pos, value)

def toAvoid(game: String, pos: Pos): List[Char] = 
  (get_col(game, pos._1) ++ get_row(game, pos._2) ++ get_box(game, pos))

def candidates(game: String, pos: Pos): List[Char] = 
  allValues.diff(toAvoid(game, pos))

def search(game: String): List[String] = {
  if (isDone(game)) List(game)
  else {
    val cs = candidates(game, emptyPosition(game))
    cs.par.map(c => search(update(game, empty(game), c))).toList.flatten
  }
}


def pretty(game: String): String = 
  "\n" + (game.grouped(maxValue).mkString(",\n"))


val game2 = """8........
              |..36.....
              |.7..9.2..
              |.5...7...
              |....457..
              |...1...3.
              |..1....68
              |..85...1.
              |.9....4..""".stripMargin.replaceAll("\\n", "")

// for measuring time
def time_needed[T](i: Int, code: => T) = {
  val start = System.nanoTime()
  for (j <- 1 to i) code
  val end = System.nanoTime()
  s"${(end - start) / 1.0e9} secs"
}


println(time_needed(10, search(game2)))

}