(patch) Add data variants to traversals in fs2 (#3)
All checks were successful
/ Build and Release Library (push) Successful in 2m24s

Reviewed-on: #3
This commit is contained in:
Pat Garrity 2026-01-06 03:59:07 +00:00
parent db6b785c1f
commit 5827a13f21

View file

@ -9,6 +9,7 @@ import fs2.Stream
import gs.graph.v0.Graph import gs.graph.v0.Graph
import gs.graph.v0.Size import gs.graph.v0.Size
import gs.graph.v0.Vertex import gs.graph.v0.Vertex
import gs.graph.v0.data.AnyGraphWithData
object GraphTraversalFs2: object GraphTraversalFs2:
@ -24,6 +25,18 @@ object GraphTraversalFs2:
.map(root => pullDfs(graph, state, visit, root).stream.unNoneTerminate) .map(root => pullDfs(graph, state, visit, root).stream.unNoneTerminate)
.reduce(_ ++ _) .reduce(_ ++ _)
def dfs[F[_]: Sync, A, Out](
graph: AnyGraphWithData[A],
visit: A => F[Out]
): Stream[F, Out] =
if graph.isEmpty then Stream.empty
else
val state = new DfsState(graph.numberOfVertices)
graph
.selectRoots()
.map(root => pullDfs(graph, state, visit, root).stream.unNoneTerminate)
.reduce(_ ++ _)
def bfs[F[_]: Async, Out]( def bfs[F[_]: Async, Out](
graph: Graph, graph: Graph,
visit: Vertex => F[Out] visit: Vertex => F[Out]
@ -41,6 +54,23 @@ object GraphTraversalFs2:
.map(_.reduce(_ ++ _)) .map(_.reduce(_ ++ _))
.map(_.unNone) .map(_.unNone)
def bfs[F[_]: Async, A, Out](
graph: AnyGraphWithData[A],
visit: A => F[Out]
): F[Stream[F, Out]] =
if graph.isEmpty then Async[F].pure(Stream.empty)
else
graph
.selectRoots()
.map { root =>
BfsState.initialize[F](graph.numberOfVertices).flatMap { state =>
doBfs(graph, visit, state, root)
}
}
.sequence
.map(_.reduce(_ ++ _))
.map(_.unNone)
private def doBfs[F[_]: Async, Out]( private def doBfs[F[_]: Async, Out](
graph: Graph, graph: Graph,
visit: Vertex => F[Out], visit: Vertex => F[Out],
@ -64,6 +94,29 @@ object GraphTraversalFs2:
} }
} }
private def doBfs[F[_]: Async, A, Out](
graph: AnyGraphWithData[A],
visit: A => F[Out],
state: BfsState[F],
root: Vertex
): F[Stream[F, Option[Out]]] =
state.enqueue(root).map { _ =>
Stream
.repeatEval(state.dequeue())
.unNoneTerminate
.evalMap { vertex =>
state.isVisited(vertex).flatMap {
case true => Async[F].delay(None)
case false =>
for
_ <- state.visit(vertex)
out <- visit(graph.data(vertex.ordinal))
_ <- enqueueAllNeighbors(vertex, graph, state)
yield Some(out)
}
}
}
private def enqueueAllNeighbors[F[_]: Async]( private def enqueueAllNeighbors[F[_]: Async](
vertex: Vertex, vertex: Vertex,
graph: Graph, graph: Graph,
@ -101,6 +154,29 @@ object GraphTraversalFs2:
} }
} }
private def pullDfs[F[_]: Sync, A, Out](
graph: AnyGraphWithData[A],
state: DfsState[F],
visit: A => F[Out],
current: Vertex
): Pull[F, Option[Out], Unit] =
Pull.eval(state.isVisited(current)).flatMap { visited =>
if visited then Pull.output1(None) >> Pull.done
else
Pull.eval(state.visit(current))
>> Pull
.eval(visit(graph.data(current.ordinal)))
.flatMap(out => Pull.output1(Some(out)))
>> graph
.neighbors(current)
.foldLeft(Pull.done: Pull[F, Option[Out], Unit]) {
(
acc,
neighbor
) => acc >> pullDfs(graph, state, visit, neighbor)
}
}
abstract private class TraverseState[F[_]: Sync](n: Size): abstract private class TraverseState[F[_]: Sync](n: Size):
val visited: Array[Boolean] = Array.fill(n.value)(false) val visited: Array[Boolean] = Array.fill(n.value)(false)