diff --git a/modules/auth/src/main/scala/gs/smolban/auth/Argon2Hash.scala b/modules/auth/src/main/scala/gs/smolban/auth/Argon2Hash.scala index 787be73..0c9cca6 100644 --- a/modules/auth/src/main/scala/gs/smolban/auth/Argon2Hash.scala +++ b/modules/auth/src/main/scala/gs/smolban/auth/Argon2Hash.scala @@ -1,6 +1,7 @@ package gs.smolban.auth import java.util.Base64 +import java.util.Objects import scala.util.Try /** Represents an Argon2 hash packed with the parameters that produced it. @@ -30,6 +31,30 @@ final class Argon2Hash( val hash: Array[Byte] ): + override def equals(obj: Any): Boolean = + obj match + case other: Argon2Hash => + (algorithmVersion == other.algorithmVersion) + && (algorithmType == other.algorithmType) + && (iterations == other.iterations) + && (parallelism == other.parallelism) + && (memoryInKb == other.memoryInKb) + && (salt.sameElements(other.salt)) + && (hash.sameElements(other.hash)) + + override def hashCode(): Int = + Objects.hash( + algorithmVersion, + algorithmType, + iterations, + parallelism, + memoryInKb, + salt, + hash + ) + + override def toString(): String = encode() + /** Encode this hash as a '$' delimited string that includes all parameters. * This string can be parsed by using the `decode` function. * @@ -37,7 +62,7 @@ final class Argon2Hash( * The encoded hash string. */ def encode(): String = - s"v=$algorithmVersion$$t=$algorithmType$$i=$iterations$$p=$parallelism$$m=$memoryInKb}$$${encodedSalt()}$$${encodedHash()}" + s"v=$algorithmVersion$$t=$algorithmType$$i=$iterations$$p=$parallelism$$m=$memoryInKb$$${encodedSalt()}$$${encodedHash()}" private def encodedSalt(): String = Base64.getEncoder().encodeToString(salt) diff --git a/modules/auth/src/test/scala/gs/smolban/auth/Argon2Tests.scala b/modules/auth/src/test/scala/gs/smolban/auth/Argon2Tests.scala new file mode 100644 index 0000000..8734fe8 --- /dev/null +++ b/modules/auth/src/test/scala/gs/smolban/auth/Argon2Tests.scala @@ -0,0 +1,70 @@ +package gs.smolban.auth + +import cats.effect.IO +import cats.effect.unsafe.IORuntime +import munit.Location +import org.bouncycastle.crypto.params.Argon2Parameters + +class Argon2Tests extends munit.FunSuite: + given IORuntime = IORuntime.global + + def iotest( + name: String + )( + f: => IO[Unit] + )( + using + Location + ): Unit = + test(name)(f.unsafeRunSync()) + + val rng: RandomByteProvider[IO] = RandomByteProvider.secureRandom[IO] + + val altConfig: Argon2.Config = Argon2.Config( + algorithmVersion = Argon2Parameters.ARGON2_VERSION_10, + algorithmType = Argon2Parameters.ARGON2_i, + iterations = 2, + parallelism = 1, + memoryInKb = 1024, + saltLengthInBytes = 8, + hashLengthInBytes = 64 + ) + + iotest("should calculate a hash and verify against that hash") { + val input = "some Complex password!1" + for + secret <- Argon2Secret.generate(32, rng) + argon2 <- IO(new Argon2[IO](Argon2.defaultConfig(), secret, rng)) + hash <- argon2.calculateHash(input) + matched <- argon2.doesInputMatch(input, hash) + encoded <- IO(hash.encode()) + decoded <- IO(Argon2Hash.decode(encoded)) + yield + assertEquals(matched, true) + assertEquals(hash.algorithmType, Argon2.Defaults.AlgorithmType) + assertEquals(hash.algorithmVersion, Argon2.Defaults.AlgorithmVersion) + assertEquals(hash.iterations, Argon2.Defaults.Iterations) + assertEquals(hash.parallelism, Argon2.Defaults.Parallelism) + assertEquals(hash.memoryInKb, Argon2.Defaults.MemoryInKB) + assertEquals(Some(hash), decoded) + } + + iotest("should match using stored params, not configured params") { + val input = "Another super s3cr3t pass@" + for + secret <- Argon2Secret.generate(32, rng) + altArgon2 <- IO(new Argon2[IO](altConfig, secret, rng)) + defArgon2 <- IO(new Argon2[IO](Argon2.defaultConfig(), secret, rng)) + altHash <- altArgon2.calculateHash(input) + // We're using the default configuration to run the match, and it should + // still match because we have the same secret and the params are + // extracted from the hash rather than the configuration. + matched <- defArgon2.doesInputMatch(input, altHash) + yield + assertEquals(matched, true) + assertEquals(altHash.algorithmType, altConfig.algorithmType) + assertEquals(altHash.algorithmVersion, altConfig.algorithmVersion) + assertEquals(altHash.iterations, altConfig.iterations) + assertEquals(altHash.parallelism, altConfig.parallelism) + assertEquals(altHash.memoryInKb, altConfig.memoryInKb) + }