(patch) add bfs (#2)
All checks were successful
/ Build and Release Library (push) Successful in 2m19s

Reviewed-on: #2
This commit is contained in:
Pat Garrity 2026-01-04 03:16:44 +00:00
parent ee97bfb65d
commit db6b785c1f
4 changed files with 352 additions and 51 deletions

View file

@ -2,6 +2,7 @@ package gs.graph.v0
import gs.graph.v0.data.AnyGraphWithData
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.Queue
import scala.collection.mutable.Stack
/** Graph traversal algorithms including DFS and BFS.
@ -26,18 +27,18 @@ object GraphTraversal:
val s = Stack.empty[Vertex]
val discovered = Array.fill(graph.numberOfVertices.value)(false)
val root = Vertex.Zero
s.push(root)
while !s.isEmpty
do
val v = s.pop()
if !discovered(v.ordinal) then
val _ = visit(v)
discovered(v.ordinal) = true
graph.neighbors(v).foreach(w => s.push(w))
else ()
()
graph.selectRoots().foreach { root =>
val _ = s.push(root)
while !s.isEmpty
do
val v = s.pop()
if !discovered(v.ordinal) then
val _ = visit(v)
discovered(v.ordinal) = true
graph.neighbors(v).foreach(w => s.push(w))
else ()
()
}
/** Depth-first search that executes a function on each [[Vertex]] to produce
* some output. This function will operate on _any_ [[Graph]].
@ -60,17 +61,17 @@ object GraphTraversal:
val s = Stack.empty[Vertex]
val discovered = Array.fill(graph.numberOfVertices.value)(false)
val root = Vertex.Zero
s.push(root)
while !s.isEmpty
do
val v = s.pop()
if !discovered(v.ordinal) then
val _ = output.addOne(visit(v))
discovered(v.ordinal) = true
graph.neighbors(v).foreach(w => s.push(w))
else ()
graph.selectRoots().foreach { root =>
val _ = s.push(root)
while !s.isEmpty
do
val v = s.pop()
if !discovered(v.ordinal) then
val _ = output.addOne(visit(v))
discovered(v.ordinal) = true
graph.neighbors(v).foreach(w => s.push(w))
else ()
}
output.toList
@ -84,8 +85,7 @@ object GraphTraversal:
val discovered = Array.fill(graph.numberOfVertices.value)(false)
graph.selectRoots().foreach { root =>
s.push(root)
val _ = s.push(root)
while !s.isEmpty
do
val v = s.pop()
@ -117,8 +117,7 @@ object GraphTraversal:
val discovered = Array.fill(graph.numberOfVertices.value)(false)
graph.selectRoots().foreach { root =>
s.push(root)
val _ = s.push(root)
while !s.isEmpty
do
val v = s.pop()
@ -153,8 +152,7 @@ object GraphTraversal:
val discovered = Array.fill(graph.numberOfVertices.value)(false)
graph.selectRoots().foreach { root =>
s.push(root)
val _ = s.push(root)
while !s.isEmpty
do
val v = s.pop()
@ -177,8 +175,7 @@ object GraphTraversal:
val discovered = Array.fill(graph.numberOfVertices.value)(false)
graph.selectRoots().foreach { root =>
s.push(root)
val _ = s.push(root)
while !s.isEmpty
do
val v = s.pop()
@ -191,4 +188,139 @@ object GraphTraversal:
acc
def bfs(
graph: Graph,
visit: Vertex => Unit
): Unit =
val q = Queue.empty[Vertex]
val visited = Array.fill(graph.numberOfVertices.value)(false)
val _ = graph.selectRoots().foreach(q.enqueue)
while !q.isEmpty
do
val v = q.dequeue()
if !visited(v.ordinal) then
val _ = visit(v)
visited(v.ordinal) = true
graph.neighbors(v).foreach { neighbor =>
if !visited(neighbor.ordinal) then q.enqueue(neighbor)
else ()
}
else ()
def bfs[Out](
graph: Graph,
visit: Vertex => Out
): List[Out] =
val output = ListBuffer.empty[Out]
val q = Queue.empty[Vertex]
val visited = Array.fill(graph.numberOfVertices.value)(false)
val _ = graph.selectRoots().foreach(q.enqueue)
while !q.isEmpty
do
val v = q.dequeue()
if !visited(v.ordinal) then
val _ = output.addOne(visit(v))
visited(v.ordinal) = true
graph.neighbors(v).foreach { neighbor =>
if !visited(neighbor.ordinal) then q.enqueue(neighbor)
else ()
}
else ()
output.toList
def bfsFold[Acc](
graph: Graph,
initial: Acc,
f: (Acc, Vertex) => Acc
): Acc =
var acc = initial
val q = Queue.empty[Vertex]
val visited = Array.fill(graph.numberOfVertices.value)(false)
val _ = graph.selectRoots().foreach(q.enqueue)
while !q.isEmpty
do
val v = q.dequeue()
if !visited(v.ordinal) then
acc = f(acc, v)
visited(v.ordinal) = true
graph.neighbors(v).foreach { neighbor =>
if !visited(neighbor.ordinal) then q.enqueue(neighbor)
else ()
}
else ()
acc
def bfs[A](
graph: AnyGraphWithData[A],
visit: (Vertex, A) => Unit
): Unit =
val q = Queue.empty[Vertex]
val visited = Array.fill(graph.numberOfVertices.value)(false)
val _ = graph.selectRoots().foreach(q.enqueue)
while !q.isEmpty
do
val v = q.dequeue()
if !visited(v.ordinal) then
val _ = visit(v, graph.data(v.ordinal))
visited(v.ordinal) = true
graph.neighbors(v).foreach { neighbor =>
if !visited(neighbor.ordinal) then q.enqueue(neighbor)
else ()
}
else ()
()
def bfs[A, Out](
graph: AnyGraphWithData[A],
visit: (Vertex, A) => Out
): List[Out] =
val output = ListBuffer.empty[Out]
val q = Queue.empty[Vertex]
val visited = Array.fill(graph.numberOfVertices.value)(false)
val _ = graph.selectRoots().foreach(q.enqueue)
while !q.isEmpty
do
val v = q.dequeue()
if !visited(v.ordinal) then
val _ = output.addOne(visit(v, graph.data(v.ordinal)))
visited(v.ordinal) = true
graph.neighbors(v).foreach { neighbor =>
if !visited(neighbor.ordinal) then q.enqueue(neighbor)
else ()
}
else ()
output.toList
def bfsFold[A, Acc](
graph: AnyGraphWithData[A],
initial: Acc,
f: (Acc, A) => Acc
): Acc =
var acc = initial
val q = Queue.empty[Vertex]
val visited = Array.fill(graph.numberOfVertices.value)(false)
val _ = graph.selectRoots().foreach(q.enqueue)
while !q.isEmpty
do
val v = q.dequeue()
if !visited(v.ordinal) then
acc = f(acc, graph.data(v.ordinal))
visited(v.ordinal) = true
graph.neighbors(v).foreach { neighbor =>
if !visited(neighbor.ordinal) then q.enqueue(neighbor)
else ()
}
else ()
acc
end GraphTraversal

View file

@ -102,7 +102,11 @@ object Digraph:
// If there are any vertices with no incoming connections, those are roots.
// Note that these may be completely disconnected.
counts.filter(_ == 0).map(Vertex(_)).toVector
counts.zipWithIndex
.filter(_._1 == 0)
.map { case (_, index) => Vertex(index) }
.toVector
.distinct
/** Determine whether the given [[Digraph]] has any cycles.
*

View file

@ -1,12 +1,14 @@
package gs.graph.v0.fs2
import cats.effect.Async
import cats.effect.Sync
import cats.effect.std.Queue
import cats.syntax.all.*
import fs2.Pull
import fs2.Stream
import gs.graph.v0.Graph
import gs.graph.v0.Size
import gs.graph.v0.Vertex
import scala.collection.mutable.Stack
object GraphTraversalFs2:
@ -19,19 +21,75 @@ object GraphTraversalFs2:
val state = new DfsState(graph.numberOfVertices)
graph
.selectRoots()
.map(root => pull(graph, state, visit, root).stream.unNoneTerminate)
.map(root => pullDfs(graph, state, visit, root).stream.unNoneTerminate)
.reduce(_ ++ _)
private def pull[F[_]: Sync, Out](
def bfs[F[_]: Async, Out](
graph: Graph,
state: DfsState,
visit: Vertex => F[Out]
): F[Stream[F, Out]] =
if graph.isEmpty then Async[F].pure(Stream.empty)
else
graph
.selectRoots()
.map { root =>
BfsState.initialize[F](graph.numberOfVertices).flatMap { state =>
doBfs(graph, visit, state, root)
}
}
.sequence
.map(_.reduce(_ ++ _))
.map(_.unNone)
private def doBfs[F[_]: Async, Out](
graph: Graph,
visit: Vertex => F[Out],
state: BfsState[F],
root: Vertex
): F[Stream[F, Option[Out]]] =
state.enqueue(root).map { _ =>
Stream
.repeatEval(state.dequeue())
.unNoneTerminate
.evalMap { vertex =>
state.isVisited(vertex).flatMap {
case true => Async[F].delay(None)
case false =>
for
_ <- state.visit(vertex)
out <- visit(vertex)
_ <- enqueueAllNeighbors(vertex, graph, state)
yield Some(out)
}
}
}
private def enqueueAllNeighbors[F[_]: Async](
vertex: Vertex,
graph: Graph,
state: BfsState[F]
): F[Unit] =
graph
.neighbors(vertex)
.map { vertex =>
state.isVisited(vertex).flatMap {
case true => Async[F].unit
case false => state.enqueue(vertex)
}
}
.sequence
.as(())
private def pullDfs[F[_]: Sync, Out](
graph: Graph,
state: DfsState[F],
visit: Vertex => F[Out],
current: Vertex
): Pull[F, Option[Out], Unit] =
Pull.eval(Sync[F].delay(state.isDiscovered(current))).flatMap { discovered =>
if discovered then Pull.output1(None) >> Pull.done
Pull.eval(state.isVisited(current)).flatMap { visited =>
if visited then Pull.output1(None) >> Pull.done
else
Pull.eval(Sync[F].delay(state.discover(current)))
Pull.eval(state.visit(current))
>> Pull.eval(visit(current)).flatMap(out => Pull.output1(Some(out)))
>> graph
.neighbors(current)
@ -39,25 +97,39 @@ object GraphTraversalFs2:
(
acc,
neighbor
) => acc >> pull(graph, state, visit, neighbor)
) => acc >> pullDfs(graph, state, visit, neighbor)
}
}
final private class DfsState(n: Size):
abstract private class TraverseState[F[_]: Sync](n: Size):
val visited: Array[Boolean] = Array.fill(n.value)(false)
val stack: Stack[Vertex] = Stack.empty
val discovered: Array[Boolean] = Array.fill(n.value)(false)
def isVisited(vertex: Vertex): F[Boolean] =
Sync[F].delay(visited(vertex.ordinal))
def push(vertex: Vertex): Unit = stack.push(vertex)
def visit(vertex: Vertex): F[Unit] =
Sync[F].delay(visited(vertex.ordinal) = true)
def pop(): Vertex = stack.pop()
def isDiscovered(vertex: Vertex): Boolean =
discovered(vertex.ordinal)
def discover(vertex: Vertex): Unit =
discovered(vertex.ordinal) = true
final private class DfsState[F[_]: Sync](n: Size) extends TraverseState[F](n)
end DfsState
final private class BfsState[F[_]: Async](
n: Size,
queue: Queue[F, Vertex]
) extends TraverseState[F](n):
def enqueue(vertex: Vertex): F[Unit] = queue.offer(vertex)
def dequeue(): F[Option[Vertex]] = queue.tryTake
end BfsState
private object BfsState:
def initialize[F[_]: Async](n: Size): F[BfsState[F]] =
Queue.bounded[F, Vertex](n.value).map(new BfsState(n, _))
end BfsState
end GraphTraversalFs2

View file

@ -0,0 +1,93 @@
package gs.graph.v0.fs2
import cats.effect.IO
import cats.effect.unsafe.IORuntime
import gs.graph.v0.Edge
import gs.graph.v0.Size
import gs.graph.v0.UndirectedGraph
import gs.graph.v0.Vertex
import gs.graph.v0.directed.Digraph
import munit.*
class Fs2BfsTests extends FunSuite:
given IORuntime = IORuntime.global
private def iotest(
name: String
)(
body: => IO[Unit]
)(
using
Location
): Unit =
test(name)(body.unsafeRunSync())
iotest("(BFS) should return an empty stream for an empty graph") {
val s = GraphTraversalFs2.bfs(
UndirectedGraph.Empty,
_ => IO.raiseError(IllegalStateException("Should not reach this point."))
)
s.flatMap(_.compile.last.map(result => assertEquals(result, None)))
}
iotest(
"(BFS) should return a stream of one for a graph with a single vertex"
) {
val s1 = GraphTraversalFs2.bfs(UndirectedGraph.Single, v => IO(v))
val s2 = GraphTraversalFs2.bfs(Digraph.Single, v => IO(v))
for
r1 <- s1.flatMap(_.compile.toList)
r2 <- s2.flatMap(_.compile.toList)
yield
assertEquals(r1, List(Vertex.Zero))
assertEquals(r2, List(Vertex.Zero))
}
iotest(
"(BFS) should visit a graph of three nodes"
) {
val vs = Array(Vertex(0), Vertex(1), Vertex(2))
val edges = List(
Edge(vs(0), vs(1)),
Edge(vs(0), vs(2))
)
val expected = vs.toList
val graph = Digraph.fromEdges(Size(vs.length), edges)
val s = GraphTraversalFs2.bfs(graph, v => IO(v))
s.flatMap(_.compile.toList).map(assertEquals(_, expected))
}
iotest(
"(BFS) should visit a graph of five nodes"
) {
val vs = Array(Vertex(0), Vertex(1), Vertex(2), Vertex(3), Vertex(4))
val edges = List(
Edge(vs(0), vs(1)),
Edge(vs(0), vs(2)),
Edge(vs(1), vs(2)),
Edge(vs(2), vs(3)),
Edge(vs(2), vs(4))
)
val expected = vs.toList
val graph = Digraph.fromEdges(Size(vs.length), edges)
val s = GraphTraversalFs2.bfs(graph, v => IO(v))
s.flatMap(_.compile.toList).map(assertEquals(_, expected))
}
iotest(
"(BFS) should visit a multi-root graph with disconnected components"
) {
val vs = (0 until 6).map(Vertex(_)).toArray
val edges = List(
Edge(vs(0), vs(2)),
Edge(vs(0), vs(3)),
Edge(vs(2), vs(1)),
Edge(vs(4), vs(5))
)
val expected = List(vs(0), vs(2), vs(3), vs(1), vs(4), vs(5))
val graph = Digraph.fromEdges(Size(vs.length), edges)
val s = GraphTraversalFs2.bfs(graph, v => IO(v))
s.flatMap(_.compile.toList).map(assertEquals(_, expected))
}