Skip to content

Commit

Permalink
add type parameter ERROR_OUTPUT to anti csrf methods
Browse files Browse the repository at this point in the history
  • Loading branch information
fupelaqu committed Jul 7, 2023
1 parent 098fb4e commit 05f52c1
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ import scala.concurrent.{ExecutionContext, Future}

trait CsrfEndpoints {

def hmacTokenCsrfProtection[T, SECURITY_INPUT, PRINCIPAL, SECURITY_OUTPUT](
def hmacTokenCsrfProtection[T, SECURITY_INPUT, PRINCIPAL, ERROR_OUTPUT, SECURITY_OUTPUT](
checkMode: TapirCsrfCheckMode[T]
)(
body: => PartialServerEndpointWithSecurityOutput[
SECURITY_INPUT,
PRINCIPAL,
Unit,
Unit,
ERROR_OUTPUT,
SECURITY_OUTPUT,
Unit,
Any,
Expand All @@ -36,15 +36,22 @@ trait CsrfEndpoints {
body
}

def hmacTokenCsrfProtectionWithFormOrMultipart[T, SECURITY_INPUT, PRINCIPAL, SECURITY_OUTPUT, F](
def hmacTokenCsrfProtectionWithFormOrMultipart[
T,
SECURITY_INPUT,
PRINCIPAL,
ERROR_OUTPUT,
SECURITY_OUTPUT,
F
](
checkMode: TapirCsrfCheckMode[T],
form: Either[EndpointIO.Body[String, F], EndpointIO.Body[Seq[RawPart], F]]
)(
body: => PartialServerEndpointWithSecurityOutput[
SECURITY_INPUT,
PRINCIPAL,
Unit,
Unit,
ERROR_OUTPUT,
SECURITY_OUTPUT,
Unit,
Any,
Expand All @@ -64,14 +71,14 @@ trait CsrfEndpoints {
body
}

def setNewCsrfToken[T, SECURITY_INPUT, PRINCIPAL, SECURITY_OUTPUT](
def setNewCsrfToken[T, SECURITY_INPUT, PRINCIPAL, ERROR_OUTPUT, SECURITY_OUTPUT](
checkMode: TapirCsrfCheckMode[T]
)(
body: => PartialServerEndpointWithSecurityOutput[
SECURITY_INPUT,
PRINCIPAL,
Unit,
Unit,
ERROR_OUTPUT,
SECURITY_OUTPUT,
Unit,
Any,
Expand All @@ -81,7 +88,7 @@ trait CsrfEndpoints {
SECURITY_INPUT,
PRINCIPAL,
Unit,
Unit,
ERROR_OUTPUT,
(SECURITY_OUTPUT, Option[CookieValueWithMeta]),
Unit,
Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,55 +32,52 @@ private[session] trait OneOffTapirSession[T] {
header[Option[String]](manager.config.sessionHeaderConfig.sendToClientHeaderName)
}

def setOneOffSession[SECURITY_INPUT, SECURITY_OUTPUT](st: SetSessionTransport)(
def setOneOffSession[SECURITY_INPUT, ERROR_OUTPUT, SECURITY_OUTPUT](st: SetSessionTransport)(
body: => PartialServerEndpointWithSecurityOutput[
SECURITY_INPUT,
Option[T],
Unit,
Unit,
ERROR_OUTPUT,
SECURITY_OUTPUT,
Unit,
Any,
Future
]
): PartialServerEndpointWithSecurityOutput[(SECURITY_INPUT, Seq[Option[String]]), Option[
T
], Unit, Unit, (SECURITY_OUTPUT, Seq[Option[String]]), Unit, Any, Future] =
], Unit, ERROR_OUTPUT, (SECURITY_OUTPUT, Seq[Option[String]]), Unit, Any, Future] =
st match {
case CookieST => setOneOffCookieSession(body)
case HeaderST => setOneOffHeaderSession(body)
}

private[this] def setOneOffSessionLogic(
private[this] def setOneOffSessionLogic[ERROR_OUTPUT](
option: Option[T],
existing: Option[String]
): Either[Unit, Option[String]] =
): Either[ERROR_OUTPUT, Option[String]] =
existing match {
case Some(value) =>
Right(
Some(value)
)
case _ =>
case None =>
option match {
case Some(v) => Right(Some(manager.clientSessionManager.encode(v)))
case _ => Left(())
case _ => Right(None)
}
case some => Right(some)
}

def setOneOffCookieSession[SECURITY_INPUT, SECURITY_OUTPUT](
def setOneOffCookieSession[SECURITY_INPUT, ERROR_OUTPUT, SECURITY_OUTPUT](
body: => PartialServerEndpointWithSecurityOutput[
SECURITY_INPUT,
Option[T],
Unit,
Unit,
ERROR_OUTPUT,
SECURITY_OUTPUT,
Unit,
Any,
Future
]
): PartialServerEndpointWithSecurityOutput[(SECURITY_INPUT, Seq[Option[String]]), Option[
T
], Unit, Unit, (SECURITY_OUTPUT, Seq[Option[String]]), Unit, Any, Future] =
], Unit, ERROR_OUTPUT, (SECURITY_OUTPUT, Seq[Option[String]]), Unit, Any, Future] =
body.endpoint
.securityIn(getSessionFromClientAsCookie.map(Seq(_))(_.head))
.out(body.securityOutput)
Expand All @@ -102,20 +99,20 @@ private[session] trait OneOffTapirSession[T] {
}
}

def setOneOffHeaderSession[SECURITY_INPUT, SECURITY_OUTPUT](
def setOneOffHeaderSession[SECURITY_INPUT, ERROR_OUTPUT, SECURITY_OUTPUT](
body: => PartialServerEndpointWithSecurityOutput[
SECURITY_INPUT,
Option[T],
Unit,
Unit,
ERROR_OUTPUT,
SECURITY_OUTPUT,
Unit,
Any,
Future
]
): PartialServerEndpointWithSecurityOutput[(SECURITY_INPUT, Seq[Option[String]]), Option[
T
], Unit, Unit, (SECURITY_OUTPUT, Seq[Option[String]]), Unit, Any, Future] =
], Unit, ERROR_OUTPUT, (SECURITY_OUTPUT, Seq[Option[String]]), Unit, Any, Future] =
body.endpoint
.securityIn(getSessionFromClientAsHeader.map(Seq(_))(_.head))
.out(body.securityOutput)
Expand Down Expand Up @@ -263,11 +260,11 @@ private[session] trait OneOffTapirSession[T] {
)
}

private[this] def invalidateOneOffSessionLogic[SECURITY_OUTPUT, PRINCIPAL](
private[this] def invalidateOneOffSessionLogic[SECURITY_OUTPUT, PRINCIPAL, ERROR_OUTPUT](
result: (SECURITY_OUTPUT, PRINCIPAL),
maybeCookie: Option[String],
maybeHeader: Option[String]
): Either[Unit, (Seq[Option[String]], PRINCIPAL)] = {
): Either[ERROR_OUTPUT, (Seq[Option[String]], PRINCIPAL)] = {
val principal = result._2
maybeCookie match {
case Some(_) =>
Expand Down Expand Up @@ -303,13 +300,14 @@ private[session] trait OneOffTapirSession[T] {

def invalidateOneOffSession[
SECURITY_INPUT,
PRINCIPAL
PRINCIPAL,
ERROR_OUTPUT
](st: GetSessionTransport)(
body: => PartialServerEndpointWithSecurityOutput[
SECURITY_INPUT,
PRINCIPAL,
Unit,
Unit,
ERROR_OUTPUT,
_,
Unit,
Any,
Expand All @@ -319,7 +317,7 @@ private[session] trait OneOffTapirSession[T] {
(SECURITY_INPUT, Seq[Option[String]]),
PRINCIPAL,
Unit,
Unit,
ERROR_OUTPUT,
Seq[Option[String]],
Unit,
Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,20 @@ private[session] trait RefreshableTapirSession[T] extends Completion {
header[Option[String]](manager.config.refreshTokenHeaderConfig.sendToClientHeaderName)
}

def setRefreshableSession[SECURITY_INPUT, SECURITY_OUTPUT](st: SetSessionTransport)(
def setRefreshableSession[SECURITY_INPUT, ERROR_OUTPUT, SECURITY_OUTPUT](st: SetSessionTransport)(
body: => PartialServerEndpointWithSecurityOutput[
SECURITY_INPUT,
Option[T],
Unit,
Unit,
ERROR_OUTPUT,
SECURITY_OUTPUT,
Unit,
Any,
Future
]
): PartialServerEndpointWithSecurityOutput[(SECURITY_INPUT, Seq[Option[String]]), Option[
T
], Unit, Unit, (SECURITY_OUTPUT, Seq[Option[String]]), Unit, Any, Future] =
], Unit, ERROR_OUTPUT, (SECURITY_OUTPUT, Seq[Option[String]]), Unit, Any, Future] =
st match {
case CookieST => setRefreshableCookieSession(body)
case HeaderST => setRefreshableHeaderSession(body)
Expand All @@ -67,29 +67,29 @@ private[session] trait RefreshableTapirSession[T] extends Completion {
}
}

def setRefreshableSessionLogic(
def setRefreshableSessionLogic[ERROR_OUTPUT](
option: Option[T],
existing: Option[String]
): Either[Unit, Option[String]] =
): Either[ERROR_OUTPUT, Option[String]] =
option match {
case Some(v) => Right(rotateToken(v, existing))
case _ => Left(())
case _ => Right(None)
}

def setRefreshableCookieSession[SECURITY_INPUT, SECURITY_OUTPUT](
def setRefreshableCookieSession[SECURITY_INPUT, ERROR_OUTPUT, SECURITY_OUTPUT](
body: => PartialServerEndpointWithSecurityOutput[
SECURITY_INPUT,
Option[T],
Unit,
Unit,
ERROR_OUTPUT,
SECURITY_OUTPUT,
Unit,
Any,
Future
]
): PartialServerEndpointWithSecurityOutput[(SECURITY_INPUT, Seq[Option[String]]), Option[
T
], Unit, Unit, (SECURITY_OUTPUT, Seq[Option[String]]), Unit, Any, Future] = {
], Unit, ERROR_OUTPUT, (SECURITY_OUTPUT, Seq[Option[String]]), Unit, Any, Future] = {
val partial =
setOneOffSession(CookieST) {
body
Expand Down Expand Up @@ -125,20 +125,20 @@ private[session] trait RefreshableTapirSession[T] extends Completion {
}
}

def setRefreshableHeaderSession[SECURITY_INPUT, SECURITY_OUTPUT](
def setRefreshableHeaderSession[SECURITY_INPUT, ERROR_OUTPUT, SECURITY_OUTPUT](
body: => PartialServerEndpointWithSecurityOutput[
SECURITY_INPUT,
Option[T],
Unit,
Unit,
ERROR_OUTPUT,
SECURITY_OUTPUT,
Unit,
Any,
Future
]
): PartialServerEndpointWithSecurityOutput[(SECURITY_INPUT, Seq[Option[String]]), Option[
T
], Unit, Unit, (SECURITY_OUTPUT, Seq[Option[String]]), Unit, Any, Future] = {
], Unit, ERROR_OUTPUT, (SECURITY_OUTPUT, Seq[Option[String]]), Unit, Any, Future] = {
val partial = setOneOffSession(HeaderST) {
body
}
Expand Down Expand Up @@ -415,12 +415,12 @@ private[session] trait RefreshableTapirSession[T] extends Completion {
}
}

private[this] def invalidateRefreshableSessionLogic[PRINCIPAL](
private[this] def invalidateRefreshableSessionLogic[PRINCIPAL, ERROR_OUTPUT](
result: (Seq[Option[String]], PRINCIPAL),
cookie: Option[String],
header: Option[String]
): Either[
Nothing,
ERROR_OUTPUT,
(
Seq[Option[String]],
PRINCIPAL
Expand Down Expand Up @@ -449,13 +449,14 @@ private[session] trait RefreshableTapirSession[T] extends Completion {

def invalidateRefreshableSession[
SECURITY_INPUT,
PRINCIPAL
PRINCIPAL,
ERROR_OUTPUT
](st: GetSessionTransport)(
body: => PartialServerEndpointWithSecurityOutput[
SECURITY_INPUT,
PRINCIPAL,
Unit,
Unit,
ERROR_OUTPUT,
_,
Unit,
Any,
Expand All @@ -465,7 +466,7 @@ private[session] trait RefreshableTapirSession[T] extends Completion {
(SECURITY_INPUT, Seq[Option[String]]),
PRINCIPAL,
Unit,
Unit,
ERROR_OUTPUT,
Seq[Option[String]],
Unit,
Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,32 @@ import scala.concurrent.{ExecutionContext, Future}

trait SessionEndpoints {

def setSessionEndpoint[T, SECURITY_INPUT](
endpoint: => Endpoint[SECURITY_INPUT, Unit, Unit, Unit, Any]
def setSessionEndpoint[T, SECURITY_INPUT, ERROR_OUTPUT](
endpoint: => Endpoint[SECURITY_INPUT, Unit, ERROR_OUTPUT, Unit, Any]
)(implicit
f: SECURITY_INPUT => Option[T]
): PartialServerEndpointWithSecurityOutput[SECURITY_INPUT, Option[
T
], Unit, Unit, Unit, Unit, Any, Future] =
endpoint.serverSecurityLogicSuccessWithOutput(si => Future.successful(((), f(si))))
], Unit, ERROR_OUTPUT, Unit, Unit, Any, Future] =
endpoint
.serverSecurityLogicSuccessWithOutput(si => Future.successful(((), f(si))))

/** Set the session cookie with the session content. The content is signed, optionally encrypted
* and with an optional expiry date.
*
* If refreshable, generates a new token (removing old ones) and stores it in the refresh token
* cookie.
*/
def setSession[T, SECURITY_INPUT, SECURITY_OUTPUT](
def setSession[T, SECURITY_INPUT, SECURITY_OUTPUT, ERROR_OUTPUT](
sc: TapirSessionContinuity[T],
st: SetSessionTransport
)(
body: => PartialServerEndpointWithSecurityOutput[SECURITY_INPUT, Option[
T
], Unit, Unit, SECURITY_OUTPUT, Unit, Any, Future]
], Unit, ERROR_OUTPUT, SECURITY_OUTPUT, Unit, Any, Future]
): PartialServerEndpointWithSecurityOutput[(SECURITY_INPUT, Seq[Option[String]]), Option[
T
], Unit, Unit, (SECURITY_OUTPUT, Seq[Option[String]]), Unit, Any, Future] =
], Unit, ERROR_OUTPUT, (SECURITY_OUTPUT, Seq[Option[String]]), Unit, Any, Future] =
sc.setSession(st)(body)

def setSessionWithAuth[T, A](sc: TapirSessionContinuity[T], st: SetSessionTransport)(
Expand All @@ -53,7 +54,7 @@ trait SessionEndpoints {
Any,
Future
] =
setSession[T, A, Unit](sc, st) {
setSession[T, A, Unit, Unit](sc, st) {
setSessionEndpoint {
endpoint.securityIn(auth)
}
Expand Down Expand Up @@ -81,15 +82,15 @@ trait SessionEndpoints {
* Note that you should use `refreshable` if you use refreshable systems even only for some
* users.
*/
def invalidateSession[T, SECURITY_INPUT, PRINCIPAL](
def invalidateSession[T, SECURITY_INPUT, PRINCIPAL, ERROR_OUTPUT](
sc: TapirSessionContinuity[T],
st: GetSessionTransport
)(
body: => PartialServerEndpointWithSecurityOutput[
SECURITY_INPUT,
PRINCIPAL,
Unit,
Unit,
ERROR_OUTPUT,
_,
Unit,
Any,
Expand All @@ -99,7 +100,7 @@ trait SessionEndpoints {
(SECURITY_INPUT, Seq[Option[String]]),
PRINCIPAL,
Unit,
Unit,
ERROR_OUTPUT,
Seq[Option[String]],
Unit,
Any,
Expand Down
Loading

0 comments on commit 05f52c1

Please sign in to comment.