Compare commits

..

7 commits
0.1.1 ... main

Author SHA1 Message Date
b7feb7a341 (patch) general code improvements and more documentation (#8)
All checks were successful
/ Build and Release Library (push) Successful in 2m30s
Reviewed-on: #8
2026-03-28 03:23:43 +00:00
c297f2f2ca (patch) fix incoming algorithm -- inverted boolean (#7)
All checks were successful
/ Build and Release Library (push) Successful in 2m31s
Reviewed-on: #7
2026-01-26 03:03:36 +00:00
ebd7d98ebb (patch) upgrade pre-commit (#6)
All checks were successful
/ Build and Release Library (push) Successful in 2m22s
Reviewed-on: #6
2026-01-25 03:55:13 +00:00
cf1690aebb (patch) update scala to 3.8.1 (#5)
All checks were successful
/ Build and Release Library (push) Successful in 2m22s
Reviewed-on: #5
2026-01-24 17:13:19 +00:00
05258f86f5 (patch) add incoming edges (#4)
All checks were successful
/ Build and Release Library (push) Successful in 2m21s
Reviewed-on: #4
2026-01-08 04:26:52 +00:00
5827a13f21 (patch) Add data variants to traversals in fs2 (#3)
All checks were successful
/ Build and Release Library (push) Successful in 2m24s
Reviewed-on: #3
2026-01-06 03:59:07 +00:00
db6b785c1f (patch) add bfs (#2)
All checks were successful
/ Build and Release Library (push) Successful in 2m19s
Reviewed-on: #2
2026-01-04 03:16:44 +00:00
14 changed files with 590 additions and 71 deletions

View file

@ -1,7 +1,7 @@
--- ---
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0 rev: v6.0.0
hooks: hooks:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
@ -12,6 +12,6 @@ repos:
- id: trailing-whitespace - id: trailing-whitespace
- id: check-yaml - id: check-yaml
- repo: https://git.garrity.co/garrity-software/gs-pre-commit-scala - repo: https://git.garrity.co/garrity-software/gs-pre-commit-scala
rev: v1.0.1 rev: v1.0.2
hooks: hooks:
- id: scalafmt - id: scalafmt

View file

@ -1,5 +1,5 @@
// See: https://github.com/scalameta/scalafmt/tags for the latest tags. // See: https://github.com/scalameta/scalafmt/tags for the latest tags.
version = 3.10.2 version = 3.10.4
runner.dialect = scala3 runner.dialect = scala3
maxColumn = 80 maxColumn = 80

View file

@ -1,4 +1,4 @@
val scala3: String = "3.7.4" val scala3: String = "3.8.2"
ThisBuild / scalaVersion := scala3 ThisBuild / scalaVersion := scala3
ThisBuild / versionScheme := Some("semver-spec") ThisBuild / versionScheme := Some("semver-spec")
@ -31,14 +31,14 @@ val Deps = new {
} }
val Fs2 = new { val Fs2 = new {
val Core: ModuleID = "co.fs2" %% "fs2-core" % "3.12.2" val Core: ModuleID = "co.fs2" %% "fs2-core" % "3.13.0"
} }
val Gs = new { val Gs = new {
val Datagen: ModuleID = "gs" %% "gs-datagen-core-v0" % "0.3.3" val Datagen: ModuleID = "gs" %% "gs-datagen-core-v0" % "0.4.1"
} }
val MUnit: ModuleID = "org.scalameta" %% "munit" % "1.1.1" val MUnit: ModuleID = "org.scalameta" %% "munit" % "1.2.4"
} }
lazy val testSettings = Seq( lazy val testSettings = Seq(
@ -51,6 +51,7 @@ lazy val testSettings = Seq(
lazy val `gs-graph` = project lazy val `gs-graph` = project
.in(file(".")) .in(file("."))
.aggregate(core, cats, fs2) .aggregate(core, cats, fs2)
.settings(noPublishSettings)
.settings(sharedSettings) .settings(sharedSettings)
.settings(testSettings) .settings(testSettings)
.settings(name := s"${gsProjectName.value}-v${semVerMajor.value}") .settings(name := s"${gsProjectName.value}-v${semVerMajor.value}")

View file

@ -20,6 +20,26 @@ final class Adjacency(val neighbors: Vector[Vector[Vertex]]):
*/ */
def at(vertex: Vertex): Vector[Vertex] = neighbors(vertex.ordinal) def at(vertex: Vertex): Vector[Vertex] = neighbors(vertex.ordinal)
/** Get the vector of _incoming_ [[Vertex]] that point _to_ some input
* [[Vertex]].
*
* @param vertex
* The [[Vertex]] for which to retrieve the incoming vertices.
* @return
* The list of [[Vertex]] that have an edge _to_ the input [[Vertex]].
*/
def incoming(vertex: Vertex): Vector[Vertex] =
if vertex.ordinal >= neighbors.length then Vector.empty
else
neighbors.zipWithIndex
.filter {
// ignore the neighbors of the input vertex
case (_, index) => index != vertex.ordinal
}
.filter(_._1.contains(vertex))
.map(_._2)
.map(Vertex(_))
/** Express this [[Adjacency]] as a vector of [[Edge]]. /** Express this [[Adjacency]] as a vector of [[Edge]].
* *
* @return * @return
@ -31,10 +51,9 @@ final class Adjacency(val neighbors: Vector[Vector[Vertex]]):
tos.map(to => Edge(from, to)).distinct tos.map(to => Edge(from, to)).distinct
} }
/** @return /** The number of vertices represented by this adjacency list.
* The number of vertices represented by this adjacency list.
*/ */
def numberOfVertices: Size = Size.fromVector(neighbors) lazy val numberOfVertices: Size = Size.fromVector(neighbors)
/** Perform a linear traversal for each [[gs.graph.v0.Vertex]] to calculate /** Perform a linear traversal for each [[gs.graph.v0.Vertex]] to calculate
* the total number of edges in this adjacency list. * the total number of edges in this adjacency list.
@ -42,9 +61,11 @@ final class Adjacency(val neighbors: Vector[Vector[Vertex]]):
* @return * @return
* The number of edges in this adjacency list. * The number of edges in this adjacency list.
*/ */
def numberOfEdges: Size = Size(neighbors.map(_.length).reduce(_ + _)) lazy val numberOfEdges: Size = Size(neighbors.map(_.length).reduce(_ + _))
def findRoots(): Vector[Vertex] = /** All vertices that do not have inbound connections.
*/
lazy val roots: Vector[Vertex] =
val counts = Array.fill(neighbors.length)(0) val counts = Array.fill(neighbors.length)(0)
neighbors.foreach { ns => neighbors.foreach { ns =>
// Each vertex listed here is receiving an inbound connection, if we // Each vertex listed here is receiving an inbound connection, if we
@ -74,6 +95,17 @@ object Adjacency:
*/ */
def apply(adj: Vector[Vector[Vertex]]): Adjacency = new Adjacency(adj) def apply(adj: Vector[Vector[Vertex]]): Adjacency = new Adjacency(adj)
/** Create an empty adjacency list for some number of vertices. No vertex has
* any connections to another vertex.
*
* @param numberOfVertices
* The number of vertices in this disconnected graph.
* @return
* Some new empty adjacency.
*/
def empty(numberOfVertices: Size): Adjacency =
new Adjacency(Vector.fill(numberOfVertices.value)(Vector.empty))
given CanEqual[Adjacency, Adjacency] = CanEqual.derived given CanEqual[Adjacency, Adjacency] = CanEqual.derived
/** @return /** @return
@ -86,7 +118,8 @@ object Adjacency:
*/ */
final val Single: Adjacency = new Adjacency(Vector(Vector.empty)) final val Single: Adjacency = new Adjacency(Vector(Vector.empty))
/** Calculate an [[Adjacency]] from some collection of [[Edge]]. /** Calculate an [[Adjacency]] from some collection of [[Edge]], where those
* edges are assumed to be directed.
* *
* @param numberOfVertices * @param numberOfVertices
* The number of [[Vertex]] (`N`) in this graph. * The number of [[Vertex]] (`N`) in this graph.

View file

@ -1,5 +1,7 @@
package gs.graph.v0 package gs.graph.v0
import java.util.Objects
/** Represents a relationship between two [[Vertex]]. /** Represents a relationship between two [[Vertex]].
* *
* When used is a directed context, the edge goes _from_ `v1` _to_ `v2`. * When used is a directed context, the edge goes _from_ `v1` _to_ `v2`.
@ -41,6 +43,11 @@ final class Edge(
case other: Edge => v1 == other.v1 && v2 == other.v2 case other: Edge => v1 == other.v1 && v2 == other.v2
case _ => false case _ => false
/** @inheritDocs
*/
override def hashCode(): Int =
Objects.hash(v1, v2)
end Edge end Edge
object Edge: object Edge:

View file

@ -1,5 +1,10 @@
package gs.graph.v0 package gs.graph.v0
/** Describes the fundamental disposition of some graph.
*
* @param name
* The string value of this disposition.
*/
sealed abstract class GraphDisposition(val name: String): sealed abstract class GraphDisposition(val name: String):
override def equals(that: Any): Boolean = override def equals(that: Any): Boolean =
@ -15,7 +20,12 @@ object GraphDisposition:
given CanEqual[GraphDisposition, GraphDisposition] = CanEqual.derived given CanEqual[GraphDisposition, GraphDisposition] = CanEqual.derived
/** The graph has directed relationships between vertices.
*/
case object Directed extends GraphDisposition("directed") case object Directed extends GraphDisposition("directed")
/** The graph has relationships between vertices with no logical direction.
*/
case object Undirected extends GraphDisposition("undirected") case object Undirected extends GraphDisposition("undirected")
end GraphDisposition end GraphDisposition

View file

@ -2,6 +2,7 @@ package gs.graph.v0
import gs.graph.v0.data.AnyGraphWithData import gs.graph.v0.data.AnyGraphWithData
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import scala.collection.mutable.Queue
import scala.collection.mutable.Stack import scala.collection.mutable.Stack
/** Graph traversal algorithms including DFS and BFS. /** Graph traversal algorithms including DFS and BFS.
@ -26,9 +27,8 @@ object GraphTraversal:
val s = Stack.empty[Vertex] val s = Stack.empty[Vertex]
val discovered = Array.fill(graph.numberOfVertices.value)(false) val discovered = Array.fill(graph.numberOfVertices.value)(false)
val root = Vertex.Zero graph.selectRoots().foreach { root =>
s.push(root) val _ = s.push(root)
while !s.isEmpty while !s.isEmpty
do do
val v = s.pop() val v = s.pop()
@ -38,6 +38,7 @@ object GraphTraversal:
graph.neighbors(v).foreach(w => s.push(w)) graph.neighbors(v).foreach(w => s.push(w))
else () else ()
() ()
}
/** Depth-first search that executes a function on each [[Vertex]] to produce /** Depth-first search that executes a function on each [[Vertex]] to produce
* some output. This function will operate on _any_ [[Graph]]. * some output. This function will operate on _any_ [[Graph]].
@ -60,9 +61,8 @@ object GraphTraversal:
val s = Stack.empty[Vertex] val s = Stack.empty[Vertex]
val discovered = Array.fill(graph.numberOfVertices.value)(false) val discovered = Array.fill(graph.numberOfVertices.value)(false)
val root = Vertex.Zero graph.selectRoots().foreach { root =>
s.push(root) val _ = s.push(root)
while !s.isEmpty while !s.isEmpty
do do
val v = s.pop() val v = s.pop()
@ -71,6 +71,7 @@ object GraphTraversal:
discovered(v.ordinal) = true discovered(v.ordinal) = true
graph.neighbors(v).foreach(w => s.push(w)) graph.neighbors(v).foreach(w => s.push(w))
else () else ()
}
output.toList output.toList
@ -84,8 +85,7 @@ object GraphTraversal:
val discovered = Array.fill(graph.numberOfVertices.value)(false) val discovered = Array.fill(graph.numberOfVertices.value)(false)
graph.selectRoots().foreach { root => graph.selectRoots().foreach { root =>
s.push(root) val _ = s.push(root)
while !s.isEmpty while !s.isEmpty
do do
val v = s.pop() val v = s.pop()
@ -117,8 +117,7 @@ object GraphTraversal:
val discovered = Array.fill(graph.numberOfVertices.value)(false) val discovered = Array.fill(graph.numberOfVertices.value)(false)
graph.selectRoots().foreach { root => graph.selectRoots().foreach { root =>
s.push(root) val _ = s.push(root)
while !s.isEmpty while !s.isEmpty
do do
val v = s.pop() val v = s.pop()
@ -153,8 +152,7 @@ object GraphTraversal:
val discovered = Array.fill(graph.numberOfVertices.value)(false) val discovered = Array.fill(graph.numberOfVertices.value)(false)
graph.selectRoots().foreach { root => graph.selectRoots().foreach { root =>
s.push(root) val _ = s.push(root)
while !s.isEmpty while !s.isEmpty
do do
val v = s.pop() val v = s.pop()
@ -177,8 +175,7 @@ object GraphTraversal:
val discovered = Array.fill(graph.numberOfVertices.value)(false) val discovered = Array.fill(graph.numberOfVertices.value)(false)
graph.selectRoots().foreach { root => graph.selectRoots().foreach { root =>
s.push(root) val _ = s.push(root)
while !s.isEmpty while !s.isEmpty
do do
val v = s.pop() val v = s.pop()
@ -191,4 +188,139 @@ object GraphTraversal:
acc acc
def bfs(
graph: Graph,
visit: Vertex => Unit
): Unit =
val q = Queue.empty[Vertex]
val visited = Array.fill(graph.numberOfVertices.value)(false)
val _ = graph.selectRoots().foreach(q.enqueue)
while !q.isEmpty
do
val v = q.dequeue()
if !visited(v.ordinal) then
val _ = visit(v)
visited(v.ordinal) = true
graph.neighbors(v).foreach { neighbor =>
if !visited(neighbor.ordinal) then q.enqueue(neighbor)
else ()
}
else ()
def bfs[Out](
graph: Graph,
visit: Vertex => Out
): List[Out] =
val output = ListBuffer.empty[Out]
val q = Queue.empty[Vertex]
val visited = Array.fill(graph.numberOfVertices.value)(false)
val _ = graph.selectRoots().foreach(q.enqueue)
while !q.isEmpty
do
val v = q.dequeue()
if !visited(v.ordinal) then
val _ = output.addOne(visit(v))
visited(v.ordinal) = true
graph.neighbors(v).foreach { neighbor =>
if !visited(neighbor.ordinal) then q.enqueue(neighbor)
else ()
}
else ()
output.toList
def bfsFold[Acc](
graph: Graph,
initial: Acc,
f: (Acc, Vertex) => Acc
): Acc =
var acc = initial
val q = Queue.empty[Vertex]
val visited = Array.fill(graph.numberOfVertices.value)(false)
val _ = graph.selectRoots().foreach(q.enqueue)
while !q.isEmpty
do
val v = q.dequeue()
if !visited(v.ordinal) then
acc = f(acc, v)
visited(v.ordinal) = true
graph.neighbors(v).foreach { neighbor =>
if !visited(neighbor.ordinal) then q.enqueue(neighbor)
else ()
}
else ()
acc
def bfs[A](
graph: AnyGraphWithData[A],
visit: (Vertex, A) => Unit
): Unit =
val q = Queue.empty[Vertex]
val visited = Array.fill(graph.numberOfVertices.value)(false)
val _ = graph.selectRoots().foreach(q.enqueue)
while !q.isEmpty
do
val v = q.dequeue()
if !visited(v.ordinal) then
val _ = visit(v, graph.data(v.ordinal))
visited(v.ordinal) = true
graph.neighbors(v).foreach { neighbor =>
if !visited(neighbor.ordinal) then q.enqueue(neighbor)
else ()
}
else ()
()
def bfs[A, Out](
graph: AnyGraphWithData[A],
visit: (Vertex, A) => Out
): List[Out] =
val output = ListBuffer.empty[Out]
val q = Queue.empty[Vertex]
val visited = Array.fill(graph.numberOfVertices.value)(false)
val _ = graph.selectRoots().foreach(q.enqueue)
while !q.isEmpty
do
val v = q.dequeue()
if !visited(v.ordinal) then
val _ = output.addOne(visit(v, graph.data(v.ordinal)))
visited(v.ordinal) = true
graph.neighbors(v).foreach { neighbor =>
if !visited(neighbor.ordinal) then q.enqueue(neighbor)
else ()
}
else ()
output.toList
def bfsFold[A, Acc](
graph: AnyGraphWithData[A],
initial: Acc,
f: (Acc, A) => Acc
): Acc =
var acc = initial
val q = Queue.empty[Vertex]
val visited = Array.fill(graph.numberOfVertices.value)(false)
val _ = graph.selectRoots().foreach(q.enqueue)
while !q.isEmpty
do
val v = q.dequeue()
if !visited(v.ordinal) then
acc = f(acc, graph.data(v.ordinal))
visited(v.ordinal) = true
graph.neighbors(v).foreach { neighbor =>
if !visited(neighbor.ordinal) then q.enqueue(neighbor)
else ()
}
else ()
acc
end GraphTraversal end GraphTraversal

View file

@ -29,6 +29,8 @@ class Digraph(
*/ */
override def selectRoots(): Vector[Vertex] = roots override def selectRoots(): Vector[Vertex] = roots
/** @inheritDocs
*/
override def equals(that: Any): Boolean = override def equals(that: Any): Boolean =
that match that match
case other: Digraph => case other: Digraph =>
@ -78,7 +80,7 @@ object Digraph:
new Digraph( new Digraph(
numberOfVertices = adjacency.numberOfVertices, numberOfVertices = adjacency.numberOfVertices,
adjacency = adjacency, adjacency = adjacency,
roots = adjacency.findRoots() roots = adjacency.roots
) )
/** Find all roots for the given collection of [[Edge]]. /** Find all roots for the given collection of [[Edge]].
@ -102,7 +104,11 @@ object Digraph:
// If there are any vertices with no incoming connections, those are roots. // If there are any vertices with no incoming connections, those are roots.
// Note that these may be completely disconnected. // Note that these may be completely disconnected.
counts.filter(_ == 0).map(Vertex(_)).toVector counts.zipWithIndex
.filter(_._1 == 0)
.map { case (_, index) => Vertex(index) }
.toVector
.distinct
/** Determine whether the given [[Digraph]] has any cycles. /** Determine whether the given [[Digraph]] has any cycles.
* *

View file

@ -6,14 +6,30 @@ import gs.graph.v0.GraphException
import gs.graph.v0.Size import gs.graph.v0.Size
import gs.graph.v0.Vertex import gs.graph.v0.Vertex
/** Specialization of [[Digraph]] that always has a single root vertex.
*
* @param n
* The number of [[Vertex]] present in this graph.
* @param a
* The [[Adjacency]] that describes this graph.
* @param r
* The singular root [[Vertex]].
*/
class SingleRootDigraph( class SingleRootDigraph(
n: Size, n: Size,
a: Adjacency, a: Adjacency,
r: Vertex root: Vertex
) extends Digraph(n, a, Vector(r)) ) extends Digraph(n, a, Vector(root))
object SingleRootDigraph: object SingleRootDigraph:
/** Attempt to show that the given [[Digraph]] has a single root.
*
* @param dg
* The input [[Digraph]].
* @return
* [[SingleRootDigraph]] or `None` if the number of roots is not 1.
*/
def fromDirectedGraph(dg: Digraph): Option[SingleRootDigraph] = def fromDirectedGraph(dg: Digraph): Option[SingleRootDigraph] =
if dg.roots.size == 1 then if dg.roots.size == 1 then
Some( Some(
@ -25,6 +41,18 @@ object SingleRootDigraph:
) )
else None else None
/** Given some edges, build a [[SingleRootDigraph]]. Throw an exception if
* this operation fails.
*
* @param numberOfVertices
* The number of [[Vertex]] in the graph.
* @param edges
* The collection of [[Edge]] that describe the graph.
* @param root
* The root [[Vertex]].
* @return
* The new [[SingleRootDigraph]].
*/
def fromEdgesUnsafe( def fromEdgesUnsafe(
numberOfVertices: Size, numberOfVertices: Size,
edges: Iterable[Edge], edges: Iterable[Edge],
@ -38,6 +66,17 @@ object SingleRootDigraph:
) )
else throw GraphException.RootOutOfBounds(root, numberOfVertices) else throw GraphException.RootOutOfBounds(root, numberOfVertices)
/** Given some edges, build a [[SingleRootDigraph]] if that collection
* represents a graph with a single root.
*
* @param numberOfVertices
* The number of [[Vertex]] in the graph.
* @param edges
* The collection of [[Edge]] that describe the graph.
* @return
* The new [[SingleRootDigraph]], or `None` if the edges do not describe a
* digraph with a single root.
*/
def fromEdges( def fromEdges(
numberOfVertices: Size, numberOfVertices: Size,
edges: Iterable[Edge] edges: Iterable[Edge]
@ -53,6 +92,19 @@ object SingleRootDigraph:
) )
else None else None
/** Given some [[Adjacency]] and a given root [[Vertex]], instantiate a new
* [[SingleRootDigraph]].
*
* Throws an exception if the given root is not contained within the
* [[Adjacency]].
*
* @param adjacency
* The [[Adjacency]] which describes the graph.
* @param root
* The root [[Vertex]].
* @return
* New [[SingleRootDigraph]]
*/
def fromAdjacencyUnsafe( def fromAdjacencyUnsafe(
adjacency: Adjacency, adjacency: Adjacency,
root: Vertex root: Vertex
@ -65,10 +117,20 @@ object SingleRootDigraph:
) )
else throw GraphException.RootOutOfBounds(root, adjacency.numberOfVertices) else throw GraphException.RootOutOfBounds(root, adjacency.numberOfVertices)
/** Given some [[Adjacency]] and a given root [[Vertex]], instantiate a new
* [[SingleRootDigraph]] if the [[Vertex]] is within the graph..
*
* @param adjacency
* The [[Adjacency]] which describes the graph.
* @param root
* The root [[Vertex]].
* @return
* New [[SingleRootDigraph]], or `None` if the given root is not valid.
*/
def fromAdjacency( def fromAdjacency(
adjacency: Adjacency adjacency: Adjacency
): Option[SingleRootDigraph] = ): Option[SingleRootDigraph] =
val roots = adjacency.findRoots() val roots = adjacency.roots
if roots.size == 1 then if roots.size == 1 then
Some( Some(
new SingleRootDigraph( new SingleRootDigraph(

View file

@ -0,0 +1,27 @@
package gs.graph.v0
class AdjacencyTests extends munit.FunSuite:
test("should provide incoming connections") {
val N = Size(7)
val vs = (0 until N.value).map(Vertex(_)).toArray
val E = List(
Edge(vs(0), vs(1)),
Edge(vs(0), vs(2)),
Edge(vs(0), vs(3)),
Edge(vs(1), vs(4)),
Edge(vs(2), vs(4)),
Edge(vs(3), vs(4)),
Edge(vs(3), vs(5)),
Edge(vs(4), vs(6))
)
val A = Adjacency.fromDirectedEdges(N, E)
assertEquals(A.incoming(vs(0)), Vector.empty)
assertEquals(A.incoming(vs(1)), Vector(vs(0)))
assertEquals(A.incoming(vs(2)), Vector(vs(0)))
assertEquals(A.incoming(vs(3)), Vector(vs(0)))
assertEquals(A.incoming(vs(4)), Vector(vs(1), vs(2), vs(3)))
assertEquals(A.incoming(vs(5)), Vector(vs(3)))
assertEquals(A.incoming(vs(6)), Vector(vs(4)))
}

View file

@ -1,12 +1,15 @@
package gs.graph.v0.fs2 package gs.graph.v0.fs2
import cats.effect.Async
import cats.effect.Sync import cats.effect.Sync
import cats.effect.std.Queue
import cats.syntax.all.*
import fs2.Pull import fs2.Pull
import fs2.Stream import fs2.Stream
import gs.graph.v0.Graph import gs.graph.v0.Graph
import gs.graph.v0.Size import gs.graph.v0.Size
import gs.graph.v0.Vertex import gs.graph.v0.Vertex
import scala.collection.mutable.Stack import gs.graph.v0.data.AnyGraphWithData
object GraphTraversalFs2: object GraphTraversalFs2:
@ -19,19 +22,127 @@ object GraphTraversalFs2:
val state = new DfsState(graph.numberOfVertices) val state = new DfsState(graph.numberOfVertices)
graph graph
.selectRoots() .selectRoots()
.map(root => pull(graph, state, visit, root).stream.unNoneTerminate) .map(root => pullDfs(graph, state, visit, root).stream.unNoneTerminate)
.reduce(_ ++ _) .reduce(_ ++ _)
private def pull[F[_]: Sync, Out]( def dfs[F[_]: Sync, A, Out](
graph: AnyGraphWithData[A],
visit: A => F[Out]
): Stream[F, Out] =
if graph.isEmpty then Stream.empty
else
val state = new DfsState(graph.numberOfVertices)
graph
.selectRoots()
.map(root => pullDfs(graph, state, visit, root).stream.unNoneTerminate)
.reduce(_ ++ _)
def bfs[F[_]: Async, Out](
graph: Graph, graph: Graph,
state: DfsState, visit: Vertex => F[Out]
): F[Stream[F, Out]] =
if graph.isEmpty then Async[F].pure(Stream.empty)
else
graph
.selectRoots()
.map { root =>
BfsState.initialize[F](graph.numberOfVertices).flatMap { state =>
doBfs(graph, visit, state, root)
}
}
.sequence
.map(_.reduce(_ ++ _))
.map(_.unNone)
def bfs[F[_]: Async, A, Out](
graph: AnyGraphWithData[A],
visit: A => F[Out]
): F[Stream[F, Out]] =
if graph.isEmpty then Async[F].pure(Stream.empty)
else
graph
.selectRoots()
.map { root =>
BfsState.initialize[F](graph.numberOfVertices).flatMap { state =>
doBfs(graph, visit, state, root)
}
}
.sequence
.map(_.reduce(_ ++ _))
.map(_.unNone)
private def doBfs[F[_]: Async, Out](
graph: Graph,
visit: Vertex => F[Out],
state: BfsState[F],
root: Vertex
): F[Stream[F, Option[Out]]] =
state.enqueue(root).map { _ =>
Stream
.repeatEval(state.dequeue())
.unNoneTerminate
.evalMap { vertex =>
state.isVisited(vertex).flatMap {
case true => Async[F].delay(None)
case false =>
for
_ <- state.visit(vertex)
out <- visit(vertex)
_ <- enqueueAllNeighbors(vertex, graph, state)
yield Some(out)
}
}
}
private def doBfs[F[_]: Async, A, Out](
graph: AnyGraphWithData[A],
visit: A => F[Out],
state: BfsState[F],
root: Vertex
): F[Stream[F, Option[Out]]] =
state.enqueue(root).map { _ =>
Stream
.repeatEval(state.dequeue())
.unNoneTerminate
.evalMap { vertex =>
state.isVisited(vertex).flatMap {
case true => Async[F].delay(None)
case false =>
for
_ <- state.visit(vertex)
out <- visit(graph.data(vertex.ordinal))
_ <- enqueueAllNeighbors(vertex, graph, state)
yield Some(out)
}
}
}
private def enqueueAllNeighbors[F[_]: Async](
vertex: Vertex,
graph: Graph,
state: BfsState[F]
): F[Unit] =
graph
.neighbors(vertex)
.map { vertex =>
state.isVisited(vertex).flatMap {
case true => Async[F].unit
case false => state.enqueue(vertex)
}
}
.sequence
.as(())
private def pullDfs[F[_]: Sync, Out](
graph: Graph,
state: DfsState[F],
visit: Vertex => F[Out], visit: Vertex => F[Out],
current: Vertex current: Vertex
): Pull[F, Option[Out], Unit] = ): Pull[F, Option[Out], Unit] =
Pull.eval(Sync[F].delay(state.isDiscovered(current))).flatMap { discovered => Pull.eval(state.isVisited(current)).flatMap { visited =>
if discovered then Pull.output1(None) >> Pull.done if visited then Pull.output1(None) >> Pull.done
else else
Pull.eval(Sync[F].delay(state.discover(current))) Pull.eval(state.visit(current))
>> Pull.eval(visit(current)).flatMap(out => Pull.output1(Some(out))) >> Pull.eval(visit(current)).flatMap(out => Pull.output1(Some(out)))
>> graph >> graph
.neighbors(current) .neighbors(current)
@ -39,25 +150,62 @@ object GraphTraversalFs2:
( (
acc, acc,
neighbor neighbor
) => acc >> pull(graph, state, visit, neighbor) ) => acc >> pullDfs(graph, state, visit, neighbor)
} }
} }
final private class DfsState(n: Size): private def pullDfs[F[_]: Sync, A, Out](
graph: AnyGraphWithData[A],
state: DfsState[F],
visit: A => F[Out],
current: Vertex
): Pull[F, Option[Out], Unit] =
Pull.eval(state.isVisited(current)).flatMap { visited =>
if visited then Pull.output1(None) >> Pull.done
else
Pull.eval(state.visit(current))
>> Pull
.eval(visit(graph.data(current.ordinal)))
.flatMap(out => Pull.output1(Some(out)))
>> graph
.neighbors(current)
.foldLeft(Pull.done: Pull[F, Option[Out], Unit]) {
(
acc,
neighbor
) => acc >> pullDfs(graph, state, visit, neighbor)
}
}
val stack: Stack[Vertex] = Stack.empty abstract private class TraverseState[F[_]: Sync](n: Size):
val discovered: Array[Boolean] = Array.fill(n.value)(false) val visited: Array[Boolean] = Array.fill(n.value)(false)
def push(vertex: Vertex): Unit = stack.push(vertex) def isVisited(vertex: Vertex): F[Boolean] =
Sync[F].delay(visited(vertex.ordinal))
def pop(): Vertex = stack.pop() def visit(vertex: Vertex): F[Unit] =
Sync[F].delay(visited(vertex.ordinal) = true)
def isDiscovered(vertex: Vertex): Boolean = final private class DfsState[F[_]: Sync](n: Size) extends TraverseState[F](n)
discovered(vertex.ordinal)
def discover(vertex: Vertex): Unit =
discovered(vertex.ordinal) = true
end DfsState end DfsState
final private class BfsState[F[_]: Async](
n: Size,
queue: Queue[F, Vertex]
) extends TraverseState[F](n):
def enqueue(vertex: Vertex): F[Unit] = queue.offer(vertex)
def dequeue(): F[Option[Vertex]] = queue.tryTake
end BfsState
private object BfsState:
def initialize[F[_]: Async](n: Size): F[BfsState[F]] =
Queue.bounded[F, Vertex](n.value).map(new BfsState(n, _))
end BfsState
end GraphTraversalFs2 end GraphTraversalFs2

View file

@ -0,0 +1,93 @@
package gs.graph.v0.fs2
import cats.effect.IO
import cats.effect.unsafe.IORuntime
import gs.graph.v0.Edge
import gs.graph.v0.Size
import gs.graph.v0.UndirectedGraph
import gs.graph.v0.Vertex
import gs.graph.v0.directed.Digraph
import munit.*
class Fs2BfsTests extends FunSuite:
given IORuntime = IORuntime.global
private def iotest(
name: String
)(
body: => IO[Unit]
)(
using
Location
): Unit =
test(name)(body.unsafeRunSync())
iotest("(BFS) should return an empty stream for an empty graph") {
val s = GraphTraversalFs2.bfs(
UndirectedGraph.Empty,
_ => IO.raiseError(IllegalStateException("Should not reach this point."))
)
s.flatMap(_.compile.last.map(result => assertEquals(result, None)))
}
iotest(
"(BFS) should return a stream of one for a graph with a single vertex"
) {
val s1 = GraphTraversalFs2.bfs(UndirectedGraph.Single, v => IO(v))
val s2 = GraphTraversalFs2.bfs(Digraph.Single, v => IO(v))
for
r1 <- s1.flatMap(_.compile.toList)
r2 <- s2.flatMap(_.compile.toList)
yield
assertEquals(r1, List(Vertex.Zero))
assertEquals(r2, List(Vertex.Zero))
}
iotest(
"(BFS) should visit a graph of three nodes"
) {
val vs = Array(Vertex(0), Vertex(1), Vertex(2))
val edges = List(
Edge(vs(0), vs(1)),
Edge(vs(0), vs(2))
)
val expected = vs.toList
val graph = Digraph.fromEdges(Size(vs.length), edges)
val s = GraphTraversalFs2.bfs(graph, v => IO(v))
s.flatMap(_.compile.toList).map(assertEquals(_, expected))
}
iotest(
"(BFS) should visit a graph of five nodes"
) {
val vs = Array(Vertex(0), Vertex(1), Vertex(2), Vertex(3), Vertex(4))
val edges = List(
Edge(vs(0), vs(1)),
Edge(vs(0), vs(2)),
Edge(vs(1), vs(2)),
Edge(vs(2), vs(3)),
Edge(vs(2), vs(4))
)
val expected = vs.toList
val graph = Digraph.fromEdges(Size(vs.length), edges)
val s = GraphTraversalFs2.bfs(graph, v => IO(v))
s.flatMap(_.compile.toList).map(assertEquals(_, expected))
}
iotest(
"(BFS) should visit a multi-root graph with disconnected components"
) {
val vs = (0 until 6).map(Vertex(_)).toArray
val edges = List(
Edge(vs(0), vs(2)),
Edge(vs(0), vs(3)),
Edge(vs(2), vs(1)),
Edge(vs(4), vs(5))
)
val expected = List(vs(0), vs(2), vs(3), vs(1), vs(4), vs(5))
val graph = Digraph.fromEdges(Size(vs.length), edges)
val s = GraphTraversalFs2.bfs(graph, v => IO(v))
s.flatMap(_.compile.toList).map(assertEquals(_, expected))
}

View file

@ -1 +1 @@
sbt.version=1.11.7 sbt.version=1.12.8

View file

@ -28,6 +28,6 @@ externalResolvers := Seq(
"Garrity Software Releases" at "https://maven.garrity.co/gs" "Garrity Software Releases" at "https://maven.garrity.co/gs"
) )
addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.3.1") addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.4.4")
addSbtPlugin("gs" % "sbt-garrity-software" % "0.6.0") addSbtPlugin("gs" % "sbt-garrity-software" % "0.7.0")
addSbtPlugin("gs" % "sbt-gs-semver" % "0.3.0") addSbtPlugin("gs" % "sbt-gs-semver" % "0.3.0")