diff --git a/modules/core/src/main/scala/gs/graph/v0/GraphTraversal.scala b/modules/core/src/main/scala/gs/graph/v0/GraphTraversal.scala index 717be39..f26f01c 100644 --- a/modules/core/src/main/scala/gs/graph/v0/GraphTraversal.scala +++ b/modules/core/src/main/scala/gs/graph/v0/GraphTraversal.scala @@ -2,6 +2,7 @@ package gs.graph.v0 import gs.graph.v0.data.AnyGraphWithData import scala.collection.mutable.ListBuffer +import scala.collection.mutable.Queue import scala.collection.mutable.Stack /** Graph traversal algorithms including DFS and BFS. @@ -26,18 +27,18 @@ object GraphTraversal: val s = Stack.empty[Vertex] val discovered = Array.fill(graph.numberOfVertices.value)(false) - val root = Vertex.Zero - s.push(root) - - while !s.isEmpty - do - val v = s.pop() - if !discovered(v.ordinal) then - val _ = visit(v) - discovered(v.ordinal) = true - graph.neighbors(v).foreach(w => s.push(w)) - else () - () + graph.selectRoots().foreach { root => + val _ = s.push(root) + while !s.isEmpty + do + val v = s.pop() + if !discovered(v.ordinal) then + val _ = visit(v) + discovered(v.ordinal) = true + graph.neighbors(v).foreach(w => s.push(w)) + else () + () + } /** Depth-first search that executes a function on each [[Vertex]] to produce * some output. This function will operate on _any_ [[Graph]]. @@ -60,17 +61,17 @@ object GraphTraversal: val s = Stack.empty[Vertex] val discovered = Array.fill(graph.numberOfVertices.value)(false) - val root = Vertex.Zero - s.push(root) - - while !s.isEmpty - do - val v = s.pop() - if !discovered(v.ordinal) then - val _ = output.addOne(visit(v)) - discovered(v.ordinal) = true - graph.neighbors(v).foreach(w => s.push(w)) - else () + graph.selectRoots().foreach { root => + val _ = s.push(root) + while !s.isEmpty + do + val v = s.pop() + if !discovered(v.ordinal) then + val _ = output.addOne(visit(v)) + discovered(v.ordinal) = true + graph.neighbors(v).foreach(w => s.push(w)) + else () + } output.toList @@ -84,8 +85,7 @@ object GraphTraversal: val discovered = Array.fill(graph.numberOfVertices.value)(false) graph.selectRoots().foreach { root => - s.push(root) - + val _ = s.push(root) while !s.isEmpty do val v = s.pop() @@ -117,8 +117,7 @@ object GraphTraversal: val discovered = Array.fill(graph.numberOfVertices.value)(false) graph.selectRoots().foreach { root => - s.push(root) - + val _ = s.push(root) while !s.isEmpty do val v = s.pop() @@ -153,8 +152,7 @@ object GraphTraversal: val discovered = Array.fill(graph.numberOfVertices.value)(false) graph.selectRoots().foreach { root => - s.push(root) - + val _ = s.push(root) while !s.isEmpty do val v = s.pop() @@ -177,8 +175,7 @@ object GraphTraversal: val discovered = Array.fill(graph.numberOfVertices.value)(false) graph.selectRoots().foreach { root => - s.push(root) - + val _ = s.push(root) while !s.isEmpty do val v = s.pop() @@ -191,4 +188,139 @@ object GraphTraversal: acc + def bfs( + graph: Graph, + visit: Vertex => Unit + ): Unit = + val q = Queue.empty[Vertex] + val visited = Array.fill(graph.numberOfVertices.value)(false) + val _ = graph.selectRoots().foreach(q.enqueue) + + while !q.isEmpty + do + val v = q.dequeue() + if !visited(v.ordinal) then + val _ = visit(v) + visited(v.ordinal) = true + graph.neighbors(v).foreach { neighbor => + if !visited(neighbor.ordinal) then q.enqueue(neighbor) + else () + } + else () + + def bfs[Out]( + graph: Graph, + visit: Vertex => Out + ): List[Out] = + val output = ListBuffer.empty[Out] + val q = Queue.empty[Vertex] + val visited = Array.fill(graph.numberOfVertices.value)(false) + val _ = graph.selectRoots().foreach(q.enqueue) + + while !q.isEmpty + do + val v = q.dequeue() + if !visited(v.ordinal) then + val _ = output.addOne(visit(v)) + visited(v.ordinal) = true + graph.neighbors(v).foreach { neighbor => + if !visited(neighbor.ordinal) then q.enqueue(neighbor) + else () + } + else () + + output.toList + + def bfsFold[Acc]( + graph: Graph, + initial: Acc, + f: (Acc, Vertex) => Acc + ): Acc = + var acc = initial + val q = Queue.empty[Vertex] + val visited = Array.fill(graph.numberOfVertices.value)(false) + val _ = graph.selectRoots().foreach(q.enqueue) + + while !q.isEmpty + do + val v = q.dequeue() + if !visited(v.ordinal) then + acc = f(acc, v) + visited(v.ordinal) = true + graph.neighbors(v).foreach { neighbor => + if !visited(neighbor.ordinal) then q.enqueue(neighbor) + else () + } + else () + + acc + + def bfs[A]( + graph: AnyGraphWithData[A], + visit: (Vertex, A) => Unit + ): Unit = + val q = Queue.empty[Vertex] + val visited = Array.fill(graph.numberOfVertices.value)(false) + val _ = graph.selectRoots().foreach(q.enqueue) + + while !q.isEmpty + do + val v = q.dequeue() + if !visited(v.ordinal) then + val _ = visit(v, graph.data(v.ordinal)) + visited(v.ordinal) = true + graph.neighbors(v).foreach { neighbor => + if !visited(neighbor.ordinal) then q.enqueue(neighbor) + else () + } + else () + () + + def bfs[A, Out]( + graph: AnyGraphWithData[A], + visit: (Vertex, A) => Out + ): List[Out] = + val output = ListBuffer.empty[Out] + val q = Queue.empty[Vertex] + val visited = Array.fill(graph.numberOfVertices.value)(false) + val _ = graph.selectRoots().foreach(q.enqueue) + + while !q.isEmpty + do + val v = q.dequeue() + if !visited(v.ordinal) then + val _ = output.addOne(visit(v, graph.data(v.ordinal))) + visited(v.ordinal) = true + graph.neighbors(v).foreach { neighbor => + if !visited(neighbor.ordinal) then q.enqueue(neighbor) + else () + } + else () + + output.toList + + def bfsFold[A, Acc]( + graph: AnyGraphWithData[A], + initial: Acc, + f: (Acc, A) => Acc + ): Acc = + var acc = initial + val q = Queue.empty[Vertex] + val visited = Array.fill(graph.numberOfVertices.value)(false) + val _ = graph.selectRoots().foreach(q.enqueue) + + while !q.isEmpty + do + val v = q.dequeue() + if !visited(v.ordinal) then + acc = f(acc, graph.data(v.ordinal)) + visited(v.ordinal) = true + graph.neighbors(v).foreach { neighbor => + if !visited(neighbor.ordinal) then q.enqueue(neighbor) + else () + } + else () + + acc + end GraphTraversal diff --git a/modules/core/src/main/scala/gs/graph/v0/directed/Digraph.scala b/modules/core/src/main/scala/gs/graph/v0/directed/Digraph.scala index 7da4dc5..17ef0ed 100644 --- a/modules/core/src/main/scala/gs/graph/v0/directed/Digraph.scala +++ b/modules/core/src/main/scala/gs/graph/v0/directed/Digraph.scala @@ -102,7 +102,11 @@ object Digraph: // If there are any vertices with no incoming connections, those are roots. // Note that these may be completely disconnected. - counts.filter(_ == 0).map(Vertex(_)).toVector + counts.zipWithIndex + .filter(_._1 == 0) + .map { case (_, index) => Vertex(index) } + .toVector + .distinct /** Determine whether the given [[Digraph]] has any cycles. * diff --git a/modules/fs2/src/main/scala/gs/graph/v0/fs2/GraphTraversalFs2.scala b/modules/fs2/src/main/scala/gs/graph/v0/fs2/GraphTraversalFs2.scala index a01395e..abc908d 100644 --- a/modules/fs2/src/main/scala/gs/graph/v0/fs2/GraphTraversalFs2.scala +++ b/modules/fs2/src/main/scala/gs/graph/v0/fs2/GraphTraversalFs2.scala @@ -1,12 +1,14 @@ package gs.graph.v0.fs2 +import cats.effect.Async import cats.effect.Sync +import cats.effect.std.Queue +import cats.syntax.all.* import fs2.Pull import fs2.Stream import gs.graph.v0.Graph import gs.graph.v0.Size import gs.graph.v0.Vertex -import scala.collection.mutable.Stack object GraphTraversalFs2: @@ -19,19 +21,75 @@ object GraphTraversalFs2: val state = new DfsState(graph.numberOfVertices) graph .selectRoots() - .map(root => pull(graph, state, visit, root).stream.unNoneTerminate) + .map(root => pullDfs(graph, state, visit, root).stream.unNoneTerminate) .reduce(_ ++ _) - private def pull[F[_]: Sync, Out]( + def bfs[F[_]: Async, Out]( graph: Graph, - state: DfsState, + visit: Vertex => F[Out] + ): F[Stream[F, Out]] = + if graph.isEmpty then Async[F].pure(Stream.empty) + else + graph + .selectRoots() + .map { root => + BfsState.initialize[F](graph.numberOfVertices).flatMap { state => + doBfs(graph, visit, state, root) + } + } + .sequence + .map(_.reduce(_ ++ _)) + .map(_.unNone) + + private def doBfs[F[_]: Async, Out]( + graph: Graph, + visit: Vertex => F[Out], + state: BfsState[F], + root: Vertex + ): F[Stream[F, Option[Out]]] = + state.enqueue(root).map { _ => + Stream + .repeatEval(state.dequeue()) + .unNoneTerminate + .evalMap { vertex => + state.isVisited(vertex).flatMap { + case true => Async[F].delay(None) + case false => + for + _ <- state.visit(vertex) + out <- visit(vertex) + _ <- enqueueAllNeighbors(vertex, graph, state) + yield Some(out) + } + } + } + + private def enqueueAllNeighbors[F[_]: Async]( + vertex: Vertex, + graph: Graph, + state: BfsState[F] + ): F[Unit] = + graph + .neighbors(vertex) + .map { vertex => + state.isVisited(vertex).flatMap { + case true => Async[F].unit + case false => state.enqueue(vertex) + } + } + .sequence + .as(()) + + private def pullDfs[F[_]: Sync, Out]( + graph: Graph, + state: DfsState[F], visit: Vertex => F[Out], current: Vertex ): Pull[F, Option[Out], Unit] = - Pull.eval(Sync[F].delay(state.isDiscovered(current))).flatMap { discovered => - if discovered then Pull.output1(None) >> Pull.done + Pull.eval(state.isVisited(current)).flatMap { visited => + if visited then Pull.output1(None) >> Pull.done else - Pull.eval(Sync[F].delay(state.discover(current))) + Pull.eval(state.visit(current)) >> Pull.eval(visit(current)).flatMap(out => Pull.output1(Some(out))) >> graph .neighbors(current) @@ -39,25 +97,39 @@ object GraphTraversalFs2: ( acc, neighbor - ) => acc >> pull(graph, state, visit, neighbor) + ) => acc >> pullDfs(graph, state, visit, neighbor) } } - final private class DfsState(n: Size): + abstract private class TraverseState[F[_]: Sync](n: Size): + val visited: Array[Boolean] = Array.fill(n.value)(false) - val stack: Stack[Vertex] = Stack.empty - val discovered: Array[Boolean] = Array.fill(n.value)(false) + def isVisited(vertex: Vertex): F[Boolean] = + Sync[F].delay(visited(vertex.ordinal)) - def push(vertex: Vertex): Unit = stack.push(vertex) + def visit(vertex: Vertex): F[Unit] = + Sync[F].delay(visited(vertex.ordinal) = true) - def pop(): Vertex = stack.pop() - - def isDiscovered(vertex: Vertex): Boolean = - discovered(vertex.ordinal) - - def discover(vertex: Vertex): Unit = - discovered(vertex.ordinal) = true + final private class DfsState[F[_]: Sync](n: Size) extends TraverseState[F](n) end DfsState + final private class BfsState[F[_]: Async]( + n: Size, + queue: Queue[F, Vertex] + ) extends TraverseState[F](n): + + def enqueue(vertex: Vertex): F[Unit] = queue.offer(vertex) + + def dequeue(): F[Option[Vertex]] = queue.tryTake + + end BfsState + + private object BfsState: + + def initialize[F[_]: Async](n: Size): F[BfsState[F]] = + Queue.bounded[F, Vertex](n.value).map(new BfsState(n, _)) + + end BfsState + end GraphTraversalFs2 diff --git a/modules/fs2/src/test/scala/gs/graph/v0/fs2/Fs2BfsTests.scala b/modules/fs2/src/test/scala/gs/graph/v0/fs2/Fs2BfsTests.scala new file mode 100644 index 0000000..0ea1b9e --- /dev/null +++ b/modules/fs2/src/test/scala/gs/graph/v0/fs2/Fs2BfsTests.scala @@ -0,0 +1,93 @@ +package gs.graph.v0.fs2 + +import cats.effect.IO +import cats.effect.unsafe.IORuntime +import gs.graph.v0.Edge +import gs.graph.v0.Size +import gs.graph.v0.UndirectedGraph +import gs.graph.v0.Vertex +import gs.graph.v0.directed.Digraph +import munit.* + +class Fs2BfsTests extends FunSuite: + given IORuntime = IORuntime.global + + private def iotest( + name: String + )( + body: => IO[Unit] + )( + using + Location + ): Unit = + test(name)(body.unsafeRunSync()) + + iotest("(BFS) should return an empty stream for an empty graph") { + val s = GraphTraversalFs2.bfs( + UndirectedGraph.Empty, + _ => IO.raiseError(IllegalStateException("Should not reach this point.")) + ) + + s.flatMap(_.compile.last.map(result => assertEquals(result, None))) + } + + iotest( + "(BFS) should return a stream of one for a graph with a single vertex" + ) { + val s1 = GraphTraversalFs2.bfs(UndirectedGraph.Single, v => IO(v)) + val s2 = GraphTraversalFs2.bfs(Digraph.Single, v => IO(v)) + + for + r1 <- s1.flatMap(_.compile.toList) + r2 <- s2.flatMap(_.compile.toList) + yield + assertEquals(r1, List(Vertex.Zero)) + assertEquals(r2, List(Vertex.Zero)) + } + + iotest( + "(BFS) should visit a graph of three nodes" + ) { + val vs = Array(Vertex(0), Vertex(1), Vertex(2)) + val edges = List( + Edge(vs(0), vs(1)), + Edge(vs(0), vs(2)) + ) + val expected = vs.toList + val graph = Digraph.fromEdges(Size(vs.length), edges) + val s = GraphTraversalFs2.bfs(graph, v => IO(v)) + s.flatMap(_.compile.toList).map(assertEquals(_, expected)) + } + + iotest( + "(BFS) should visit a graph of five nodes" + ) { + val vs = Array(Vertex(0), Vertex(1), Vertex(2), Vertex(3), Vertex(4)) + val edges = List( + Edge(vs(0), vs(1)), + Edge(vs(0), vs(2)), + Edge(vs(1), vs(2)), + Edge(vs(2), vs(3)), + Edge(vs(2), vs(4)) + ) + val expected = vs.toList + val graph = Digraph.fromEdges(Size(vs.length), edges) + val s = GraphTraversalFs2.bfs(graph, v => IO(v)) + s.flatMap(_.compile.toList).map(assertEquals(_, expected)) + } + + iotest( + "(BFS) should visit a multi-root graph with disconnected components" + ) { + val vs = (0 until 6).map(Vertex(_)).toArray + val edges = List( + Edge(vs(0), vs(2)), + Edge(vs(0), vs(3)), + Edge(vs(2), vs(1)), + Edge(vs(4), vs(5)) + ) + val expected = List(vs(0), vs(2), vs(3), vs(1), vs(4), vs(5)) + val graph = Digraph.fromEdges(Size(vs.length), edges) + val s = GraphTraversalFs2.bfs(graph, v => IO(v)) + s.flatMap(_.compile.toList).map(assertEquals(_, expected)) + }