diff --git a/modules/fs2/src/main/scala/gs/graph/v0/fs2/GraphTraversalFs2.scala b/modules/fs2/src/main/scala/gs/graph/v0/fs2/GraphTraversalFs2.scala index abc908d..f1480fe 100644 --- a/modules/fs2/src/main/scala/gs/graph/v0/fs2/GraphTraversalFs2.scala +++ b/modules/fs2/src/main/scala/gs/graph/v0/fs2/GraphTraversalFs2.scala @@ -9,6 +9,7 @@ import fs2.Stream import gs.graph.v0.Graph import gs.graph.v0.Size import gs.graph.v0.Vertex +import gs.graph.v0.data.AnyGraphWithData object GraphTraversalFs2: @@ -24,6 +25,18 @@ object GraphTraversalFs2: .map(root => pullDfs(graph, state, visit, root).stream.unNoneTerminate) .reduce(_ ++ _) + def dfs[F[_]: Sync, A, Out]( + graph: AnyGraphWithData[A], + visit: A => F[Out] + ): Stream[F, Out] = + if graph.isEmpty then Stream.empty + else + val state = new DfsState(graph.numberOfVertices) + graph + .selectRoots() + .map(root => pullDfs(graph, state, visit, root).stream.unNoneTerminate) + .reduce(_ ++ _) + def bfs[F[_]: Async, Out]( graph: Graph, visit: Vertex => F[Out] @@ -41,6 +54,23 @@ object GraphTraversalFs2: .map(_.reduce(_ ++ _)) .map(_.unNone) + def bfs[F[_]: Async, A, Out]( + graph: AnyGraphWithData[A], + visit: A => F[Out] + ): F[Stream[F, Out]] = + if graph.isEmpty then Async[F].pure(Stream.empty) + else + graph + .selectRoots() + .map { root => + BfsState.initialize[F](graph.numberOfVertices).flatMap { state => + doBfs(graph, visit, state, root) + } + } + .sequence + .map(_.reduce(_ ++ _)) + .map(_.unNone) + private def doBfs[F[_]: Async, Out]( graph: Graph, visit: Vertex => F[Out], @@ -64,6 +94,29 @@ object GraphTraversalFs2: } } + private def doBfs[F[_]: Async, A, Out]( + graph: AnyGraphWithData[A], + visit: A => F[Out], + state: BfsState[F], + root: Vertex + ): F[Stream[F, Option[Out]]] = + state.enqueue(root).map { _ => + Stream + .repeatEval(state.dequeue()) + .unNoneTerminate + .evalMap { vertex => + state.isVisited(vertex).flatMap { + case true => Async[F].delay(None) + case false => + for + _ <- state.visit(vertex) + out <- visit(graph.data(vertex.ordinal)) + _ <- enqueueAllNeighbors(vertex, graph, state) + yield Some(out) + } + } + } + private def enqueueAllNeighbors[F[_]: Async]( vertex: Vertex, graph: Graph, @@ -101,6 +154,29 @@ object GraphTraversalFs2: } } + private def pullDfs[F[_]: Sync, A, Out]( + graph: AnyGraphWithData[A], + state: DfsState[F], + visit: A => F[Out], + current: Vertex + ): Pull[F, Option[Out], Unit] = + Pull.eval(state.isVisited(current)).flatMap { visited => + if visited then Pull.output1(None) >> Pull.done + else + Pull.eval(state.visit(current)) + >> Pull + .eval(visit(graph.data(current.ordinal))) + .flatMap(out => Pull.output1(Some(out))) + >> graph + .neighbors(current) + .foldLeft(Pull.done: Pull[F, Option[Out], Unit]) { + ( + acc, + neighbor + ) => acc >> pullDfs(graph, state, visit, neighbor) + } + } + abstract private class TraverseState[F[_]: Sync](n: Size): val visited: Array[Boolean] = Array.fill(n.value)(false)