Skip to content

Commit

Permalink
Move non-error related code to TapirController rather than `TapirEr…
Browse files Browse the repository at this point in the history
…rorHelpers`
  • Loading branch information
jnatten committed Oct 9, 2024
1 parent 9bed3ca commit 29ed0a7
Show file tree
Hide file tree
Showing 15 changed files with 52 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ trait InternController {
val internController: InternController

class InternController extends TapirController with StrictLogging {
import ErrorHelpers._

override val prefix: EndpointInput[Unit] = "intern"
override val enableSwagger = false
private val stringInternalServerError = statusCode(StatusCode.InternalServerError).and(stringBody)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ trait SeriesController {
this: ReadService & WriteService & SeriesSearchService & SearchConverterService & ConverterService & Props &
ErrorHelpers & TapirController =>
val seriesController: SeriesController
class SeriesController() extends TapirController {

import ErrorHelpers._
import props._
class SeriesController extends TapirController {
import props.*

private val queryString = query[Option[String]]("query")
.description("Return only results with titles or tags matching the specified query.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ trait DraftConceptController {
}
}
import ConceptControllerHelpers._
import ErrorHelpers._

def getConceptById: ServerEndpoint[Any, Eff] = endpoint.get
.summary("Show concept with a specified id")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ trait InternController {
val internController: InternController

class InternController extends TapirController {
import ErrorHelpers._

override val prefix: EndpointInput[Unit] = "intern"
override val enableSwagger = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ trait PublishedConceptController {

class PublishedConceptController extends TapirController {
import ConceptControllerHelpers._
import ErrorHelpers._
import props._

override val serviceName: String = "concepts"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,6 @@ trait DraftController {
}
}

import ErrorHelpers._

def getTagSearch: ServerEndpoint[Any, Eff] = endpoint.get
.summary("Retrieves a list of all previously used tags in articles")
.description("Retrieves a list of all previously used tags in articles")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ trait FileController {
val fileController: FileController

class FileController extends TapirController {
import ErrorHelpers._
override val serviceName: String = "files"
override val prefix: EndpointInput[Unit] = "draft-api" / "v1" / serviceName

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ trait InternController {

class InternController extends TapirController with StrictLogging {
import props.{DraftSearchIndex, DraftTagSearchIndex, DraftGrepCodesSearchIndex}
import ErrorHelpers._

override val prefix: EndpointInput[Unit] = "intern"
override val enableSwagger = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ trait UserDataController {
val userDataController: UserDataController

class UserDataController extends TapirController {
import ErrorHelpers._
override val serviceName: String = "user-data"
override val prefix: EndpointInput[Unit] = "draft-api" / "v1" / serviceName

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ trait FilmPageController {
val filmPageController: FilmPageController

class FilmPageController extends TapirController {
import ErrorHelpers._

override val serviceName: String = "filmfrontpage"
override val prefix: EndpointInput[Unit] = "frontpage-api" / "v1" / serviceName
override val endpoints: List[ServerEndpoint[Any, Eff]] = List(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ trait FrontPageController {
override val serviceName: String = "frontpage"
override val prefix: EndpointInput[Unit] = "frontpage-api" / "v1" / serviceName

import ErrorHelpers._

def getFrontPage: ServerEndpoint[Any, Eff] = endpoint.get
.summary("Get data to display on the front page")
.out(jsonBody[FrontPage])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ trait SubjectPageController {
override val serviceName: String = "subjectpage"
override val prefix: EndpointInput[Unit] = "frontpage-api" / "v1" / serviceName

import ErrorHelpers._

def getAllSubjectPages: ServerEndpoint[Any, Eff] = endpoint.get
.summary("Fetch all subjectpages")
.in(query[Int]("page").default(1).validate(Validator.min(1)))
Expand All @@ -45,9 +43,7 @@ trait SubjectPageController {
.errorOut(errorOutputsFor(400, 404))
.out(jsonBody[List[SubjectPageData]])
.serverLogicPure { case (page, pageSize, language, fallback) =>
readService
.subjectPages(page, pageSize, language, fallback)

readService.subjectPages(page, pageSize, language, fallback)
}

def getSingleSubjectPage: ServerEndpoint[Any, Eff] = endpoint.get
Expand All @@ -58,9 +54,7 @@ trait SubjectPageController {
.out(jsonBody[SubjectPageData])
.errorOut(errorOutputsFor(400, 404))
.serverLogicPure { case (id, language, fallback) =>
readService
.subjectPage(id, language, fallback)

readService.subjectPage(id, language, fallback)
}

def getSubjectPagesByIds: ServerEndpoint[Any, Eff] = endpoint.get
Expand All @@ -76,8 +70,7 @@ trait SubjectPageController {
.serverLogicPure { case (ids, language, fallback, pageSize, page) =>
val parsedPageSize = if (pageSize < 1) props.DefaultPageSize else pageSize
val parsedPage = if (page < 1) 1 else page
readService
.getSubjectPageByIds(ids.values, language, fallback, parsedPageSize, parsedPage)
readService.getSubjectPageByIds(ids.values, language, fallback, parsedPageSize, parsedPage)

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ trait ConfigController {

override protected val prefix: EndpointInput[Unit] = "myndla-api" / "v1" / serviceName

import ErrorHelpers._

val pathConfigKey: EndpointInput.PathCapture[ConfigKey] =
path[ConfigKey]("config-key")
.description(s"The of configuration value. Can only be one of '${ConfigKey.all.mkString("', '")}'")
Expand Down
48 changes: 46 additions & 2 deletions network/src/main/scala/no/ndla/network/tapir/TapirController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@
*/
package no.ndla.network.tapir

import cats.implicits.catsSyntaxEitherId
import com.typesafe.scalalogging.StrictLogging
import io.circe.{Decoder, Encoder}
import no.ndla.common.Clock
import no.ndla.common.configuration.HasBaseProps
import no.ndla.network.tapir.auth.{Permission, TokenUser}
import sttp.client3.Identity
import sttp.model.StatusCode
import sttp.monad.MonadError
import sttp.tapir.*
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.{PartialServerEndpoint, ServerEndpoint}
import no.ndla.network.tapir.NoNullJsonPrinter.jsonBody

trait TapirController {
this: HasBaseProps & Clock =>
this: HasBaseProps & Clock & TapirErrorHelpers =>
trait TapirController extends StrictLogging {
type Eff[A] = Identity[A]
val enableSwagger: Boolean = true
Expand All @@ -31,6 +37,44 @@ trait TapirController {
)
})
}

/** Helper to simplify returning _both_ NoContent and some json body T from an endpoint */
def noContentOrBodyOutput[T: Encoder: Decoder: Schema]: EndpointOutput.OneOf[Option[T], Option[T]] =
oneOf[Option[T]](
oneOfVariantValueMatcher(statusCode(StatusCode.Ok).and(jsonBody[Option[T]])) { case Some(_) => true },
oneOfVariantValueMatcher(statusCode(StatusCode.NoContent).and(emptyOutputAs[Option[T]](None))) { case None =>
true
}
)

/** Helper function that returns function one can pass to `serverSecurityLogicPure` to require a specific scope for
* some endpoint.
*/
def requireScope(scope: Permission*): Option[TokenUser] => Either[AllErrors, TokenUser] = {
case Some(user) if user.hasPermissions(scope) => user.asRight
case Some(_) => ErrorHelpers.forbidden.asLeft
case None => ErrorHelpers.unauthorized.asLeft
}

implicit class authlessEndpoint[A, I, E, O, R](self: Endpoint[Unit, I, AllErrors, O, R]) {
def requirePermission[F[_]](
requiredPermission: Permission*
): PartialServerEndpoint[Option[TokenUser], TokenUser, I, AllErrors, O, R, F] = {
val newEndpoint = self.securityIn(TokenUser.oauth2Input(requiredPermission))
val authFunc = requireScope(requiredPermission *)
val securityLogic = (m: MonadError[F]) => (a: Option[TokenUser]) => m.unit(authFunc(a))
PartialServerEndpoint(newEndpoint, securityLogic)
}
}

implicit class authlessErrorlessEndpoint[A, I, E, O, R, X](self: Endpoint[Unit, I, X, O, R]) {
def withOptionalUser[F[_]]: PartialServerEndpoint[Option[TokenUser], Option[TokenUser], I, X, O, R, F] = {
val newEndpoint = self.securityIn(TokenUser.oauth2Input(Seq.empty))
val authFunc = (tokenUser: Option[TokenUser]) => Right(tokenUser): Either[X, Option[TokenUser]]
val securityLogic = (m: MonadError[F]) => (a: Option[TokenUser]) => m.unit(authFunc(a))
PartialServerEndpoint(newEndpoint, securityLogic)
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,14 @@ package no.ndla.network.tapir

import cats.implicits.*
import com.typesafe.scalalogging.StrictLogging
import io.circe.{Decoder, Encoder}
import no.ndla.common.Clock
import no.ndla.common.configuration.HasBaseProps
import no.ndla.common.errors.ValidationException
import no.ndla.network.tapir.NoNullJsonPrinter.jsonBody
import no.ndla.network.tapir.auth.{Permission, TokenUser}
import sttp.model.StatusCode
import sttp.monad.MonadError
import sttp.tapir.server.PartialServerEndpoint
import sttp.tapir.{Endpoint, EndpointOutput, Schema, emptyOutputAs, oneOf, oneOfVariantValueMatcher, statusCode}

import scala.util.{Failure, Success, Try}

trait TapirErrorHelpers extends StrictLogging {
this: HasBaseProps with Clock =>
this: HasBaseProps & Clock =>

object ErrorHelpers {
val GENERIC = "GENERIC"
Expand Down Expand Up @@ -86,44 +79,6 @@ trait TapirErrorHelpers extends StrictLogging {
def errorBody(code: String, description: String, statusCode: Int): ErrorBody =
ErrorBody(code, description, clock.now(), statusCode)

/** Helper function that returns function one can pass to `serverSecurityLogicPure` to require a specific scope for
* some endpoint.
*/
def requireScope(scope: Permission*): Option[TokenUser] => Either[AllErrors, TokenUser] = {
case Some(user) if user.hasPermissions(scope) => user.asRight
case Some(_) => ErrorHelpers.forbidden.asLeft
case None => ErrorHelpers.unauthorized.asLeft
}

/** Helper to simplify returning _both_ NoContent and some json body T from an endpoint */
def noContentOrBodyOutput[T: Encoder: Decoder: Schema]: EndpointOutput.OneOf[Option[T], Option[T]] =
oneOf[Option[T]](
oneOfVariantValueMatcher(statusCode(StatusCode.Ok).and(jsonBody[Option[T]])) { case Some(_) => true },
oneOfVariantValueMatcher(statusCode(StatusCode.NoContent).and(emptyOutputAs[Option[T]](None))) { case None =>
true
}
)

implicit class authlessEndpoint[A, I, E, O, R](self: Endpoint[Unit, I, AllErrors, O, R]) {
def requirePermission[F[_]](
requiredPermission: Permission*
): PartialServerEndpoint[Option[TokenUser], TokenUser, I, AllErrors, O, R, F] = {
val newEndpoint = self.securityIn(TokenUser.oauth2Input(requiredPermission))
val authFunc = ErrorHelpers.requireScope(requiredPermission: _*)
val securityLogic = (m: MonadError[F]) => (a: Option[TokenUser]) => m.unit(authFunc(a))

PartialServerEndpoint(newEndpoint, securityLogic)
}
}

implicit class authlessErrorlessEndpoint[A, I, E, O, R, X](self: Endpoint[Unit, I, X, O, R]) {
def withOptionalUser[F[_]]: PartialServerEndpoint[Option[TokenUser], Option[TokenUser], I, X, O, R, F] = {
val newEndpoint = self.securityIn(TokenUser.oauth2Input(Seq.empty))
val authFunc = (tokenUser: Option[TokenUser]) => Right(tokenUser): Either[X, Option[TokenUser]]
val securityLogic = (m: MonadError[F]) => (a: Option[TokenUser]) => m.unit(authFunc(a))
PartialServerEndpoint(newEndpoint, securityLogic)
}
}
}

def logError(e: Throwable): Unit = {
Expand Down

0 comments on commit 29ed0a7

Please sign in to comment.