diff --git a/.coveragerc b/.coveragerc index 93f230678..2c4a8c84d 100644 --- a/.coveragerc +++ b/.coveragerc @@ -8,10 +8,11 @@ omit = src/klein/test/typing_*.py source= src/klein .tox/*/lib/python*/site-packages/klein + .tox/*/lib/pypy*/site-packages/klein .tox/*/Lib/site-packages/klein - .tox/pypy*/site-packages/klein [report] exclude_lines = pragma: no cover if TYPE_CHECKING: + \s*\.\.\.$ diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 9762d018f..5ebcd6312 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -194,7 +194,7 @@ jobs: matrix: os: ["ubuntu-latest"] python-version: ["3.7", "3.9", "3.10", "3.11"] - twisted-version: ["21.2", "22.1", "23.8"] + twisted-version: ["21.2", "22.1", "23.8", "23.10"] tox-prefix: ["coverage"] optional: [false] include: @@ -225,7 +225,9 @@ jobs: - name: Install Python uses: actions/setup-python@v4 with: - python-version: ${{ matrix.python-version }} + python-version: | + ${{ matrix.python-version }} + 3.11 - name: System Python Information uses: twisted/python-info-action@v1 @@ -264,6 +266,16 @@ jobs: - name: Run unit tests run: tox run -e ${TOX_ENV} + - name: Combine coverage + run: tox run -e coverage_combine,coverage_report + if: ${{ matrix.tox-prefix == 'coverage' }} + + - name: Upload Coverage XML + uses: actions/upload-artifact@v3 + with: + name: coverage-debug + path: coverage.xml + - name: Upload Trial log artifact if: ${{ failure() }} uses: actions/upload-artifact@v3 diff --git a/MANIFEST.in b/MANIFEST.in index 3e64db125..fde03f3a3 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -20,3 +20,4 @@ recursive-include docs *.rst recursive-include docs *.txt recursive-exclude docs/_build * recursive-exclude requirements * +recursive-include src *.sql diff --git a/docs/index.rst b/docs/index.rst index 980b12c0c..c74881b3e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -19,6 +19,8 @@ This is an introduction to Klein, going through from creating a simple web site introduction/1-gettingstarted introduction/2-twistdtap + introduction/3-forms + introduction/4-auth Klein Examples diff --git a/docs/introduction/1-gettingstarted.rst b/docs/introduction/1-gettingstarted.rst index 0fc1117ce..99650d1c2 100644 --- a/docs/introduction/1-gettingstarted.rst +++ b/docs/introduction/1-gettingstarted.rst @@ -103,6 +103,8 @@ If you run this example and then visit ``http://localhost:8080/``, you will get Streamlined Apps With HTML and JSON =================================== +.. _htmljson: + For a typical web application, the first order of business is generating some simple HTML pages that users can interact with and that search engines can easily index. diff --git a/docs/introduction/3-forms.rst b/docs/introduction/3-forms.rst new file mode 100644 index 000000000..be5fefe13 --- /dev/null +++ b/docs/introduction/3-forms.rst @@ -0,0 +1,332 @@ + +=================== +Handling Form Input +=================== + +Introduction +------------ + +In :ref:`“Streamlined Apps With HTML and JSON” ` we set up a basic +site that could render HTML and read data. However, for most applications, you +will need some way for users to **input** data; in other words: handling forms, +both rendering them and posting them. + +In order to handle HTML forms from the browser `securely +`_, we also have to implement +some form of authenticated session along with them. + +So let's build on top of our food-list application by letting users submit a +form that adds some foods to a list. + +Our example here will be a very simple app, where you type in the name of a +food and give it a star rating. To begin, it'll be entirely anonymous. + +If you want full, runnable examples, you can find them `in the Klein repository +on Github +`_ + +Configuration and Setup +----------------------- + +In order to provide a realistic example that actually stores state, we'll also +use Klein's integrated database access system, and simple account/session +storage with username and password authentication. However, there are +documented interfaces between each of these layers (storage, sessions, +accounts), and your application can supply its own account or session storage +as your needs for authentication evolve. But before we get into +authentication, let's get a basic system for processing forms and storing data +set up. + +To configure our system we will set up a few things: + +- First, we will adapt the synchronous ``sqlite3`` database driver to an + asynchronous one. + +- Next, we will build a *session procurer*, which is what will retrieve our + sessions from the configured database. + +- Then, we will set up a ``Requirer``, which is how each of our routes will + tell the authorization and forms systems what values our routes require to + execute. + +- Finally, we will set up a *prerequisite requirement*, a thing that all routes + in our application require, of an ``ISession``. We hook this up to our + ``Requirer`` using the ``requirer.prerequisite()`` decorator. + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_config.py + +We'll also need some HTML templating set up to style our pages. Using what we +learned about Plating, we'll set up a basic page, use the ``fragment`` +convenience decorator to make a widget for consistently displaying a food in +the HTML UI. + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_templates.py + +.. note:: + + ``@Plating.fragment`` functions are invoked once at the time they are + decorated, with each of their arguments being a ``slot`` object, **not** the + type that it's they're declared to have; the only thing you should do in the + body of these functions is construct a ``Tag`` object that serves as a + fragment of your resulting template. This can be a little confusing at + first, but it allows you to have a nice type-checked interface to ensure + that you're always passing the correct slots to them later. + + +Database Access with ``dbxs`` +----------------------------- + +You may have noticed that in the configuration above, we constructed our +``SQLSessionProcurer`` with a list of *authorizers*. An authorizer is a +function that can look at a database and determine if a user is authorized to +perform a task, so now we will implement the interaction with the database. + +We will use Klein's built-in lightweight asynchronous database access system, +``dbxs``, allows you to keep your queries organized and construct simple +classes from your query results, without bringing in the overhead of an ORM or +query builder. If you know SQL and you know basic Python data structures, you +allmost know how to use it already. + +First let's get started with a very basic schema; a 'food' with a name and a +rating: + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_db.py + :lines: 14-19 + +Next, a function to apply that schema, along with Klein's own basic account & +session schema with ``session`` and ``account`` tables: + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_db.py + :lines: 11-11 + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_db.py + :pyobject: applySchema + +Now, let's define our basic data structure to correspond to that table: + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_db.py + :pyobject: FoodRating + +And now we will use ``dbxs`` to specify what queries we're going to make +against that schema. + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_db.py + :lines: 5 + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_db.py + :pyobject: RatingsDB + +Here, we have defined a ``typing.Protocol`` whowse methods are all awaitable or +async iterables decorated with ``@query`` (for SQL expressions that we expect +results for) or ``@statement`` for those which we expect to have side effects +but not return values. We have one read operation, ``allRatings``, that gives +us all the ratings in the database, and ``addRating`` which adds a rating. All +the argument types for these methods must be things you can pass to the +database, and they are supplied to the query via the curly-braced format +specifiers included in the SQL string, whose names match the parameters +specified in your Python function arguments. + +While ``@statement`` returns no values, ``@query`` needs to know how to +interpret its query results, and it does this via its ``load`` argument. If +you pass ``load=many(YourCallable)``, the decorated function must return an +``AsyncIterable`` of ``YourCallable``'s return type. The callable itself takes +an ``AsyncConnection`` as its first argument, and the columns of the query's +results as the rest of the arguments. Here, we know that ``select name, +rating`` matches up with ``FoodRating``'s dataclass arguments, ``name: str`` +and ``rating: int``. + +If you have a query that you know should only ever return a single value, you +can use ``load=one(YourType)`` and the return type should be +``Awaitable[YourType]``, or for one-or-zero results you can use +``load=maybe(YourType)`` which should return ``Awaitable[YourType | None]``. + +These decorators provide information, but a ``Protocol`` is an abstract type; +it can't actually **do** anything on its own. We need to somehow transform an +``AsyncConnection`` into something that looks like this type and executes these +queries, and for that we use ``accessor``, which converts our ``RatingsDB`` +protocol into a callable that *takes* an ``AsyncConnection`` and *returns* an +instance of ``RatingsDB`` that can execute all those queries. + +This system will help you out by performing a few basic checks. At type-check +time, ``mypy`` will make sure that your return types correspond with the loader +type (``one``, ``many``, ``maybe``) that you've specified. At import time, you +will get an exception if the arguments specified in your function signatures +are not used in your queries, or if the queries use arguments you didn't +provide. However, you will need to verify that the SQL itself is valid; we'll +cover that in a later section on testing. + +Creating an Authorizer +---------------------- + +Now that we've got a basic data-access layer in place, let's put some access +control in place. For this simple anonymous site, the access control is pretty +lenient; everyone should bea uthorized to access these methods all the time. +However, given that we'll want to restrict that a bit in the future, we can't +use our new data-access ``RatingsDB`` ``Protocol`` directly, so we will declare +a new class. For this example it will simply forward all the methods on: + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_db.py + :pyobject: FoodRater + +But then we will also declare an **authorizer** for it, so that Klein knows how +to determine if a user has access to it in a particular route that needs it: + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_db.py + :pyobject: authorizeFoodRater + +SQL authorizers are passed a ``dbxs`` ``AsyncConnection``, a session store, and +the user's current session. They can then do any queries necessary to +determine if a user is authorized, and return ``None`` if they're not, which is +why we declare that we return an ``Optional[FoodRater]``, reserving the right +that we may want to return ``None`` later. However, for the time being, we use +``accessRatings`` to convert our database connection into a ``RatingsDB``, then +pass it to our ``FoodRater`` so that all sessions have access to this +functionality if they need it; no queries required just yet. + +Finally, we can build the list of authorizers that we used in the configuration +above: + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_db.py + :lines: 70 + +Now that everything is set up, let's move on to our main application and +declare some routes! + +Handling Form Fields with Requirer +---------------------------------- + +For this quick, anonymous version of the application, let's first set up a +route to rate foods, which will serve as our form. + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_routes.py + :pyobject: postHandler + +As before, you can see we've wrapped the ``Plating.routed`` decorator around +our klein app's ``route`` method decorator, which takes care of templating and +handling our return value as the mapping of slot names to slot values. +However, we are now wrapping an additional decorator around the ``routed`` +decorator: ``require``. + +As its first (positional) argument, ``Requirer.require`` takes a decorator, +something that expects to receive a Twisted HTTP request object (or ``self``, +then, a request, if you're using a ``Klein`` bound to an instance) as its first +argument, as well as any additional arguments. So you can use +``Plating.routed`` or ``Klein.route`` here, depending on whether your +application requires HTML templating or not. + +.. note:: + + ``Requirer.require`` *consumes* the request that it is given, so you can't + access it any more. The idea here is that interacting with ``Request`` + directly is a low-level way of expressing what values you require from the + request, and ``Requirer`` is trying to provide a high-level way to get those + requirements, where you've expressed the things you need and your route is + not even invoked if they can't be retrieved. If you need data from the + request that is not exposed by Klein, you can implement your own + ``IRequiredParameter`` to take the request and supply whatever value you + require. + +Next, it takes a set of keyword arguments. Each argument corresponds to an +argument taken by the decorated function, and is an ``IRequiredParameter`` +which describes what will be passed and how it will be fetched from either the +request in the database. + +In simpler terms, in code like this:: + + @requirer.require(..., something=SomeRequiredParameter()) + def routeHandler(something: SomeRequiredParameterType): + ... + +What is happening is that ``routeHandler`` is saying to Klein, "I take a +parameter called ``something``, which ``SomeRequiredParameter`` knows how to +supply". + +In our ``postHandler`` example above, ``require`` is given instructions to pass +3 relevant parameters to ``postHandler``. Let's look at the first two: + +1. ``name``, which is text form field +2. ``rating``, which an integer form field with a value between 1 and 5 + +The first two values here are fairly simple; ``klein.Field`` declares that +they'll be extracted from a form POST in multipart/form-data or JSON formats, +it will validate them, and then pass them along to ``postHandler`` as arguments +assuming everything looks correct. + +Using the authorizer we created with ``Authorization`` +------------------------------------------------------ + +The third requirement is ``foodRater``, which is *a request to authorize the +current session* to access a ``FoodRater`` object, using ``Authorization`` . + +Remember that ``@authorizerFor(FoodRater)`` function that we wrote before? It +pulls the ``ISession`` implementation from our ``ISession`` prerequisite, +checks if the user is authorized for ``FoodRater``, then passes the created +object along to us. In other words, to use this route, *an authorization for a +food rater is required*. + +Finally, our implementation job is very simple here. We call the ``rateFood`` +method on the ``FoodRater`` we have been passed, then format some outputs for +our template, including a synthetic redirect to send the user back over to +``/`` to look at the rating list after the form post is processed. + +Rendering an HTML form with ``Form.rendererFor`` +------------------------------------------------ + +It might have seemed slightly odd to describe the *handler* for a form before +we've even drawn the form itself, but the idea behind this is that you think +first about what you want to do with the form, what values are required, and +then the description of those values serves as the description of the form +itself. So now that we have a function decorated with ``@requirer.require`` +that takes some ``klein.Field`` parameters, we can get a renderable form out of +it, to render on the front page. + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_routes.py + :pyobject: frontPage + +In the ``GET`` route for ``/``, we do not require any other ``Field``\ s, but we +still require the ``FoodRater`` authorization in order to use its +``allRatings`` method. Once again, we ask for it via an ``IRequiredParameter`` +passed to ``require``, by calling +``klein.Form.rendererFor(theRouteWithRequiredFields)``, which will pass along a +``RenderableForm`` object, that can be dropped into a slot in a ``Plating`` +template. + +Here, we also use the ``food`` fragment that we declared before in our template +module, which allows us to embed more complex template fragments into a +list-item slot. + +Handling Validation Errors with ``Form.onValidationFailureFor`` +--------------------------------------------------------------- + +Next, we need to do something about validation failures. We don't want our +users to see a generic error message (or worse, a traceback) when something +doesn't validate, and we'd like Klein to be able to communicate the nature of +the validation issue on a per-field basis. To do that, we use +``Form.onValidationFailureFor(theRouteWithRequiredFields)``. This decorator +functions similarly to ``app.route``, as it also handles a URL, although +*which* URL it's handling depends on the post-handling route it is wrapping. + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_routes.py + :pyobject: validationFailed + +This route defines a template and logic to use to render the form-validation +failure on ``/rate-food``. By using ``page.routed``, we ensure that the +template used is not a generic placeholder default for the form being handled, +but contains all the relevant decorations for our page template. + +Putting it all together +----------------------- + +.. literalinclude:: codeexamples/foodwiki/anon/foodwiki_routes.py + :lines: 70-81 + +Finally, we ensure that the database schema is applied, and we start our server +up using ``app.run`` as usual. You should be able to start up a server and see +the example food-rating app there, post a form, try to post negative stars or +more than five stars, see the validation either fail or succeed depending on +those values, and see all the ratings you've put into the system. + +Next up, we will cover a modified version of this application that shows you +how to implement a signup form, a login form, and actually leverage the power +of an ``Authorizer`` when authorization is not available to every +unauthenticated user. diff --git a/docs/introduction/4-auth.rst b/docs/introduction/4-auth.rst new file mode 100644 index 000000000..f8665034c --- /dev/null +++ b/docs/introduction/4-auth.rst @@ -0,0 +1,174 @@ + +Authentication and Authorization +================================ + +Now that we can handle forms and sessions, let's build on that to build a +website with signup and login forms. + +We'll build on our food-rating wiki example, and modify it to have to have user +accounts. Let's begin by changing our schema to include the user who posted +the rating. + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_db.py + :lines: 21-30 + +We are adding a ``rated_by`` column with a foreign key constraint on +``account``. But where did ``account`` come from? + +In order to do anything useful within a database with authentication, we need +to be able to relate to the account and session tables, so they are considered +part of Klein's public API. For reference, here is that full schema: + +.. literalinclude:: ../../src/klein/storage/sql/basic_auth_schema.sql + :language: sql + +Next, we'll need to split up our database interface. Previously, we only had +one authorized object, for all clients. However, now we have two classes of +client: those logged in to an account, and those not logged in to an account. +We only want to allow ratings for those who have signed up and logged in. + +On the front page, we will want to display a bunch of ratings by different +users, with links to their user pages. So we will need a new ``NamedRating`` +class which combines the rating with a username rather than an account ID, to +make the presentation of the URLs nice; we don't want them to include the +opaque blobs used for account IDs. We'll also need a query to build those. +So, here are our new queries; we need one for just the top 10 ratings, and then +one that gives us all the ratings by a given user: + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_db.py + :pyobject: PublicRatingsDB + +Next, we will need our *private* queries interface, the one you only get if +you're logged in. + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_db.py + :pyobject: RatingsDB + +Similar to before, we have an authorizer that allows everyone access to the +public ratings: + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_db.py + :pyobject: RatingsViewer + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_db.py + :pyobject: authorizeRatingsViewer + +But now, we have the slight additional complexity of *conditional* +authorization. Our authenticated-user authorization, ``FoodCritic``, needs to +return ``None`` from its authorizer if you're not logged in: + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_db.py + :pyobject: FoodCritic + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_db.py + :pyobject: authorizeFoodCritic + +``SQLSessionProcurer`` provides built-in authorizers for Klein's built-in +account functionality, ``ISimpleAccountBinding`` and ``ISimpleAccount``. So +here we ask the session to authorize us an ``ISimpleAccountBinding`` to see +which accounts our session is bound to. If it we find one, then we can return +a ``FoodCritic`` wrapped around it; the ``FoodCritic`` remembers its user and +performs all its operations with that account ID. If we can't, then we return +``None``. + +.. note:: + + The interfaces for ``ISimpleAccount`` and ``ISimpleAccountBinding`` begin + with the word "simple" because Klein's built-in account system is + deliberately simplistic. It is intended to be easy to get started with and + suitable for light production workloads, but is not intended to be an + all-encompassing way that all Klein applications should perform their + account management; not all systems have usernames, not all systems have + passwords, and not all systems use a relational database. + + If you have your own existing datastore, your own way of accessing your + RDBMS, or your own authentication system, you will want to look into + implementing your own version of the ``ISessionStore`` and ``ISession`` + interfaces; in particular ``ISession.authorize`` is the back-end for + ``Authorization``. Once you have one, you can set up your ``ISession`` + prerequisite to use ``SessionProcurer`` with your own ``ISessionStore``, and + all the route-level logic ought to look similar, modulo whatever access + pattern your data store requires. + +So now our database and model supports our new authenticated/unauthenticated +distinction. But this doesn't do us any good if we can't sign up for the site, +or log in to it. So let's make some routes that can do just that: + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_auth_routes.py + :pyobject: signup + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_auth_routes.py + :pyobject: showSignup + +We have another form following the example set in the previous section. +``signup`` presents a form with a username and 2 password fields. Ensuring +that those fields match is left as an exercise for the reader, but we request +an ``Authorization`` for ``ISimpleAccountBinding``. Once again, this +authorizer is built in to the SQL session store and is available to any user. +We create an account and send the user over to ``/login``. Then we render the +form in the same way as any other form, with ``Form.rendererFor``. + +Having successfully signed up, now we need to log in. + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_auth_routes.py + :pyobject: login + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_auth_routes.py + :pyobject: loginForm + +Our login form looks a lot like our signup form, but instead calls +``bindIfCredentialsMatch`` with the username/password credentials that we've +received. This returns the bound account if the credentials match, but +``None`` otherwise. Finally, we need a way to log out as well: + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_auth_routes.py + :pyobject: logout + +Here we demonstrate customizing the text on the submit button for the form, +since we need *some* field to indicate this is indeed a form post processor; +including an explicit “submit” field is how you mark an effectively no-argument +form as a POSTable form route. Plus, it wouldn't make sense for the rendered +button to say “submit” with no context; “log out” makes a lot more sense. + +.. note:: + + If you want to interact with a session store directly in, i.e. an + administrative command line tool rather than a Klein route, you can + instantiate a ``klein.storage.sql.SessionStore`` directly with an + ``AsyncConnection``, rather than using ``SQLSessionProcurer``, which needs + an HTTP request. + +That's sign-up, login, and logout handled. Now we need to change the way that +our application routes actually handle authorization to deal with our new +logged-in/logged-out split. First, let's look at our food-rating post handler: + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_routes.py + :pyobject: notLoggedIn + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_routes.py + :pyobject: postHandler + +Not much has changed here; we still have an ``Authorization`` that requests a +``FoodCritic`` and calls a method on it. The only difference here is that +*this method will no longer be called* if the user is not logged in; instead, +the resource specified by ``whenDenied`` - in other words, the simple templated +page from ``notLoggedIn`` - will be displayed. + +But surely we don't even want to *show* the form to the user if they're not +logged in, right? Just the top ratings, with the option to log in. How can we +accomplish that? We don't want the presence of an ``Authorization`` requesting +the ``FoodCritic`` on the front page to simply *fail* and show the user an +error, that would be a pretty annoying user experience. What we use here is an +``Authorization`` with ``required=False`` ; that will give us a conditional +authorization that passes ``None`` if it cannot be authorized, so we take a +``FoodCritic | None`` as our parameter, like so: + +.. literalinclude:: codeexamples/foodwiki/auth/foodwiki_routes.py + :pyobject: frontPage + +We require an ``Authorization`` for a ``FoodCritic`` conditionally, but we +require ``RatingsViewer`` unconditionally, mirroring the way the page is +actually displayed. We want to see the top ratings regardless, but the form +only when we're logged in. Note that our ``topRatings`` method is now giving +us ``NamedRating`` objects, and thus we use a new ``linkedFood`` fragment to +display them with a hyperlink. diff --git a/docs/introduction/codeexamples/foodwiki.py b/docs/introduction/codeexamples/foodwiki.py new file mode 100644 index 000000000..869110159 --- /dev/null +++ b/docs/introduction/codeexamples/foodwiki.py @@ -0,0 +1,100 @@ +""" +Simple example of a public website. +""" +from twisted.web.template import Tag, slot, tags + +from klein import Field, Form, Klein, Plating, Requirer, SessionProcurer +from klein.interfaces import ISession +from klein.storage.memory import MemorySessionStore + + +app = Klein() + +sessions = MemorySessionStore() + +requirer = Requirer() + + +@requirer.prerequisite([ISession]) +def procurer(request): + return SessionProcurer(sessions).procureSession(request) + + +style = Plating( + tags=tags.html( + tags.head( + tags.title("Foods Example: ", slot("pageTitle")), + slot("headExtras"), + ), + tags.body(tags.div(slot(Plating.CONTENT))), + ), + defaults={"pageTitle": "Food List", "headExtras": ""}, + presentation_slots={"pageTitle", "headExtras", "addFoodForm"}, +) + +foodsList = [("test", 1)] + + +@requirer.require( + style.routed( + app.route("/", methods=["POST"]), + tags.h1("Added Food: ", slot("name")), + ), + name=Field.text(), + rating=Field.number(minimum=1, maximum=5, kind=int), +) +def postHandler(name, rating): + foodsList.append((name, rating)) + return { + "name": name, + "rating": "\N{BLACK STAR}" * rating, + "pageTitle": "Food Added", + "headExtras": tags.meta( + content="0;URL='/'", **{"http-equiv": "refresh"} + ), + } + + +@Plating.fragment +def food(name: str, rating: str) -> Tag: + return tags.div( + tags.div("food:", name), + tags.div("rating:", rating), + ) + + +@requirer.require( + style.routed( + app.route("/", methods=["GET"]), + tags.div( + tags.ul(tags.li(render="foods:list")(slot("item"))), + tags.div(slot("addFoodForm")), + ), + ), + theForm=Form.rendererFor(postHandler, action="/?post=yes"), +) +def formRenderer(theForm): + global foodsList + return { + "addFoodForm": theForm, + "foods": [ + food(name=name, rating="\N{BLACK STAR}" * rating) + for name, rating in foodsList + ], + } + + +@requirer.require( + style.routed( + Form.onValidationFailureFor(postHandler), + [tags.h1("invalid form"), tags.div(slot("the-invalid-form"))], + ), + renderer=Form.rendererFor(postHandler, action="/?post=yes"), +) +def validationFailed(values, renderer): + renderer.prevalidationValues = values.prevalidationValues + renderer.validationErrors = values.validationErrors + return {"the-invalid-form": renderer} + + +app.run("localhost", 8080) diff --git a/docs/introduction/codeexamples/foodwiki/anon/foodwiki_config.py b/docs/introduction/codeexamples/foodwiki/anon/foodwiki_config.py new file mode 100644 index 000000000..7195f629f --- /dev/null +++ b/docs/introduction/codeexamples/foodwiki/anon/foodwiki_config.py @@ -0,0 +1,33 @@ +import sqlite3 +from typing import Optional + +from foodwiki_db import allAuthorizers + +from twisted.internet.defer import Deferred, succeed +from twisted.web.iweb import IRequest + +from klein import Requirer +from klein.interfaces import ISession +from klein.storage.dbxs.dbapi_async import adaptSynchronousDriver +from klein.storage.sql import SQLSessionProcurer + + +DB_FILE = "food-wiki.sqlite" + +asyncDriver = adaptSynchronousDriver( + (lambda: sqlite3.connect(DB_FILE)), sqlite3.paramstyle +) + +sessions = SQLSessionProcurer(asyncDriver, allAuthorizers) +requirer = Requirer() + + +@requirer.prerequisite([ISession]) +def procurer(request: IRequest) -> Deferred[ISession]: + result: Optional[ISession] = ISession(request, None) + if result is not None: + # TODO: onValidationFailureFor results in one require nested inside + # another, which invokes this prerequisite twice. this mistake should + # not be easy to make + return succeed(result) + return sessions.procureSession(request) diff --git a/docs/introduction/codeexamples/foodwiki/anon/foodwiki_db.py b/docs/introduction/codeexamples/foodwiki/anon/foodwiki_db.py new file mode 100644 index 000000000..f4eafbdca --- /dev/null +++ b/docs/introduction/codeexamples/foodwiki/anon/foodwiki_db.py @@ -0,0 +1,70 @@ +from dataclasses import dataclass +from typing import AsyncIterable, Optional, Protocol + +from klein.interfaces import ISession, ISessionStore +from klein.storage.dbxs import accessor, many, query, statement +from klein.storage.dbxs.dbapi_async import ( + AsyncConnectable, + AsyncConnection, + transaction, +) +from klein.storage.sql import applyBasicSchema, authorizerFor + + +foodTable = """ +CREATE TABLE food ( + name VARCHAR NOT NULL, + rating INTEGER NOT NULL +) +""" + + +async def applySchema(connectable: AsyncConnectable) -> None: + await applyBasicSchema(connectable) + async with transaction(connectable) as c: + cur = await c.cursor() + await cur.execute(foodTable) + + +@dataclass +class FoodRating: + txn: AsyncConnection + name: str + rating: int + + +class RatingsDB(Protocol): + @query( + sql="select name, rating from food", + load=many(FoodRating), + ) + def allRatings(self) -> AsyncIterable[FoodRating]: + ... + + @statement(sql="insert into food (name, rating) values ({name}, {rating})") + async def addRating(self, name: str, rating: int) -> None: + ... + + +accessRatings = accessor(RatingsDB) + + +@dataclass +class FoodRater: + db: RatingsDB + + def allRatings(self) -> AsyncIterable[FoodRating]: + return self.db.allRatings() + + async def rateFood(self, name: str, rating: int) -> None: + return await self.db.addRating(name, rating) + + +@authorizerFor(FoodRater) +async def authorizeFoodRater( + store: ISessionStore, conn: AsyncConnection, session: ISession +) -> Optional[FoodRater]: + return FoodRater(accessRatings(conn)) + + +allAuthorizers = [authorizeFoodRater.authorizer] diff --git a/docs/introduction/codeexamples/foodwiki/anon/foodwiki_routes.py b/docs/introduction/codeexamples/foodwiki/anon/foodwiki_routes.py new file mode 100644 index 000000000..a29db733b --- /dev/null +++ b/docs/introduction/codeexamples/foodwiki/anon/foodwiki_routes.py @@ -0,0 +1,81 @@ +""" +Simple example of a public website. +""" + + +from foodwiki_config import requirer +from foodwiki_db import FoodRater +from foodwiki_templates import food, page, refresh + +from twisted.web.template import slot, tags + +from klein import Authorization, Field, FieldValues, Form, Klein, RenderableForm + + +app = Klein() + + +@requirer.require( + page.routed( + app.route("/rate-food", methods=["POST"]), + tags.h1("Rated Food: ", slot("name")), + ), + name=Field.text(), + rating=Field.number(minimum=1, maximum=5, kind=int), + foodRater=Authorization(FoodRater), +) +async def postHandler(name: str, rating: int, foodRater: FoodRater) -> dict: + await foodRater.rateFood(name, rating) + return { + "name": name, + "rating": "\N{BLACK STAR}" * rating, + "pageTitle": "Food Rated", + "headExtras": refresh("/"), + } + + +@requirer.require( + page.routed( + app.route("/", methods=["GET"]), + tags.div( + tags.ul(tags.li(render="foods:list")(slot("item"))), + tags.div(slot("rateFoodForm")), + ), + ), + ratingForm=Form.rendererFor(postHandler, action="/rate-food"), + foodRater=Authorization(FoodRater), +) +async def frontPage(foodRater: FoodRater, ratingForm: RenderableForm) -> dict: + allRatings = [] + async for eachFood in foodRater.allRatings(): + allRatings.append( + food(name=eachFood.name, rating="\N{BLACK STAR}" * eachFood.rating) + ) + return {"foods": allRatings, "rateFoodForm": ratingForm} + + +@requirer.require( + page.routed( + Form.onValidationFailureFor(postHandler), + [tags.h1("invalid form"), tags.div(slot("the-invalid-form"))], + ), + renderer=Form.rendererFor(postHandler, action="/?post=again"), +) +def validationFailed(values: FieldValues, renderer: RenderableForm) -> dict: + renderer.prevalidationValues = values.prevalidationValues + renderer.validationErrors = values.validationErrors + return {"the-invalid-form": renderer} + + +if __name__ == "__main__": + from os.path import exists + + from foodwiki_config import DB_FILE, asyncDriver + from foodwiki_db import applySchema + + from twisted.internet.defer import Deferred + + if not exists(DB_FILE): + Deferred.fromCoroutine(applySchema(asyncDriver)) + + app.run("localhost", 8080) diff --git a/docs/introduction/codeexamples/foodwiki/anon/foodwiki_templates.py b/docs/introduction/codeexamples/foodwiki/anon/foodwiki_templates.py new file mode 100644 index 000000000..a9036afe6 --- /dev/null +++ b/docs/introduction/codeexamples/foodwiki/anon/foodwiki_templates.py @@ -0,0 +1,30 @@ +from twisted.web.template import Tag, slot, tags + +from klein import Plating + + +page = Plating( + tags=tags.html( + tags.head( + tags.title("Food Ratings Example: ", slot("pageTitle")), + slot("headExtras"), + ), + tags.body( + tags.h1("Food Ratings Example: ", slot("pageTitle")), + tags.div(slot(Plating.CONTENT)), + ), + ), + defaults={"pageTitle": "", "headExtras": ""}, +) + + +@page.fragment +def food(name: str, rating: str) -> Tag: + return tags.div( + tags.div("food:", name), + tags.div("rating:", rating), + ) + + +def refresh(url: str) -> Tag: + return tags.meta(content=f"0;URL='{url}'", **{"http-equiv": "refresh"}) diff --git a/docs/introduction/codeexamples/foodwiki/auth/foodwiki_auth_routes.py b/docs/introduction/codeexamples/foodwiki/auth/foodwiki_auth_routes.py new file mode 100644 index 000000000..9a98a2279 --- /dev/null +++ b/docs/introduction/codeexamples/foodwiki/auth/foodwiki_auth_routes.py @@ -0,0 +1,127 @@ +from foodwiki_config import app, requirer +from foodwiki_db import APIKeyProvisioner +from foodwiki_templates import page, refresh + +from twisted.web.template import slot, tags + +from klein import Authorization, Field, Form, RenderableForm +from klein.interfaces import ISimpleAccountBinding + + +@requirer.require( + page.routed( + app.route("/signup", methods=["POST"]), + tags.h1("signed up", slot("signedUp")), + ), + username=Field.text(), + password=Field.password(), + password2=Field.password(), + binding=Authorization(ISimpleAccountBinding), +) +async def signup( + username: str, password: str, password2: str, binding: ISimpleAccountBinding +) -> dict: + await binding.createAccount(username, "", password) + return {"signedUp": "yep", "headExtras": refresh("/login")} + + +@requirer.require( + page.routed( + app.route("/signup", methods=["GET"]), + tags.div(tags.h1("sign up pls"), slot("signupForm")), + ), + theForm=Form.rendererFor(signup, action="/signup"), +) +async def showSignup(theForm: RenderableForm) -> dict: + return {"signupForm": theForm} + + +@requirer.require( + page.routed( + app.route("/login", methods=["POST"]), + tags.div(tags.h1("logged in", slot("didlogin"))), + ), + username=Field.text(), + password=Field.password(), + binding=Authorization(ISimpleAccountBinding), +) +async def login( + username: str, password: str, binding: ISimpleAccountBinding +) -> dict: + didLogIn = await binding.bindIfCredentialsMatch(username, password) + if didLogIn is not None: + return { + "didlogin": "yes", + "headExtras": refresh("/"), + } + else: + return { + "didlogin": "no", + "headExtras": refresh("/login"), + } + + +@requirer.require( + page.routed(app.route("/login", methods=["GET"]), slot("loginForm")), + loginForm=Form.rendererFor(login, action="/login"), +) +def loginForm(loginForm: RenderableForm) -> dict: + return {"loginForm": loginForm} + + +@requirer.require( + page.routed( + app.route("/logout", methods=["POST"]), + tags.div(tags.h1("logged out ", slot("didlogout"))), + ), + binding=Authorization(ISimpleAccountBinding), + ignored=Field.submit("log out"), +) +async def logout( + binding: ISimpleAccountBinding, + ignored: str, +) -> dict: + await binding.unbindThisSession() + return {"didlogout": "yes", "headExtras": refresh("/")} + + +@requirer.require( + page.routed( + app.route("/logout", methods=["GET"]), + tags.div(slot("button")), + ), + form=Form.rendererFor(logout, action="/logout"), +) +async def logoutView(form: RenderableForm) -> dict: + return { + "pageTitle": "log out?", + "button": form, + } + + +@requirer.require( + page.routed( + app.route("/new-api-key"), + [ + tags.div("API Key Created"), + tags.div("Copy this key; when you close this window, it's gone:"), + tags.div(tags.code(slot("key"))), + tags.div(tags.a(href="/api-keys")("back to API key management")), + ], + ), + ok=Field.submit("New API Key"), + provisioner=Authorization(APIKeyProvisioner), +) +async def createAPIKey(ok: object, provisioner: APIKeyProvisioner) -> dict: + return {"key": await provisioner.provisionAPIKey()} + + +@requirer.require( + page.routed( + app.route("/api-keys", methods=["GET"]), + tags.div(tags.h1("API Key Management"), slot("form")), + ), + form=Form.rendererFor(createAPIKey, action="/new-api-key"), +) +async def listAPIKeys(form: RenderableForm) -> dict: + return {"form": form} diff --git a/docs/introduction/codeexamples/foodwiki/auth/foodwiki_config.py b/docs/introduction/codeexamples/foodwiki/auth/foodwiki_config.py new file mode 100644 index 000000000..4ded85c5b --- /dev/null +++ b/docs/introduction/codeexamples/foodwiki/auth/foodwiki_config.py @@ -0,0 +1,35 @@ +import sqlite3 +from typing import Optional + +from foodwiki_db import allAuthorizers + +from twisted.internet.defer import Deferred, succeed +from twisted.web.iweb import IRequest + +from klein import Klein, Requirer +from klein.interfaces import ISession +from klein.storage.dbxs.dbapi_async import adaptSynchronousDriver +from klein.storage.sql import SQLSessionProcurer + + +app = Klein() + +DB_FILE = "food-wiki.sqlite" + +asyncDriver = adaptSynchronousDriver( + (lambda: sqlite3.connect(DB_FILE)), sqlite3.paramstyle +) + +sessions = SQLSessionProcurer(asyncDriver, allAuthorizers) +requirer = Requirer() + + +@requirer.prerequisite([ISession]) +def procurer(request: IRequest) -> Deferred[ISession]: + result: Optional[ISession] = ISession(request, None) + if result is not None: + # TODO: onValidationFailureFor results in one require nested inside + # another, which invokes this prerequisite twice. this mistake should + # not be easy to make + return succeed(result) + return sessions.procureSession(request) diff --git a/docs/introduction/codeexamples/foodwiki/auth/foodwiki_db.py b/docs/introduction/codeexamples/foodwiki/auth/foodwiki_db.py new file mode 100644 index 000000000..8fcb8eb1a --- /dev/null +++ b/docs/introduction/codeexamples/foodwiki/auth/foodwiki_db.py @@ -0,0 +1,181 @@ +from dataclasses import dataclass +from typing import Any, AsyncIterable, Optional, Protocol, Sequence + +from klein.interfaces import ( + ISession, + ISessionStore, + ISimpleAccount, + ISimpleAccountBinding, + SessionMechanism, +) +from klein.storage.dbxs import accessor, many, query, statement +from klein.storage.dbxs.dbapi_async import ( + AsyncConnectable, + AsyncConnection, + transaction, +) +from klein.storage.sql import applyBasicSchema, authorizerFor +from klein.storage.sql._sql_glue import SQLAuthorizer + + +foodTable = """ +CREATE TABLE food ( + name VARCHAR NOT NULL, + rating INTEGER NOT NULL, + rated_by VARCHAR NOT NULL, + FOREIGN KEY(rated_by) + REFERENCES account(account_id) + ON DELETE CASCADE +) +""" + + +async def applySchema(connectable: AsyncConnectable) -> None: + await applyBasicSchema(connectable) + async with transaction(connectable) as c: + cur = await c.cursor() + await cur.execute(foodTable) + + +@dataclass +class FoodRating: + txn: AsyncConnection + name: str + rating: int + ratedByAccountID: str + + +@dataclass +class NamedRating: + txn: AsyncConnection + name: str + rating: int + username: str + + +class PublicRatingsDB(Protocol): + @query( + sql=""" + select name, rating, rated_by from food + join account on(food.rated_by = account.account_id) + where account.username = {userName} + """, + load=many(FoodRating), + ) + def ratingsByUserName(self, userName: str) -> AsyncIterable[FoodRating]: + ... + + @query( + sql=""" + select name, rating, account.username from food + join account on(food.rated_by = account.account_id) + order by rating desc + limit 10 + """, + load=many(NamedRating), + ) + def topRatings(self) -> AsyncIterable[NamedRating]: + ... + + +accessPublicRatings = accessor(PublicRatingsDB) + + +@dataclass +class RatingsViewer: + db: PublicRatingsDB + + def ratingsByUserName(self, userName: str) -> AsyncIterable[FoodRating]: + return self.db.ratingsByUserName(userName) + + def topRatings(self) -> AsyncIterable[NamedRating]: + return self.db.topRatings() + + +@authorizerFor(RatingsViewer) +async def authorizeRatingsViewer( + store: ISessionStore, conn: AsyncConnection, session: ISession +) -> RatingsViewer: + return RatingsViewer(accessPublicRatings(conn)) + + +class RatingsDB(Protocol): + @query( + sql="select name, rating, rated_by from food" + "where rated_by = {accountID}", + load=many(FoodRating), + ) + def ratingsByUserID(self, accountID: str) -> AsyncIterable[FoodRating]: + ... + + @statement( + sql=""" + insert into food (rated_by, name, rating) + values ({accountID}, {name}, {rating}) + """ + ) + async def addRating(self, accountID: str, name: str, rating: int) -> None: + ... + + +accessRatings = accessor(RatingsDB) + + +@dataclass +class FoodCritic: + db: RatingsDB + account: ISimpleAccount + + def myRatings(self) -> AsyncIterable[FoodRating]: + return self.db.ratingsByUserID(self.account.accountID) + + async def rateFood(self, name: str, rating: int) -> None: + return await self.db.addRating(self.account.accountID, name, rating) + + +@authorizerFor(FoodCritic) +async def authorizeFoodCritic( + store: ISessionStore, conn: AsyncConnection, session: ISession +) -> Optional[FoodCritic]: + accts = await (await session.authorize([ISimpleAccountBinding]))[ + ISimpleAccountBinding + ].boundAccounts() + if not accts: + return None + return FoodCritic(accessRatings(conn), accts[0]) + + +@dataclass +class APIKeyProvisioner: + sessionStore: ISessionStore + session: ISession + account: ISimpleAccount + + async def provisionAPIKey(self) -> str: + """ + Provision a new API key for the given account. + """ + apiKeySession = await self.sessionStore.newSession( + self.session.isConfidential, SessionMechanism.Header + ) + await self.account.bindSession(apiKeySession) + return apiKeySession.identifier + + +@authorizerFor(APIKeyProvisioner) +async def authorizeProvisioner( + store: ISessionStore, conn: AsyncConnection, session: ISession +) -> Optional[APIKeyProvisioner]: + accts = await (await session.authorize([ISimpleAccountBinding]))[ + ISimpleAccountBinding + ].boundAccounts() + if not accts: + return None + return APIKeyProvisioner(store, session, accts[0]) + + +allAuthorizers: Sequence[SQLAuthorizer[Any]] = [ + authorizeFoodCritic.authorizer, + authorizeRatingsViewer.authorizer, + authorizeProvisioner.authorizer, +] diff --git a/docs/introduction/codeexamples/foodwiki/auth/foodwiki_routes.py b/docs/introduction/codeexamples/foodwiki/auth/foodwiki_routes.py new file mode 100644 index 000000000..1d918ced4 --- /dev/null +++ b/docs/introduction/codeexamples/foodwiki/auth/foodwiki_routes.py @@ -0,0 +1,151 @@ +""" +Simple example of a public website. +""" + + +from typing import Optional, Union + +from foodwiki_config import app, requirer +from foodwiki_db import FoodCritic, RatingsViewer +from foodwiki_templates import food, linkedFood, page, refresh + +from twisted.web.template import Tag, slot, tags + +from klein import ( + Authorization, + Field, + FieldValues, + Form, + Plating, + RenderableForm, +) +from klein.interfaces import ISimpleAccount + + +@page.widgeted +def notLoggedIn() -> dict: + return {Plating.CONTENT: "You are not logged in."} + + +@requirer.require( + page.routed( + app.route("/rate-food", methods=["POST"]), + tags.h1("Rated Food: ", slot("name")), + ), + name=Field.text(), + rating=Field.number(minimum=1, maximum=5, kind=int), + critic=Authorization( + FoodCritic, whenDenied=lambda interface, instance: notLoggedIn.widget() + ), +) +async def postHandler(name: str, rating: int, critic: FoodCritic) -> dict: + await critic.rateFood(name, rating) + return { + "name": name, + "rating": "\N{BLACK STAR}" * rating, + "pageTitle": "Food Rated", + "headExtras": refresh("/"), + } + + +rateFoodForm = Form.rendererFor(postHandler, action="/rate-food") + + +@requirer.require( + page.routed( + app.route("/", methods=["GET"]), + tags.div( + tags.ul(tags.li(render="foods:list")(slot("item"))), + tags.div(slot("rateFoodForm")), + ), + ), + ratingForm=rateFoodForm, + critic=Authorization(FoodCritic, required=False), + viewer=Authorization(RatingsViewer), +) +async def frontPage( + ratingForm: RenderableForm, + critic: Optional[FoodCritic], + viewer: RatingsViewer, +) -> dict: + allRatings = [] + async for eachFood in viewer.topRatings(): + allRatings.append( + linkedFood( + name=eachFood.name, + rating="\N{BLACK STAR}" * eachFood.rating, + username=eachFood.username, + ) + ) + return { + "foods": allRatings, + "rateFoodForm": "" if critic is None else ratingForm, + "pageTitle": "top-rated foods", + } + + +@requirer.require( + page.routed( + app.route("/users/", methods=["GET"]), + tags.div( + tags.ul(tags.li(render="userRatings:list")(slot("item"))), + ), + ), + viewer=Authorization(RatingsViewer), +) +async def userPage(viewer: RatingsViewer, username: str) -> dict: + userRatings = [] + async for eachFood in viewer.ratingsByUserName(username): + userRatings.append( + food(name=eachFood.name, rating="\N{BLACK STAR}" * eachFood.rating) + ) + return { + "userRatings": userRatings, + "pageTitle": f"ratings by {username}", + } + + +@requirer.require( + page.renderMethod, critic=Authorization(ISimpleAccount, required=False) +) +def whenLoggedIn(tag: Tag, critic: Optional[ISimpleAccount]) -> Union[Tag, str]: + return "" if critic is None else tag + + +@requirer.require( + page.renderMethod, critic=Authorization(FoodCritic, required=False) +) +def whenLoggedOut( + tag: Tag, critic: Optional[ISimpleAccount] +) -> Union[Tag, str]: + return "" if critic is not None else tag + + +@requirer.require( + page.routed( + Form.onValidationFailureFor(postHandler), + [tags.h1("invalid form"), tags.div(slot("the-invalid-form"))], + ), + renderer=rateFoodForm, +) +def validationFailed(values: FieldValues, renderer: RenderableForm) -> dict: + renderer.prevalidationValues = values.prevalidationValues + renderer.validationErrors = values.validationErrors + return {"the-invalid-form": renderer} + + +if __name__ == "__main__": + from os.path import exists + + from foodwiki_config import DB_FILE, asyncDriver + from foodwiki_db import applySchema + + # load other routes for side-effects of gathering them into the app object. + __import__("foodwiki_auth_routes") + + from twisted.internet.defer import Deferred + + if not exists(DB_FILE): + Deferred.fromCoroutine(applySchema(asyncDriver)) + + app.run("localhost", 8080) diff --git a/docs/introduction/codeexamples/foodwiki/auth/foodwiki_templates.py b/docs/introduction/codeexamples/foodwiki/auth/foodwiki_templates.py new file mode 100644 index 000000000..5c633cb2e --- /dev/null +++ b/docs/introduction/codeexamples/foodwiki/auth/foodwiki_templates.py @@ -0,0 +1,66 @@ +from twisted.web.template import Tag, slot, tags + +from klein import Plating + + +page = Plating( + tags=tags.html( + tags.head( + tags.title("Food Ratings Example: ", slot("pageTitle")), + slot("headExtras"), + tags.style( + """ + .nav a { + padding-left: 2em; + } + form { + border: 1px solid grey; + border-radius: 1em; + padding: 1em; + } + form label { + display: block; + padding: 0.2em; + } + """ + ), + ), + tags.body( + tags.div(class_="nav")( + "navigation:", + tags.a(href="/")("home"), + tags.a(href="/login", render="whenLoggedOut")("login"), + tags.a(href="/signup", render="whenLoggedOut")("signup"), + tags.a(href="/api-keys", render="whenLoggedIn")( + "API Key Management" + ), + tags.a(href="/logout", render="whenLoggedIn")("logout"), + ), + tags.h1("Food Ratings Example: ", slot("pageTitle")), + tags.div(slot(Plating.CONTENT)), + ), + ), + defaults={"pageTitle": "", "headExtras": ""}, + presentation_slots=["pageTitle", "headExtras"], +) + + +@page.fragment +def food(name: str, rating: str) -> Tag: + return tags.div( + tags.div("food:", name), + tags.div("rating:", rating), + ) + + +@page.fragment +def linkedFood(name: str, rating: str, username: str) -> Tag: + return tags.div( + tags.div("food:", name), + tags.div("rating:", rating), + tags.div("user:", tags.a(href=["/users/", username])(username)), + ) + + +def refresh(url: str) -> Tag: + return tags.meta(content=f"0;URL='{url}'", **{"http-equiv": "refresh"}) diff --git a/docs/introduction/codeexamples/foodwikisql.py b/docs/introduction/codeexamples/foodwikisql.py new file mode 100644 index 000000000..b523698fc --- /dev/null +++ b/docs/introduction/codeexamples/foodwikisql.py @@ -0,0 +1,311 @@ +""" +Simple example of a public website. +""" +import os +import sqlite3 +from dataclasses import dataclass +from typing import AsyncIterable, Optional, Protocol + +from twisted.internet.defer import Deferred, succeed +from twisted.web.iweb import IRequest +from twisted.web.template import Tag, slot, tags + +from klein import ( + Authorization, + Field, + FieldValues, + Form, + Klein, + Plating, + RenderableForm, + Requirer, +) +from klein.interfaces import ( + ISession, + ISessionStore, + ISimpleAccount, + ISimpleAccountBinding, +) +from klein.storage.dbxs import accessor, many, query, statement +from klein.storage.dbxs.dbapi_async import ( + AsyncConnection, + adaptSynchronousDriver, + transaction, +) +from klein.storage.sql import ( + SQLSessionProcurer, + applyBasicSchema, + authorizerFor, +) + + +app = Klein() + + +asyncDriver = adaptSynchronousDriver( + (lambda: sqlite3.connect("food-wiki.sqlite")), sqlite3.paramstyle +) + +foodTable = """ +CREATE TABLE food ( + name VARCHAR NOT NULL, + rating INTEGER NOT NULL, + rated_by VARCHAR NOT NULL, + FOREIGN KEY(rated_by) + REFERENCES account(account_id) + ON DELETE CASCADE +) +""" + + +async def applySchema() -> None: + await applyBasicSchema(asyncDriver) + async with transaction(asyncDriver) as c: + cur = await c.cursor() + await cur.execute(foodTable) + + +@dataclass +class Food: + txn: AsyncConnection + name: str + rating: int + ratedByAccountID: str + + +class FoodListSQL(Protocol): + @query( + sql=""" + select name, rating, account_id from food + where account_id = {accountID} + """, + load=many(Food), + ) + def getFoods(self, accountID: str) -> AsyncIterable[Food]: + ... + + @statement( + sql="insert into food (account_id, name, rating) values " + "({accountID}, {name}, {rating})" + ) + async def addFood(self, accountID: str, name: str, rating: int) -> None: + ... + + +FoodListQueries = accessor(FoodListSQL) + + +@dataclass +class FoodList: + account: ISimpleAccount + db: FoodListSQL + + def foodsForUser(self) -> AsyncIterable[Food]: + return self.db.getFoods(self.account.accountID) + + async def rateFood(self, name: str, rating: int) -> None: + return await self.db.addFood(self.account.accountID, name, rating) + + +@authorizerFor(FoodList) +async def authorizeFoodList( + store: ISessionStore, conn: AsyncConnection, session: ISession +) -> Optional[FoodList]: + accts = await (await session.authorize([ISimpleAccountBinding]))[ + ISimpleAccountBinding + ].boundAccounts() + if not accts: + return None + return FoodList(accts[0], FoodListQueries(conn)) + + +if not os.path.exists("food-wiki.sqlite"): + Deferred.fromCoroutine(applySchema()) + +sessions = SQLSessionProcurer(asyncDriver, [authorizeFoodList.authorizer]) +requirer = Requirer() + + +@requirer.prerequisite([ISession]) +def procurer(request: IRequest) -> Deferred[ISession]: + result: Optional[ISession] = ISession(request, None) + if result is not None: + # TODO: onValidationFailureFor results in one require nested inside + # another, which invokes this prerequisite twice. this mistake should + # not be easy to make + return succeed(result) + return sessions.procureSession(request) + + +style = Plating( + tags=tags.html( + tags.head( + tags.title("Foods Example: ", slot("pageTitle")), + slot("headExtras"), + ), + tags.body(tags.div(slot(Plating.CONTENT))), + ), + defaults={"pageTitle": "Food List", "headExtras": ""}, + presentation_slots={"pageTitle", "headExtras", "addFoodForm", "loginForm"}, +) + +foodsList = [("test", 1)] + + +def refresh(url: str) -> Tag: + return tags.meta(content=f"0;URL='{url}'", **{"http-equiv": "refresh"}) + + +@requirer.require( + style.routed( + app.route("/", methods=["POST"]), + tags.h1("Added Food: ", slot("name")), + ), + name=Field.text(), + rating=Field.number(minimum=1, maximum=5, kind=int), + foodList=Authorization(FoodList), +) +async def postHandler(name: str, rating: int, foodList: FoodList) -> dict: + await foodList.rateFood(name, rating) + return { + "name": name, + "rating": "\N{BLACK STAR}" * rating, + "pageTitle": "Food Added", + "headExtras": refresh("/"), + } + + +@requirer.require( + style.routed( + app.route("/login", methods=["POST"]), + tags.div(tags.h1("logged in", slot("didlogin"))), + ), + username=Field.text(), + password=Field.password(), + binding=Authorization(ISimpleAccountBinding), +) +async def login( + username: str, password: str, binding: ISimpleAccountBinding +) -> dict: + await binding.bindIfCredentialsMatch(username, password) + return { + "didlogin": "yes", + "headExtras": refresh("/"), + } + + +@requirer.require( + style.routed( + app.route("/logout", methods=["POST"]), + tags.div(tags.h1("logged out ", slot("didlogout"))), + ), + binding=Authorization(ISimpleAccountBinding), + ignored=Field.submit("log out"), +) +async def logout( + binding: ISimpleAccountBinding, + ignored: str, +) -> dict: + await binding.unbindThisSession() + return {"didlogout": "yes", "headExtras": refresh("/")} + + +@requirer.require( + style.routed( + app.route("/login", methods=["GET"]), + tags.div("form", slot("form")), + ), + theForm=Form.rendererFor(login, action="/login"), +) +async def showLogin(theForm: object) -> dict: + return {"form": theForm} + + +@requirer.require( + style.routed( + app.route("/signup", methods=["POST"]), + tags.h1("signed up", slot("signedUp")), + ), + username=Field.text(), + password=Field.password(), + password2=Field.password(), + binding=Authorization(ISimpleAccountBinding), +) +async def signup( + username: str, password: str, password2: str, binding: ISimpleAccountBinding +) -> dict: + await binding.createAccount(username, "", password) + return {"signedUp": "yep", "headExtras": refresh("/login")} + + +@requirer.require( + style.routed( + app.route("/signup", methods=["GET"]), + tags.div(tags.h1("sign up pls"), slot("signupForm")), + ), + binding=Authorization(ISimpleAccountBinding), + theForm=Form.rendererFor(signup, action="/signup"), +) +async def showSignup( + binding: ISimpleAccountBinding, + theForm: RenderableForm, +) -> dict: + return {"signupForm": theForm} + + +@Plating.fragment +def food(name: str, rating: str) -> Tag: + return tags.div( + tags.div("food:", name), + tags.div("rating:", rating), + ) + + +@requirer.require( + style.routed( + app.route("/", methods=["GET"]), + tags.div( + tags.ul(tags.li(render="foods:list")(slot("item"))), + tags.div(slot("addFoodForm")), + tags.div(slot("loginForm")), + ), + ), + theForm=Form.rendererFor(postHandler, action="/?post=yes"), + loginForm=Form.rendererFor(login, action="/login"), + logoutForm=Form.rendererFor(logout, action="/logout"), + foodList=Authorization(FoodList, required=False), +) +async def formRenderer( + theForm: RenderableForm, + loginForm: RenderableForm, + logoutForm: RenderableForm, + foodList: Optional[FoodList], +) -> dict: + result = [] + if foodList is not None: + async for eachFood in foodList.foodsForUser(): + result.append(eachFood) + return { + "addFoodForm": theForm if (foodList is not None) else "", + "loginForm": loginForm if (foodList is None) else logoutForm, + "foods": [ + food(name=each.name, rating="\N{BLACK STAR}" * each.rating) + for each in result + ], + } + + +@requirer.require( + style.routed( + Form.onValidationFailureFor(postHandler), + [tags.h1("invalid form"), tags.div(slot("the-invalid-form"))], + ), + renderer=Form.rendererFor(postHandler, action="/?post=yes"), +) +def validationFailed(values: FieldValues, renderer: RenderableForm) -> dict: + renderer.prevalidationValues = values.prevalidationValues + renderer.validationErrors = values.validationErrors + return {"the-invalid-form": renderer} + + +app.run("localhost", 8080) diff --git a/release.py b/release.py index 0116c3882..45907324b 100644 --- a/release.py +++ b/release.py @@ -8,7 +8,7 @@ from subprocess import CalledProcessError, run from sys import exit, stderr from tempfile import mkdtemp -from typing import Any, Dict, NoReturn, Optional, Sequence, cast +from typing import Dict, NoReturn, Optional, Sequence from click import group as commandGroup from click import option as commandOption @@ -56,7 +56,7 @@ def currentVersion() -> Version: """ # Incremental doesn't have an API to do this, so we are duplicating some # code from its source tree. Boo. - versionInfo: Dict[str, Any] = {} + versionInfo: Dict[str, Version] = {} versonFile = Path(__file__).parent / "src" / "klein" / "_version.py" exec(versonFile.read_text(), versionInfo) return versionInfo["__version__"] @@ -109,7 +109,7 @@ def releaseTagName(version: Version) -> str: """ Compute the name of the release tag for the given version. """ - return cast(str, version.public()) + return version.public() def createReleaseBranch(repository: Repository, version: Version) -> Head: diff --git a/requirements/tox-pin-base.txt b/requirements/tox-pin-base.txt index 732534da4..8d3e02468 100644 --- a/requirements/tox-pin-base.txt +++ b/requirements/tox-pin-base.txt @@ -3,7 +3,7 @@ Automat==22.10.0 characteristic==14.3.0 constantly==15.1.0 hyperlink==21.0.0 -incremental==21.3.0 +incremental==22.10.0 PyHamcrest==2.1.0 six==1.16.0 Tubes==0.2.1 diff --git a/setup.py b/setup.py index 401a619c2..1ecb7afd3 100644 --- a/setup.py +++ b/setup.py @@ -35,17 +35,34 @@ "Tubes", "Twisted>=16.6", # 16.6 introduces ensureDeferred "typing_extensions ; python_version<'3.10'", + # PyPy doesn't have hashlib.scrypt, which we need for secure + # password storage. + "cryptography ; platform_python_implementation == 'PyPy'", "Werkzeug", "zope.interface", ], keywords="twisted flask werkzeug web", license="MIT", name="klein", - packages=["klein", "klein.storage", "klein.test"], + packages=[ + "klein", + "klein.storage", + "klein.storage.memory", + "klein.storage.memory.test", + "klein.storage.dbxs", + "klein.storage.dbxs.test", + "klein.storage.passwords", + "klein.storage.passwords.test", + "klein.storage.sql", + "klein.storage.sql.test", + "klein.test", + "klein.storage.test", + ], package_dir={"": "src"}, - package_data=dict( - klein=["py.typed"], - ), + package_data={ + "klein": ["py.typed"], + "klein.storage.sql": ["basic_auth_schema.sql"], + }, url="https://github.com/twisted/klein", maintainer="Twisted Matrix Laboratories", maintainer_email="twisted-python@twistedmatrix.com", diff --git a/src/klein/_form.py b/src/klein/_form.py index 9fce977f9..adf0259f2 100644 --- a/src/klein/_form.py +++ b/src/klein/_form.py @@ -292,6 +292,7 @@ def submit(cls, value: str) -> "Field": formInputType="submit", noLabel=True, default=value, + value=value, ) diff --git a/src/klein/_isession.py b/src/klein/_isession.py index 34dfca342..c4f65e1ae 100644 --- a/src/klein/_isession.py +++ b/src/klein/_isession.py @@ -1,4 +1,20 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Sequence, Type +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + AsyncContextManager, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Type, + TypeVar, + Union, + overload, +) import attr from constantly import NamedConstant, Names @@ -8,6 +24,8 @@ from twisted.python.components import Componentized from twisted.web.iweb import IRequest +from ._typing_compat import Protocol + if TYPE_CHECKING: from ._app import KleinRenderable @@ -52,13 +70,37 @@ class SessionMechanism(Names): Header = NamedConstant() +T = TypeVar("T") +K = TypeVar("K") +V = TypeVar("V") + + +class AuthorizationMap(Protocol): + @overload + def get(self, key: Type[V]) -> V: + ... + + @overload + def get(self, key: Type[V], default: T) -> Union[V, T]: + ... + + def get(self, *args: Any, **kwargs: Any) -> Any: + ... + + def __getitem__(self, key: Type[V]) -> V: + ... + + def __setitem__(self, key: Type[V], value: V) -> None: + ... + + class ISession(Interface): """ An L{ISession} provider contains an identifier for the session, information about how the session was negotiated with the client software, and """ - identifier = Attribute( + identifier: str = Attribute( """ L{str} identifying a session. @@ -90,7 +132,9 @@ class ISession(Interface): """ ) - def authorize(interfaces: Iterable[Type[Interface]]) -> Deferred: + def authorize( + interfaces: Iterable[Type[object]], + ) -> Deferred[AuthorizationMap]: """ Retrieve other objects from this session. @@ -117,19 +161,26 @@ class ISessionStore(Interface): def newSession( isConfidential: bool, authenticatedBy: SessionMechanism, - ) -> Deferred: + ) -> Deferred[ISession]: """ Create a new L{ISession}. - @return: a new session with a new identifier. - @rtype: L{Deferred} firing with L{ISession}. + @param isConfidential: Is the new session being created a confidential + (i.e. “sent over HTTPS” session)? + + @param authenticatedBy: Was the request for this new session + authenticated by a header or a cookie? + + @return: a new session with a new (randomly generated) identifier that + can later be passed back to this object's + L{ISessionStore.loadSession}. """ def loadSession( identifier: str, isConfidential: bool, authenticatedBy: SessionMechanism, - ) -> Deferred: + ) -> Deferred[ISession]: """ Load a session given the given identifier and security properties. @@ -165,14 +216,16 @@ class ISimpleAccountBinding(Interface): attribute as a component. """ - def bindIfCredentialsMatch(username: str, password: str) -> None: + def bindIfCredentialsMatch( + username: str, password: str + ) -> Deferred[Optional[ISimpleAccount]]: """ Attach the session this is a component of to an account with the given username and password, if the given username and password correctly authenticate a principal. """ - def boundAccounts() -> Deferred: + def boundAccounts() -> Deferred[List[ISimpleAccount]]: """ Retrieve the accounts currently associated with the session this is a component of. @@ -180,13 +233,15 @@ def boundAccounts() -> Deferred: @return: L{Deferred} firing with a L{list} of L{ISimpleAccount}. """ - def unbindThisSession() -> None: + def unbindThisSession() -> Deferred[None]: """ Disassociate the session this is a component of from any accounts it's logged in to. """ - def createAccount(username: str, email: str, password: str) -> None: + def createAccount( + username: str, email: str, password: str + ) -> Deferred[Optional[ISimpleAccount]]: """ Create a new account with the given username, email and password. """ @@ -197,25 +252,25 @@ class ISimpleAccount(Interface): Data-store agnostic account interface. """ - username = Attribute( + username: str = Attribute( """ Unicode username. """ ) - accountID = Attribute( + accountID: str = Attribute( """ Unicode account-ID. """ ) - def bindSession(session: ISession) -> None: + def bindSession(session: ISession) -> Deferred[None]: """ Bind the given session to this account; i.e. authorize the given session to act on behalf of this account. """ - def changePassword(newPassword: str) -> None: + def changePassword(newPassword: str) -> Deferred[None]: """ Change the password of this account. """ @@ -229,7 +284,7 @@ class ISessionProcurer(Interface): def procureSession( request: IRequest, forceInsecure: bool = False - ) -> Deferred: + ) -> Deferred[ISession]: """ Retrieve a session using whatever technique is necessary. @@ -309,8 +364,8 @@ class IRequestLifecycle(Interface): def addPrepareHook( beforeHook: Callable, - requires: Sequence[Type[Interface]] = (), - provides: Sequence[Type[Interface]] = (), + requires: Sequence[Type[object]] = (), + provides: Sequence[Type[object]] = (), ) -> None: """ Add a hook that promises to prepare the request by supplying the given @@ -352,6 +407,21 @@ def registerInjector( """ +class IRequirementContext(Interface): + """ + An L{IRequirementContext} is a request component that can be used during + C{@require} dependency injection. + + In particular, this can be used to raise L{EarlyExit} from C{__exit__} to + reconsider the return value I{after} the route's returning. + """ + + async def enter_async_context(cm: AsyncContextManager[T]) -> T: + """ + Add the contextmanager to the list of context managers. + """ + + @attr.s(auto_attribs=True) class EarlyExit(Exception): """ @@ -363,4 +433,4 @@ class EarlyExit(Exception): supplied as the route's response. """ - alternateReturnValue: "KleinRenderable" + alternateReturnValue: KleinRenderable diff --git a/src/klein/_plating.py b/src/klein/_plating.py index e9133dd22..cd05faa28 100644 --- a/src/klein/_plating.py +++ b/src/klein/_plating.py @@ -6,6 +6,7 @@ from __future__ import annotations from functools import partial +from inspect import signature from json import dumps from operator import setitem from typing import Any, Callable, Generator, List, Tuple, cast @@ -15,10 +16,11 @@ from twisted.internet.defer import inlineCallbacks from twisted.web.error import MissingRenderMethod from twisted.web.iweb import IRequest -from twisted.web.template import Element, Tag, TagLoader +from twisted.web.template import Element, Tag, TagLoader, slot from ._app import _call from ._decorators import bindable, modified, originalName +from ._typing_compat import ParamSpec StackType = List[Tuple[Any, Callable[[Any], None]]] @@ -193,6 +195,9 @@ def renderList(request, tag): raise MissingRenderMethod(self, name) +P = ParamSpec("P") + + class Plating: """ A L{Plating} is a container which can be used to generate HTML from data. @@ -233,6 +238,14 @@ def mymethod( instance: Any, request: IRequest, *args: Any, **kw: Any ) -> Any: data = yield _call(instance, method, request, *args, **kw) + if not hasattr(data, "__setitem__"): + # Allow plating routes to return other forms of Klein + # renderable object, if they want to customize an HTTP + # response in specific cases, such as returning a redirect. + # This is a very narrow test rather than a more general + # isinstance(data, dict) or similar because older versions + # of Klein did not have this check. + return data if _should_return_json(request): json_data = self._defaults.copy() json_data.update(data) @@ -312,3 +325,64 @@ def widgeted(self, function): template elements. """ return self._Widget(self, function, None) + + @classmethod + def fragment(cls, f: Callable[P, Tag]) -> Callable[P, Tag]: + """ + Decorator for a function that presents a formatted view of a set of + slots. For example, if we have a page that displays a list of links to + articles:: + + page = Plating(tags=tags.html(tags.body(Plating.CONTENT))) + + @Plating.fragment + def post(title: str, author: str, + publishDate: str, url: str) -> Tag: + return tags.div( + tags.div("Title: ", tags.a(href=url)(title)), + tags.div("Author: ", author), + tags.div("Published: ", publishDate), + ) + + @page.routed(self.app.route("/"), + tags.div(render="posts:list")(slot("item"))) + def plateMe(result): + return { + "posts": [ + post("First Post", "Alice", + "2023-01-01", "http://example.com/1"), + post("Second Post", "Bob", + "2023-02-02", "http://example.com/2"), + ... + ], + } + + When viewed as HTML, this will render each post inline as a C{
}, + and in JSON, it will be an object that looks like this:: + + { + "posts": [ + {"title": "First Post", "author": "Alice", + "publishDate": "2023-01-01", + "url": "http://example.com/1"}, + ... + ] + } + + @note: The function being decorated will not be invoked each time it + appears to be called with strings, but rather, called once at + import time, and the types of its arguments will actually be + C{slot} objects, which will I{later} be filled in with the + stipulated types. + """ + sigf = signature(f) + computedArgs: P.args = [slot(name) for name in sigf.parameters] + + def makeDict(*args: object, **kwargs: object) -> dict[str, object]: + bound = sigf.bind(*args, **kwargs) + return bound.arguments + + c: Callable[P, Tag] = ( + cls(tags=f(*computedArgs)).widgeted(makeDict).widget + ) + return c diff --git a/src/klein/_requirer.py b/src/klein/_requirer.py index 0b9b98c65..def43e81d 100644 --- a/src/klein/_requirer.py +++ b/src/klein/_requirer.py @@ -1,19 +1,35 @@ -from typing import Any, Callable, Dict, Generator, List, Sequence, Type +# -*- test-case-name: klein.test.test_requirer -*- +from contextlib import AsyncExitStack +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generator, + List, + Sequence, + Type, + TypeVar, + Union, +) import attr -from zope.interface import Interface, implementer +from zope.interface import implementer from twisted.internet.defer import inlineCallbacks from twisted.python.components import Componentized from twisted.web.iweb import IRequest +from twisted.web.server import Request from ._app import _call from ._decorators import bindable, modified +from ._util import eagerDeferredCoroutine from .interfaces import ( EarlyExit, IDependencyInjector, IRequestLifecycle, IRequiredParameter, + IRequirementContext, ) @@ -29,8 +45,8 @@ class RequestLifecycle: def addPrepareHook( self, beforeHook: Callable, - requires: Sequence[Type[Interface]] = (), - provides: Sequence[Type[Interface]] = (), + requires: Sequence[Type[object]] = (), + provides: Sequence[Type[object]] = (), ) -> None: # TODO: topological requirements sort self._prepareHooks.append(beforeHook) @@ -51,11 +67,27 @@ def runPrepareHooks( yield _call(instance, hook, request) +@implementer(IRequirementContext) +class RequirementContext(AsyncExitStack): + """ + Subclass only to mark the implementation of this interface; this is in + every way an C{ExitStack}. + """ + + _routeDecorator = Any # a decorator like @route _routeT = Any # a thing decorated by a decorator like @route _prerequisiteCallback = Callable[[IRequestLifecycle], None] +T = TypeVar("T") + + +async def _maybeAsync(v: Union[T, Awaitable[T]]) -> T: + if isinstance(v, Awaitable): + return await v + return v + @attr.s(auto_attribs=True) class Requirer: @@ -67,8 +99,8 @@ class Requirer: def prerequisite( self, - providesComponents: Sequence[Type[Interface]], - requiresComponents: Sequence[Type[Interface]] = (), + providesComponents: Sequence[Type[object]], + requiresComponents: Sequence[Type[object]] = (), ) -> Callable[[Callable], Callable]: """ Specify a component that is a pre-requisite of every request routed @@ -127,24 +159,38 @@ def decorator(functionWithRequirements: Callable) -> Callable: @modified("dependency-injecting route", functionWithRequirements) @bindable - @inlineCallbacks - def router( - instance: Any, request: IRequest, *args: Any, **routeParams: Any + @eagerDeferredCoroutine + async def router( + instance: Any, request: Request, *args: Any, **routeParams: Any ) -> Any: - injected = routeParams.copy() try: - yield lifecycle.runPrepareHooks(instance, request) - for k, injector in injectors.items(): - injected[k] = yield injector.injectValue( - instance, request, routeParams - ) + try: + shouldSet = False + async with RequirementContext() as stack: + if IRequirementContext(request, None) is None: + shouldSet = True + request.setComponent(IRequirementContext, stack) + injected = routeParams.copy() + await lifecycle.runPrepareHooks(instance, request) + for k, injector in injectors.items(): + injected[k] = await _maybeAsync( + injector.injectValue( + instance, request, routeParams + ) + ) + return await _maybeAsync( + _call( + instance, + functionWithRequirements, + *args, + **injected, + ) + ) + finally: + if shouldSet: + request.unsetComponent(IRequirementContext) except EarlyExit as ee: - result = ee.alternateReturnValue - else: - result = yield _call( - instance, functionWithRequirements, *args, **injected - ) - return result + return ee.alternateReturnValue fWR, iC = functionWithRequirements, injectionComponents fWR.injectionComponents = iC # type: ignore[attr-defined] diff --git a/src/klein/_session.py b/src/klein/_session.py index a4d843baf..26de02a0f 100644 --- a/src/klein/_session.py +++ b/src/klein/_session.py @@ -1,9 +1,21 @@ # -*- test-case-name: klein.test.test_session -*- -from typing import Any, Callable, Dict, Optional, Sequence, Type, Union, cast +from __future__ import annotations + +from typing import ( + Any, + Callable, + Dict, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) import attr -from zope.interface import Interface, implementer +from zope.interface import implementer from twisted.internet.defer import inlineCallbacks from twisted.python.components import Componentized @@ -11,7 +23,9 @@ from twisted.web.http import UNAUTHORIZED from twisted.web.iweb import IRequest from twisted.web.resource import Resource +from twisted.web.server import Request +from ._util import eagerDeferredCoroutine from .interfaces import ( EarlyExit, IDependencyInjector, @@ -26,6 +40,90 @@ ) +async def cookieLoader( + self: SessionProcurer, + request: IRequest, + token: str, + sentSecurely: bool, + cookieName: Union[str, bytes], +) -> ISession: + """ + Procuring a session from a cookie is complex. First, just try to look it + up based on the current cookie, but then, do a bunch of checks to see if we + can set up a new session, then set one up. + """ + try: + return await self._store.loadSession( + token, sentSecurely, SessionMechanism.Cookie + ) + except NoSuchSession: + pass + + # No existing session. + if request.startedWriting: # type: ignore[attr-defined] + # At this point, if the mechanism is Header, we either have + # a valid session or we bailed after NoSuchSession above. + raise TooLateForCookies( + "You tried initializing a cookie-based session too" + " late in the request pipeline; the headers" + " were already sent." + ) + if request.method != b"GET": + # Sessions should only ever be auto-created by GET + # requests; there's no way that any meaningful data + # manipulation could succeed (no CSRF token check could + # ever succeed, for example). + raise NoSuchSession( + "Can't initialize a session on a " + "{method} request.".format(method=request.method.decode("ascii")) + ) + if not self._setCookieOnGET: + # We don't have a session ID at all, and we're not allowed + # by policy to set a cookie on the client. + raise NoSuchSession( + "Cannot auto-initialize a session for this request." + ) + session = await self._store.newSession( + sentSecurely, SessionMechanism.Cookie + ) + + # https://github.com/twisted/twisted/issues/11865 + wrongSignature: Request = request # type:ignore[assignment] + wrongSignature.addCookie( + cookieName, + session.identifier, + max_age=str(self._maxAge), + domain=self._cookieDomain, + path=self._cookiePath, + secure=sentSecurely, + httpOnly=True, + ) + + return session + + +async def headerLoader( + self: SessionProcurer, + request: IRequest, + token: str, + sentSecurely: bool, + cookieName: Union[str, bytes], +) -> ISession: + """ + Procuring a session via a header API key is very simple. Just look it up + and fail if you can't find it. + """ + return await self._store.loadSession( + token, sentSecurely, SessionMechanism.Header + ) + + +loaderForMechanism = { + SessionMechanism.Cookie: cookieLoader, + SessionMechanism.Header: headerLoader, +} + + @implementer(ISessionProcurer) @attr.s(auto_attribs=True) class SessionProcurer: @@ -64,115 +162,74 @@ class SessionProcurer: _insecureTokenHeader: bytes = b"X-INSECURE-Auth-Token" _setCookieOnGET: bool = True - @inlineCallbacks - def procureSession( + def _tokenTransportAttributes( + self, request: IRequest, forceInsecure: bool + ) -> Tuple[bytes, bytes, bool]: + """ + @return: 3-tuple of header, cookie, secure + """ + secure = (self._secureTokenHeader, self._secureCookie, True) + insecure = (self._insecureTokenHeader, self._insecureCookie, False) + + if request.isSecure(): + return insecure if forceInsecure else secure + + # Have we inadvertently disclosed a secure token over an insecure + # transport, for example, due to a buggy client? + allPossibleSentTokens: Sequence[bytes] = sum( + ( + request.requestHeaders.getRawHeaders(header, []) + for header in [ + self._secureTokenHeader, + self._insecureTokenHeader, + ] + ), + [], + ) + [ + it + for it in [ + request.getCookie(cookie) + for cookie in [self._secureCookie, self._insecureCookie] + if cookie is not None + ] + if it + ] + + # Fun future feature: honeypot that does this over HTTPS, but sets + # isSecure() to return false because it serves up a cert for the + # wrong hostname or an invalid cert, to keep API clients honest + # about chain validation. + self._store.sentInsecurely( + [each.decode() for each in allPossibleSentTokens] + ) + return insecure + + @eagerDeferredCoroutine + async def procureSession( self, request: IRequest, forceInsecure: bool = False - ) -> Any: - alreadyProcured = cast(Componentized, request).getComponent(ISession) + ) -> ISession: + alreadyProcured: Optional[ISession] = ISession(request, None) if alreadyProcured is not None: if not forceInsecure or not request.isSecure(): return alreadyProcured - if request.isSecure(): - if forceInsecure: - tokenHeader = self._insecureTokenHeader - cookieName: Union[str, bytes] = self._insecureCookie - sentSecurely = False - else: - tokenHeader = self._secureTokenHeader - cookieName = self._secureCookie - sentSecurely = True - else: - # Have we inadvertently disclosed a secure token over an insecure - # transport, for example, due to a buggy client? - allPossibleSentTokens: Sequence[str] = sum( - ( - request.requestHeaders.getRawHeaders(header, []) - for header in [ - self._secureTokenHeader, - self._insecureTokenHeader, - ] - ), - [], - ) + [ - it - for it in [ - request.getCookie(cookie) - for cookie in [self._secureCookie, self._insecureCookie] - ] - if it - ] - # Does it seem like this check is expensive? It sure is! Don't want - # to do it? Turn on your dang HTTPS! - self._store.sentInsecurely(allPossibleSentTokens) - tokenHeader = self._insecureTokenHeader - cookieName = self._insecureCookie - sentSecurely = False - # Fun future feature: honeypot that does this over HTTPS, but sets - # isSecure() to return false because it serves up a cert for the - # wrong hostname or an invalid cert, to keep API clients honest - # about chain validation. + tokenHeader, cookieName, sentSecurely = self._tokenTransportAttributes( + request, forceInsecure + ) + sentHeader = (request.getHeader(tokenHeader) or b"").decode("utf-8") sentCookie = (request.getCookie(cookieName) or b"").decode("utf-8") - if sentHeader: - mechanism = SessionMechanism.Header - else: - mechanism = SessionMechanism.Cookie - if not (sentHeader or sentCookie): - session = None - else: - try: - session = yield self._store.loadSession( - sentHeader or sentCookie, sentSecurely, mechanism - ) - except NoSuchSession: - if mechanism == SessionMechanism.Header: - raise - session = None - if mechanism == SessionMechanism.Cookie and ( - session is None or session.identifier != sentCookie - ): - if session is None: - if request.startedWriting: # type: ignore[attr-defined] - # At this point, if the mechanism is Header, we either have - # a valid session or we bailed after NoSuchSession above. - raise TooLateForCookies( - "You tried initializing a cookie-based session too" - " late in the request pipeline; the headers" - " were already sent." - ) - if request.method != b"GET": - # Sessions should only ever be auto-created by GET - # requests; there's no way that any meaningful data - # manipulation could succeed (no CSRF token check could - # ever succeed, for example). - raise NoSuchSession( - "Can't initialize a session on a " - "{method} request.".format( - method=request.method.decode("ascii") - ) - ) - if not self._setCookieOnGET: - # We don't have a session ID at all, and we're not allowed - # by policy to set a cookie on the client. - raise NoSuchSession( - "Cannot auto-initialize a session for this request." - ) - session = yield self._store.newSession(sentSecurely, mechanism) - identifierInCookie = session.identifier - if not isinstance(identifierInCookie, str): - identifierInCookie = identifierInCookie.encode("ascii") - if not isinstance(cookieName, str): - cookieName = cookieName.decode("ascii") - request.addCookie( # type: ignore[call-arg] - cookieName, - identifierInCookie, - max_age=str(self._maxAge), - domain=self._cookieDomain, - path=self._cookiePath, - secure=sentSecurely, - httpOnly=True, - ) + + mechanism, token = ( + (SessionMechanism.Header, sentHeader) + if sentHeader + else (SessionMechanism.Cookie, sentCookie) + ) + + session = await loaderForMechanism[mechanism]( + self, request, token, sentSecurely, cookieName + ) + if sentSecurely or not request.isSecure(): # Do not cache the insecure session on the secure request, thanks. cast(Componentized, request).setComponent(ISession, session) @@ -180,7 +237,7 @@ def procureSession( class AuthorizationDenied(Resource): - def __init__(self, interface: Type[Interface], instance: Any) -> None: + def __init__(self, interface: Type[object], instance: Any) -> None: self._interface = interface super().__init__() @@ -244,9 +301,9 @@ def myRoute(adminPowers): C{required} is set to C{False}. """ - _interface: Type[Interface] + _interface: Type[object] _required: bool = True - _whenDenied: Callable[[Type[Interface], Any], Any] = AuthorizationDenied + _whenDenied: Callable[[Type[object], Any], Any] = AuthorizationDenied def registerInjector( self, diff --git a/src/klein/_typing_compat.py b/src/klein/_typing_compat.py index bffed3be3..26f42c299 100644 --- a/src/klein/_typing_compat.py +++ b/src/klein/_typing_compat.py @@ -6,16 +6,24 @@ import sys -if sys.version_info > (3, 8): - from typing import Protocol +if sys.version_info > (3, 10): + from typing import Concatenate, ParamSpec, Protocol else: - from typing_extensions import Protocol + # PyPy 3.9 seems to have a bonus runtime check for Protocol's generic + # arguments all being TypeVars, so lie to it about ParamSpec. + from typing import TYPE_CHECKING + from typing_extensions import Concatenate, Protocol -if sys.version_info > (3, 10): - from typing import Concatenate, ParamSpec -else: - from typing_extensions import Concatenate, ParamSpec + if TYPE_CHECKING: + from typing_extensions import ParamSpec + else: + from platform import python_implementation + + if python_implementation() == "PyPy": + from typing import TypeVar as ParamSpec + else: + from typing_extensions import ParamSpec __all__ = [ diff --git a/src/klein/_util.py b/src/klein/_util.py new file mode 100644 index 000000000..ce7193ae3 --- /dev/null +++ b/src/klein/_util.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Coroutine, TypeVar + +from twisted.internet.defer import Deferred + +from ._typing_compat import ParamSpec + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +if TYPE_CHECKING: # pragma: no cover + # https://github.com/twisted/twisted/issues/11862 + def deferToThread(f: Callable[[], _T]) -> Deferred[_T]: + ... + +else: + from twisted.internet.threads import deferToThread + + +def eagerDeferredCoroutine( + f: Callable[_P, Coroutine[Deferred[object], object, _T]] +) -> Callable[_P, Deferred[_T]]: + def inner(*args: _P.args, **kwargs: _P.kwargs) -> Deferred[_T]: + return Deferred.fromCoroutine(f(*args, **kwargs)) + + return inner + + +def threadedDeferredFunction(f: Callable[_P, _T]) -> Callable[_P, Deferred[_T]]: + """ + When the decorated function is called, always run it in a thread. + """ + + def inner(*args: _P.args, **kwargs: _P.kwargs) -> Deferred[_T]: + return deferToThread(lambda: f(*args, **kwargs)) + + return inner + + +__all__ = [ + "eagerDeferredCoroutine", + "deferToThread", + "threadedDeferredFunction", +] diff --git a/src/klein/interfaces.py b/src/klein/interfaces.py index 481799e8d..6c35553a2 100644 --- a/src/klein/interfaces.py +++ b/src/klein/interfaces.py @@ -5,6 +5,7 @@ IDependencyInjector, IRequestLifecycle, IRequiredParameter, + IRequirementContext, ISession, ISessionProcurer, ISessionStore, @@ -23,6 +24,7 @@ "IKleinRequest", "IRequestLifecycle", "IRequiredParameter", + "IRequirementContext", "ISession", "ISessionProcurer", "ISessionStore", diff --git a/src/klein/storage/dbxs/__init__.py b/src/klein/storage/dbxs/__init__.py new file mode 100644 index 000000000..e9aaa60c4 --- /dev/null +++ b/src/klein/storage/dbxs/__init__.py @@ -0,0 +1,36 @@ +""" +C{DBXS} (“database access”) is an asynchronous database access layer based on +lightly organizing queries into simple data structures rather than a more +general query builder or object-relational mapping. + +It serves as the basis for L{klein.storage.sql}. +""" + +from ._access import ( + ExtraneousMethods, + IncorrectResultCount, + NotEnoughResults, + ParamMismatch, + TooManyResults, + accessor, + many, + maybe, + one, + query, + statement, +) + + +__all__ = [ + "one", + "many", + "maybe", + "accessor", + "statement", + "query", + "ParamMismatch", + "TooManyResults", + "NotEnoughResults", + "IncorrectResultCount", + "ExtraneousMethods", +] diff --git a/src/klein/storage/dbxs/_access.py b/src/klein/storage/dbxs/_access.py new file mode 100644 index 000000000..ad91bc46b --- /dev/null +++ b/src/klein/storage/dbxs/_access.py @@ -0,0 +1,325 @@ +# -*- test-case-name: klein.storage.dbxs.test.test_access -*- +from __future__ import annotations + +from dataclasses import dataclass, field +from inspect import BoundArguments, signature +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Coroutine, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) + +from ..._typing_compat import ParamSpec, Protocol +from .dbapi_async import AsyncConnection, AsyncCursor + + +T = TypeVar("T") +P = ParamSpec("P") +A = TypeVar("A", bound=Union[AsyncIterable[object], Awaitable[object]]) + + +class ParamMismatch(Exception): + """ + The parameters required by the query are different than the parameters + specified by the function. + """ + + +class IncorrectResultCount(Exception): + """ + An assumption about the number of rows from a given query was violated; + there were either too many or too few. + """ + + +class NotEnoughResults(IncorrectResultCount): + """ + There were not enough results for the query to satify L{one}. + """ + + +class TooManyResults(IncorrectResultCount): + """ + There were more results for a query than expected; more than one for + L{one}, or any at all for L{zero}. + """ + + +class ExtraneousMethods(Exception): + """ + An access pattern defined extraneous methods. + """ + + +def one( + load: Callable[..., T], +) -> Callable[[object, AsyncCursor], Coroutine[object, object, T]]: + """ + Fetch a single result with a translator function. + """ + + async def translate(db: object, cursor: AsyncCursor) -> T: + rows = await cursor.fetchall() + if len(rows) < 1: + raise NotEnoughResults() + if len(rows) > 1: + raise TooManyResults() + return load(db, *rows[0]) + + return translate + + +def maybe( + load: Callable[..., T] +) -> Callable[[object, AsyncCursor], Coroutine[object, object, Optional[T]]]: + """ + Fetch a single result and pass it to a translator function, but return None + if it's not found. + """ + + async def translate(db: object, cursor: AsyncCursor) -> Optional[T]: + rows = await cursor.fetchall() + if len(rows) < 1: + return None + if len(rows) > 1: + raise TooManyResults() + return load(db, *rows[0]) + + return translate + + +def many( + load: Callable[..., T] +) -> Callable[[object, AsyncCursor], AsyncIterable[T]]: + """ + Fetch multiple results with a function to translate rows. + """ + + async def translate(db: object, cursor: AsyncCursor) -> AsyncIterable[T]: + while True: + row = await cursor.fetchone() + if row is None: + return + yield load(db, *row) + + return translate + + +async def zero(loader: object, cursor: AsyncCursor) -> None: + """ + Zero record loader. + """ + result = await cursor.fetchone() + if result is not None: + raise TooManyResults("statemnts should not return values") + return None + + +METADATA_KEY = "__query_metadata__" + + +@dataclass +class MaybeAIterable: + down: Any + + def __await__(self) -> Any: + return self.down.__await__() + + async def __aiter__(self) -> Any: + actuallyiter = await self + async for each in actuallyiter: + yield each + + +@dataclass +class QueryMetadata: + """ + Metadata defining a certain function on a protocol as a query method. + """ + + sql: str + load: Callable[[AccessProxy, AsyncCursor], A] + proxyMethod: Callable[..., Awaitable[object]] = field(init=False) + + def setOn(self, protocolMethod: Any) -> None: + """ + Attach this QueryMetadata to the given protocol method definition, + checking its arguments and computing C{proxyMethod} in the process, + raising L{ParamMismatch} if the expected parameters do not match. + """ + sig = signature(protocolMethod) + precomputedSQL: Dict[str, Tuple[str, QmarkParamstyleMap]] = {} + for style, mapFactory in styles.items(): + mapInstance = mapFactory() + styledSQL = self.sql.format_map(mapInstance) + precomputedSQL[style] = (styledSQL, mapInstance) + + sampleSQL, sampleInstance = precomputedSQL["qmark"] + selfExcluded = list(sig.parameters)[1:] + if set(sampleInstance.names) != set(selfExcluded): + raise ParamMismatch( + f"when defining {protocolMethod.__name__}(...), " + f"SQL placeholders {sampleInstance.names} != " + f"function params {selfExcluded}" + ) + + def proxyMethod( + proxySelf: AccessProxy, *args: object, **kw: object + ) -> Any: + """ + Implementation of all database-proxy methods on objects returned + from C{accessor}. + """ + + async def body() -> Any: + conn = proxySelf.__query_connection__ + styledSQL, styledMap = precomputedSQL[conn.paramstyle] + cur = await conn.cursor() + bound = sig.bind(None, *args, **kw) + await cur.execute(styledSQL, styledMap.queryArguments(bound)) + maybeAgen: Any = self.load(proxySelf, cur) + try: + # there is probably a nicer way to detect aiter-ability + return await maybeAgen + except TypeError: + return maybeAgen + + return MaybeAIterable(body()) + + self.proxyMethod = proxyMethod + setattr(protocolMethod, METADATA_KEY, self) + + @classmethod + def loadFrom(cls, f: object) -> Optional[QueryMetadata]: + """ + Load the query metadata for C{f} if it has any. + """ + self: Optional[QueryMetadata] = getattr(f, METADATA_KEY, None) + return self + + @classmethod + def filterProtocolNamespace( + cls, protocolNamespace: Iterable[Tuple[str, object]] + ) -> Iterable[Tuple[str, QueryMetadata]]: + """ + Load all QueryMetadata + """ + extraneous = [] + for attrname, value in protocolNamespace: + qm = QueryMetadata.loadFrom(value) + if qm is None: + if attrname not in PROTOCOL_IGNORED_ATTRIBUTES: + extraneous.append(attrname) + continue + yield attrname, qm + if extraneous: + raise ExtraneousMethods( + f"non-query/statement methods defined: {extraneous}" + ) + + +def query( + *, + sql: str, + load: Callable[[object, AsyncCursor], A], +) -> Callable[[Callable[P, A]], Callable[P, A]]: + """ + Declare a query method. + """ + qm = QueryMetadata(sql=sql, load=load) + + def decorator(f: Callable[P, A]) -> Callable[P, A]: + qm.setOn(f) + return f + + return decorator + + +def statement( + *, + sql: str, +) -> Callable[ + [Callable[P, Coroutine[Any, Any, None]]], + Callable[P, Coroutine[Any, Any, None]], +]: + """ + Declare a query method. + """ + return query(sql=sql, load=zero) + + +@dataclass +class DBProxy: + """ + Database Proxy + """ + + name: str + transaction: AsyncConnection + + +@dataclass +class QmarkParamstyleMap: + names: List[str] = field(default_factory=list) + + def __getitem__(self, name: str) -> str: + self.names.append(name) + return "?" + + def queryArguments(self, bound: BoundArguments) -> Sequence[object]: + """ + Compute the arguments to the query. + """ + return [bound.arguments[each] for each in self.names] + + +class _EmptyProtocol(Protocol): + """ + Empty protocol for setting a baseline of what attributes to ignore while + metaprogramming. + """ + + +PROTOCOL_IGNORED_ATTRIBUTES = set(_EmptyProtocol.__dict__.keys()) + +styles = { + "qmark": QmarkParamstyleMap, +} + + +@dataclass +class AccessProxy: + """ + Superclass of all access proxies. + """ + + __query_connection__: AsyncConnection + + +def accessor( + accessPatternProtocol: Callable[[], T] +) -> Callable[[AsyncConnection], T]: + """ + Create a factory which binds a database transaction in the form of an + AsyncConnection to a set of declared SQL methods. + """ + return type( + f"{accessPatternProtocol.__name__}DB", + tuple([AccessProxy]), + { + name: metadata.proxyMethod + for name, metadata in QueryMetadata.filterProtocolNamespace( + accessPatternProtocol.__dict__.items() + ) + }, + ) diff --git a/src/klein/storage/dbxs/_dbapi_async_protocols.py b/src/klein/storage/dbxs/_dbapi_async_protocols.py new file mode 100644 index 000000000..974eecf41 --- /dev/null +++ b/src/klein/storage/dbxs/_dbapi_async_protocols.py @@ -0,0 +1,112 @@ +from contextlib import asynccontextmanager +from typing import ( + Any, + AsyncIterator, + Mapping, + Optional, + Sequence, + TypeVar, + Union, +) + +from ..._typing_compat import Protocol +from ._dbapi_types import DBAPIColumnDescription + + +ParamStyle = str + +# Sadly, db-api modules do not restrict themselves in this way, so we can't +# specify the ParamStyle type more precisely, like so: + +# ParamStyle = Literal['qmark', 'numeric', 'named', 'format', 'pyformat'] + +T = TypeVar("T") + + +class AsyncCursor(Protocol): + """ + Asynchronous Cursor Object. + """ + + async def description(self) -> Optional[Sequence[DBAPIColumnDescription]]: + ... + + async def rowcount(self) -> int: + ... + + async def fetchone(self) -> Optional[Sequence[Any]]: + ... + + async def fetchmany( + self, size: Optional[int] = None + ) -> Sequence[Sequence[Any]]: + ... + + async def fetchall(self) -> Sequence[Sequence[Any]]: + ... + + async def execute( + self, + operation: str, + parameters: Union[Sequence[Any], Mapping[str, Any]] = (), + ) -> object: + ... + + async def executemany( + self, __operation: str, __seq_of_parameters: Sequence[Sequence[Any]] + ) -> object: + ... + + async def close(self) -> None: + ... + + +class AsyncConnection(Protocol): + """ + Asynchronous version of a DB-API connection. + """ + + @property + def paramstyle(self) -> ParamStyle: + ... + + async def cursor(self) -> AsyncCursor: + ... + + async def rollback(self) -> None: + ... + + async def commit(self) -> None: + ... + + async def close(self) -> None: + ... + + +class AsyncConnectable(Protocol): + """ + An L{AsyncConnectable} can establish and pool L{AsyncConnection} objects. + """ + + async def connect(self) -> AsyncConnection: + ... + + async def quit(self) -> None: + ... + + +@asynccontextmanager +async def transaction( + connectable: AsyncConnectable, +) -> AsyncIterator[AsyncConnection]: + """ + Connect to a given connection in a context manager. + """ + conn = await connectable.connect() + try: + yield conn + except BaseException: + await conn.rollback() + raise + else: + await conn.commit() diff --git a/src/klein/storage/dbxs/_dbapi_async_twisted.py b/src/klein/storage/dbxs/_dbapi_async_twisted.py new file mode 100644 index 000000000..b3c5f172c --- /dev/null +++ b/src/klein/storage/dbxs/_dbapi_async_twisted.py @@ -0,0 +1,387 @@ +# -*- test-case-name: klein.storage.dbxs.test.test_sync_adapter -*- +""" +Async version of db-api methods which associate each underlying db-api +connection with a specific thread, since some database drivers have issues with +sharing connections and cursors between threads. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from queue import Queue +from threading import Thread +from typing import ( + TYPE_CHECKING, + Any, + Callable, + List, + Mapping, + Optional, + Sequence, + Type, + TypeVar, +) + +from twisted._threads import AlreadyQuit, ThreadWorker +from twisted._threads._ithreads import IExclusiveWorker +from twisted.internet.defer import Deferred +from twisted.python.failure import Failure + +from ._dbapi_async_protocols import ( + AsyncConnectable, + AsyncConnection, + AsyncCursor, + ParamStyle, +) +from ._dbapi_types import DBAPIColumnDescription, DBAPIConnection, DBAPICursor + + +_T = TypeVar("_T") + +F = Callable[[], None] + + +class InvalidConnection(Exception): + """ + The connection has already been closed, or the transaction has already been + committed. + """ + + +def _newThread() -> IExclusiveWorker: + def _startThread(target: Callable[[], None]) -> Thread: + thread = Thread(target=target, daemon=True) + thread.start() + return thread + + return ThreadWorker(_startThread, Queue()) + + +@dataclass +class ExclusiveWorkQueue: + _worker: Optional[IExclusiveWorker] + _deliver: Callable[[F], None] + + def worker(self, invalidate: bool = False) -> IExclusiveWorker: + """ + Assert that the worker should still be present, then return it + (invalidating it if the flag is passed). + """ + if invalidate: + w, self._worker = self._worker, None + else: + w = self._worker + if w is None: + raise AlreadyQuit("cannot quit twice") + return w + + def perform( + self, + work: Callable[[], _T], + ) -> Deferred[_T]: + """ + Perform the given work on the underlying thread, delivering the result + back to the main thread with L{ExclusiveWorkQueue._deliver}. + """ + + deferred: Deferred[_T] = Deferred() + + def workInThread() -> None: + try: + result = work() + except BaseException: + f = Failure() + self._deliver(lambda: deferred.errback(f)) + else: + self._deliver(lambda: deferred.callback(result)) + + self.worker().do(workInThread) + + return deferred + + def quit(self) -> None: + """ + Allow this thread to stop, and invalidate this L{ExclusiveWorkQueue} by + removing its C{_worker} attribute. + """ + self.worker(True).quit() + + def __del__(self) -> None: + """ + When garbage collected make sure we kill off our underlying thread. + """ + if self._worker is None: + return + # might be nice to emit a ResourceWarning here, since __del__ is not a + # good way to clean up resources. + self.quit() + + +@dataclass +class ThreadedCursorAdapter(AsyncCursor): + """ + A cursor that can be interacted with asynchronously. + """ + + _cursor: DBAPICursor + _exclusive: ExclusiveWorkQueue + + async def description(self) -> Optional[Sequence[DBAPIColumnDescription]]: + return await self._exclusive.perform(lambda: self._cursor.description) + + async def rowcount(self) -> int: + return await self._exclusive.perform(lambda: self._cursor.rowcount) + + async def fetchone(self) -> Optional[Sequence[Any]]: + return await self._exclusive.perform(self._cursor.fetchone) + + async def fetchmany( + self, size: Optional[int] = None + ) -> Sequence[Sequence[Any]]: + a = [size] if size is not None else [] + return await self._exclusive.perform(lambda: self._cursor.fetchmany(*a)) + + async def fetchall(self) -> Sequence[Sequence[Any]]: + return await self._exclusive.perform(self._cursor.fetchall) + + async def execute( + self, + operation: str, + parameters: Sequence[Any] | Mapping[str, Any] = (), + ) -> object: + """ + Execute the given statement. + """ + + def query() -> object: + return self._cursor.execute(operation, parameters) + + return await self._exclusive.perform(query) + + async def executemany( + self, __operation: str, __seq_of_parameters: Sequence[Sequence[Any]] + ) -> object: + def query() -> object: + return self._cursor.executemany(__operation, __seq_of_parameters) + + return await self._exclusive.perform(query) + + async def close(self) -> None: + """ + Close the underlying cursor. + """ + await self._exclusive.perform(self._cursor.close) + + +@dataclass +class ThreadedConnectionAdapter: + """ + Asynchronous database connection that binds to a specific thread. + """ + + _connection: Optional[DBAPIConnection] + _exclusive: ExclusiveWorkQueue + paramstyle: ParamStyle + + def _getConnection(self, invalidate: bool = False) -> DBAPIConnection: + """ + Get the connection, raising an exception if it's already been + invalidated. + """ + c = self._connection + assert ( + c is not None + ), "should not be able to get a bad connection via public API" + if invalidate: + self._connection = None + return c + + async def close(self) -> None: + """ + Close the connection if it hasn't been closed yet. + """ + connection = self._getConnection(True) + await self._exclusive.perform(connection.close) + self._exclusive.quit() + + async def cursor(self) -> ThreadedCursorAdapter: + """ + Construct a new async cursor. + """ + c = self._getConnection() + cur = await self._exclusive.perform(c.cursor) + return ThreadedCursorAdapter(cur, self._exclusive) + + async def rollback(self) -> None: + """ + Roll back the current transaction. + """ + c = self._getConnection() + await self._exclusive.perform(c.rollback) + + async def commit(self) -> None: + """ + Roll back the current transaction. + """ + c = self._getConnection() + await self._exclusive.perform(c.commit) + + +@dataclass(eq=False) +class PooledThreadedConnectionAdapter: + """ + Pooled connection adapter that re-adds itself back to the pool upon commit + or rollback. + """ + + _adapter: Optional[ThreadedConnectionAdapter] + _pool: ThreadedConnectionPool + _cursors: List[ThreadedCursorAdapter] + + def _original(self, invalidate: bool) -> ThreadedConnectionAdapter: + """ + Check for validity, return the underlying connection, and then + optionally invalidate this adapter. + """ + a = self._adapter + if a is None: + raise InvalidConnection("The connection has already been closed.") + if invalidate: + self._adapter = None + return a + + @property + def paramstyle(self) -> str: + return self._original(False).paramstyle + + async def cursor(self) -> ThreadedCursorAdapter: + it = await self._original(False).cursor() + self._cursors.append(it) + return it + + async def rollback(self) -> None: + """ + Roll back the transaction, returning the connection to the pool. + """ + a = self._original(True) + try: + await a.rollback() + finally: + await self._pool._checkin(self, a) + + async def _closeCursors(self) -> None: + for cursor in self._cursors: + await cursor.close() + + async def commit(self) -> None: + """ + Commit the transaction, returning the connection to the pool. + """ + await self._closeCursors() + a = self._original(True) + try: + await a.commit() + finally: + await self._pool._checkin(self, a) + + async def close(self) -> None: + """ + Close the underlying connection, removing it from the pool. + """ + await self._closeCursors() + await self._original(True).close() + + +@dataclass(eq=False) +class ThreadedConnectionPool: + """ + Database engine and connection pool. + """ + + _connectCallable: Callable[[], DBAPIConnection] + paramstyle: ParamStyle + _idleMax: int + _createWorker: Callable[[], IExclusiveWorker] + _deliver: Callable[[Callable[[], None]], None] + _idlers: List[ThreadedConnectionAdapter] = field(default_factory=list) + _active: List[PooledThreadedConnectionAdapter] = field(default_factory=list) + + async def connect(self) -> PooledThreadedConnectionAdapter: + """ + Checkout a new connection from the pool, connecting to the database and + opening a thread first if necessary. + """ + if self._idlers: + conn = self._idlers.pop() + else: + e = ExclusiveWorkQueue(self._createWorker(), self._deliver) + conn = ThreadedConnectionAdapter( + await e.perform(self._connectCallable), + e, + self.paramstyle, + ) + txn = PooledThreadedConnectionAdapter(conn, self, []) + self._active.append(txn) + return txn + + async def _checkin( + self, + txn: PooledThreadedConnectionAdapter, + connection: ThreadedConnectionAdapter, + ) -> None: + """ + Check a connection back in to the pool, closing and discarding it. + """ + self._active.remove(txn) + if len(self._idlers) < self._idleMax: + self._idlers.append(connection) + else: + await connection.close() + + async def quit(self) -> None: + """ + Close all outstanding connections and shut down the underlying + threadpool. + """ + self._idleMax = 0 + while self._active: + await self._active[0].rollback() + + while self._idlers: + await self._idlers.pop().close() + + +def adaptSynchronousDriver( + connectCallable: Callable[[], DBAPIConnection], + paramstyle: ParamStyle, + *, + createWorker: Optional[Callable[[], IExclusiveWorker]] = None, + callFromThread: Optional[Callable[[F], None]] = None, + maxIdleConnections: int = 5, +) -> AsyncConnectable: + """ + Adapt a synchronous DB-API driver to be an L{AsyncConnectable}. + """ + if callFromThread is None: + reactor: Any + from twisted.internet import reactor + + callFromThread = reactor.callFromThread + + if createWorker is None: + createWorker = _newThread + + return ThreadedConnectionPool( + connectCallable, + paramstyle, + maxIdleConnections, + createWorker, + callFromThread, + ) + + +if TYPE_CHECKING: + _1: Type[AsyncCursor] = ThreadedCursorAdapter + _2: Type[AsyncConnection] = ThreadedConnectionAdapter + _4: Type[AsyncConnection] = PooledThreadedConnectionAdapter + _3: Type[AsyncConnectable] = ThreadedConnectionPool diff --git a/src/klein/storage/dbxs/_dbapi_types.py b/src/klein/storage/dbxs/_dbapi_types.py new file mode 100644 index 000000000..f8b63a6af --- /dev/null +++ b/src/klein/storage/dbxs/_dbapi_types.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Any, Mapping, Optional, Sequence, Tuple, Union + +from ..._typing_compat import Protocol + + +# PEP 249 Database API 2.0 Types +# https://www.python.org/dev/peps/pep-0249/ + + +DBAPITypeCode = Optional[Any] + +DBAPIColumnDescription = Tuple[ + str, + DBAPITypeCode, + Optional[int], + Optional[int], + Optional[int], + Optional[int], + Optional[bool], +] + + +class DBAPIConnection(Protocol): + def close(self) -> object: + ... + + def commit(self) -> object: + ... + + def rollback(self) -> Any: + ... + + def cursor(self) -> DBAPICursor: + ... + + +class DBAPICursor(Protocol): + arraysize: int + + @property + def description(self) -> Optional[Sequence[DBAPIColumnDescription]]: + ... + + @property + def rowcount(self) -> int: + ... + + def close(self) -> object: + ... + + def execute( + self, + operation: str, + parameters: Union[Sequence[Any], Mapping[str, Any]] = ..., + ) -> object: + ... + + def executemany( + self, __operation: str, __seq_of_parameters: Sequence[Sequence[Any]] + ) -> object: + ... + + def fetchone(self) -> Optional[Sequence[Any]]: + ... + + def fetchmany(self, __size: int = ...) -> Sequence[Sequence[Any]]: + ... + + def fetchall(self) -> Sequence[Sequence[Any]]: + ... diff --git a/src/klein/storage/dbxs/_testing.py b/src/klein/storage/dbxs/_testing.py new file mode 100644 index 000000000..5c02d8845 --- /dev/null +++ b/src/klein/storage/dbxs/_testing.py @@ -0,0 +1,138 @@ +# -*- test-case-name: klein.storage.dbxs.test -*- +from __future__ import annotations + +import sqlite3 +from dataclasses import dataclass +from typing import Any, Callable, Coroutine, List, TypeVar +from uuid import uuid4 + +from twisted._threads._ithreads import IExclusiveWorker +from twisted._threads._memory import createMemoryWorker +from twisted.internet.defer import Deferred +from twisted.trial.unittest import SynchronousTestCase + +from ._dbapi_types import DBAPIConnection +from .dbapi_async import AsyncConnectable, adaptSynchronousDriver + + +def sqlite3Connector() -> Callable[[], DBAPIConnection]: + """ + Create an in-memory shared-cache SQLite3 database and return a 0-argument + callable that will connect to that database. + """ + uri = f"file:{str(uuid4())}?mode=memory&cache=shared" + + held = None + + def connect() -> DBAPIConnection: + # This callable has to hang on to a connection to the underlying SQLite + # data structures, otherwise its schema and shared cache disappear as + # soon as it's garbage collected. This 'nonlocal' stateemnt adds it to + # the closure, which keeps the reference after it's created. + nonlocal held + return sqlite3.connect(uri, uri=True) + + held = connect() + return connect + + +@dataclass +class MemoryPool: + """ + An in-memory connection pool to an in-memory SQLite database which can be + controlled a single operation at a time. Each operation that would + normally be asynchronoulsy dispatched to a thread can be invoked with the + L{MemoryPool.pump} and L{MemoryPool.flush} methods. + + @ivar connectable: The L{AsyncConnectable} to be passed to the system under + test. + """ + + connectable: AsyncConnectable + _performers: List[Callable[[], bool]] + + def additionalPump(self, f: Callable[[], bool]) -> None: + """ + Add an additional callable to be called by L{MemoryPool.pump} and + L{MemoryPool.flush}. This can be used to interleave other sources of + in-memory event completion to allow test coroutines to complete, such + as needing to call L{StubTreq.flush}. + """ + self._performers.append(f) + + def pump(self) -> bool: + """ + Perform one step of pending work. + + @return: True if any work was performed and False if no work was left. + """ + for performer in self._performers: + if performer(): + return True + return False + + def flush(self) -> int: + """ + Perform all outstanding steps of work. + + @return: a count of the number of steps of work performed. + """ + steps = 0 + while self.pump(): + steps += 1 + return steps + + @classmethod + def new(cls) -> MemoryPool: + """ + Create a synchronous memory connection pool. + """ + performers = [] + + def createWorker() -> IExclusiveWorker: + worker: IExclusiveWorker + # note: createMemoryWorker actually returns IWorker, better type + # annotations may require additional shenanigans + worker, perform = createMemoryWorker() + performers.append(perform) + return worker + + return MemoryPool( + adaptSynchronousDriver( + sqlite3Connector(), + sqlite3.paramstyle, + createWorker=createWorker, + callFromThread=lambda f: f(), + maxIdleConnections=10, + ), + performers, + ) + + +AnyTestCase = TypeVar("AnyTestCase", bound=SynchronousTestCase) +syncAsyncTest = Callable[ + [AnyTestCase, MemoryPool], + Coroutine[Any, Any, None], +] +regularTest = Callable[[AnyTestCase], None] + + +def immediateTest() -> ( + Callable[[syncAsyncTest[AnyTestCase]], regularTest[AnyTestCase]] +): + """ + Decorate an C{async def} test that expects a coroutine. + """ + + def decorator(decorated: syncAsyncTest[AnyTestCase]) -> regularTest: + def regular(self: AnyTestCase) -> None: + pool = MemoryPool.new() + d = Deferred.fromCoroutine(decorated(self, pool)) + self.assertNoResult(d) + while pool.flush(): + pass + self.successResultOf(d) + + return regular + + return decorator diff --git a/src/klein/storage/dbxs/dbapi_async.py b/src/klein/storage/dbxs/dbapi_async.py new file mode 100644 index 000000000..010c051cb --- /dev/null +++ b/src/klein/storage/dbxs/dbapi_async.py @@ -0,0 +1,21 @@ +""" +Minimal asynchronous mapping of DB-API 2.0 interfaces, along with tools to +""" + +from ._dbapi_async_protocols import ( + AsyncConnectable, + AsyncConnection, + AsyncCursor, + transaction, +) +from ._dbapi_async_twisted import InvalidConnection, adaptSynchronousDriver + + +__all__ = [ + "InvalidConnection", + "AsyncConnection", + "AsyncConnectable", + "AsyncCursor", + "adaptSynchronousDriver", + "transaction", +] diff --git a/src/klein/storage/dbxs/dbapi_sync.py b/src/klein/storage/dbxs/dbapi_sync.py new file mode 100644 index 000000000..3634582e2 --- /dev/null +++ b/src/klein/storage/dbxs/dbapi_sync.py @@ -0,0 +1,14 @@ +from ._dbapi_types import ( + DBAPIColumnDescription, + DBAPIConnection, + DBAPICursor, + DBAPITypeCode, +) + + +__all__ = [ + "DBAPIConnection", + "DBAPITypeCode", + "DBAPICursor", + "DBAPIColumnDescription", +] diff --git a/src/klein/storage/dbxs/test/__init__.py b/src/klein/storage/dbxs/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/klein/storage/dbxs/test/test_access.py b/src/klein/storage/dbxs/test/test_access.py new file mode 100644 index 000000000..d67bd8376 --- /dev/null +++ b/src/klein/storage/dbxs/test/test_access.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +from twisted.trial.unittest import SynchronousTestCase as TestCase + +from ...._typing_compat import Protocol +from .. import ( + ExtraneousMethods, + NotEnoughResults, + ParamMismatch, + TooManyResults, + accessor, + maybe, + one, + query, + statement, +) +from ..dbapi_async import AsyncConnection, transaction +from ..testing import MemoryPool, immediateTest + + +# Trying to stick to the public API for what we're testing; no underscores here. + + +@dataclass +class Foo: + db: FooAccessPattern + bar: int + baz: int + + +class FooAccessPattern(Protocol): + @query(sql="select bar, baz from foo where bar = {bar}", load=one(Foo)) + async def getFoo(self, bar: int) -> Foo: + ... + + @query(sql="select bar, baz from foo where bar = {bar}", load=maybe(Foo)) + async def maybeFoo(self, bar: int) -> Optional[Foo]: + ... + + @query(sql="select bar, baz from foo where baz = {baz}", load=one(Foo)) + async def oneFooByBaz(self, baz: int) -> Foo: + ... + + @query(sql="select bar, baz from foo where baz = {baz}", load=maybe(Foo)) + async def maybeFooByBaz(self, baz: int) -> Optional[Foo]: + ... + + @statement(sql="insert into foo (baz) values ({baz})") + async def newFoo(self, baz: int) -> None: + """ + Create a new C{Foo} + """ + + @statement(sql="select * from foo") + async def oopsQueryNotStatement(self) -> None: + """ + Oops, it's a query, not a statement, it returns values. + """ + + @query( + sql="insert into foo (baz) values ({baz}) returning bar, baz", + load=one(Foo), + ) + async def newReturnFoo(self, baz: int) -> Foo: + """ + Create a new C{Foo} and return it. + """ + + +accessFoo = accessor(FooAccessPattern) + + +async def schemaAndData(c: AsyncConnection) -> None: + """ + Create the schema for 'foo' and insert some sample data. + """ + cur = await c.cursor() + for stmt in """ + create table foo (bar integer primary key autoincrement, baz int); + insert into foo values (1, 3); + insert into foo values (2, 4); + """.split( + ";" + ): + await cur.execute(stmt) + + +class AccessTestCase(TestCase): + """ + Tests for L{accessor} and its associated functions + """ + + @immediateTest() + async def test_happyPath(self, pool: MemoryPool) -> None: + """ + Declaring a protocol with a query and executing it + """ + async with transaction(pool.connectable) as c: + await schemaAndData(c) + db = accessFoo(c) + result = await db.getFoo(1) + result2 = await db.maybeFoo(1) + self.assertEqual(result, Foo(db, 1, 3)) + self.assertEqual(result, result2) + + def test_argumentExhaustiveness(self) -> None: + """ + If a query does not use all of its arguments, or the function does not + specify all the arguments that a function uses, it will raise an + exception during definition. + """ + with self.assertRaises(ParamMismatch) as pm: + + class MissingBar(Protocol): + @statement(sql="fake sql {bar}") + async def someUnused(self) -> None: + ... + + self.assertIn("bar", str(pm.exception)) + self.assertIn("someUnused", str(pm.exception)) + with self.assertRaises(ParamMismatch): + + class DoesntUseBar(Protocol): + @statement(sql="fake sql") + async def someMissing(self, bar: str) -> None: + ... + + @immediateTest() + async def test_tooManyResults(self, pool: MemoryPool) -> None: + """ + If there are too many results for a L{one} query, then a + L{TooManyResults} exception is raised. + """ + async with transaction(pool.connectable) as c: + await schemaAndData(c) + cur = await c.cursor() + await cur.execute("insert into foo (baz) values (3)") + await cur.execute("insert into foo (baz) values (3)") + db = accessFoo(c) + with self.assertRaises(TooManyResults): + await db.oneFooByBaz(3) + with self.assertRaises(TooManyResults): + await db.maybeFooByBaz(3) + + def test_brokenProtocol(self) -> None: + """ + Using L{accessor} on a protocol with unrelated methods raises a . + """ + + class NonAccessPatternProtocol(Protocol): + def randomNonQueryMethod(self) -> None: + ... + + with self.assertRaises(ExtraneousMethods) as em: + accessor(NonAccessPatternProtocol) + self.assertIn("randomNonQueryMethod", str(em.exception)) + + @immediateTest() + async def test_notEnoughResults(self, pool: MemoryPool) -> None: + """ + If there are too many results for a L{one} query, then a + L{NotEnoughResults} exception is raised. + """ + async with transaction(pool.connectable) as c: + cur = await c.cursor() + await schemaAndData(c) + await cur.execute("delete from foo") + db = accessFoo(c) + with self.assertRaises(NotEnoughResults): + await db.getFoo(1) + self.assertIs(await db.maybeFoo(1), None) + + @immediateTest() + async def test_insertStatementWithReturn(self, pool: MemoryPool) -> None: + """ + DML statements can use RETURNING to return values. + """ + async with transaction(pool.connectable) as c: + await schemaAndData(c) + db = accessFoo(c) + self.assertEqual(await db.newReturnFoo(100), Foo(db, 3, 100)) + + @immediateTest() + async def test_statementHasNoResult(self, pool: MemoryPool) -> None: + """ + The L{statement} decorator gives a result. + """ + async with transaction(pool.connectable) as c: + await schemaAndData(c) + db = accessFoo(c) + nothing = await db.newFoo(7) # type:ignore[func-returns-value] + self.assertIs(nothing, None) + + @immediateTest() + async def test_statementWithResultIsError(self, pool: MemoryPool) -> None: + """ + The L{statement} decorator gives a result. + """ + async with transaction(pool.connectable) as c: + await schemaAndData(c) + db = accessFoo(c) + with self.assertRaises(TooManyResults) as tmr: + await db.oopsQueryNotStatement() + self.assertIn("should not return", str(tmr.exception)) diff --git a/src/klein/storage/dbxs/test/test_sync_adapter.py b/src/klein/storage/dbxs/test/test_sync_adapter.py new file mode 100644 index 000000000..3c0fb6341 --- /dev/null +++ b/src/klein/storage/dbxs/test/test_sync_adapter.py @@ -0,0 +1,613 @@ +""" +Tests for running synchronous DB-API drivers within threads. +""" +from __future__ import annotations + +import sqlite3 +from contextlib import contextmanager +from dataclasses import dataclass, field +from itertools import count +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from zope.interface import implementer + +from twisted._threads import AlreadyQuit +from twisted._threads._ithreads import IExclusiveWorker +from twisted.internet.defer import Deferred +from twisted.trial.unittest import TestCase + +from ...._util import eagerDeferredCoroutine +from .._dbapi_async_twisted import ExclusiveWorkQueue, ThreadedConnectionPool +from .._testing import sqlite3Connector +from ..dbapi_async import ( + AsyncConnectable, + InvalidConnection, + adaptSynchronousDriver, + transaction, +) +from ..dbapi_sync import DBAPIColumnDescription, DBAPIConnection, DBAPICursor + + +pretendThreadID = 0 + + +@contextmanager +def pretendToBeInThread(newThreadID: int) -> Iterator[None]: + """ + Pretend to be in the given thread while executing. + """ + global pretendThreadID + pretendThreadID = newThreadID + try: + yield + finally: + pretendThreadID = newThreadID + + +@dataclass +class FakeDBAPICursor: + """ + Fake PEP 249 cursor. + """ + + connection: FakeDBAPIConnection + arraysize: int + cursorID: int = field(default_factory=count().__next__) + + @property + def operationsByThread(self) -> List[Tuple[str, int, int]]: + """ + Delegate to connection. + """ + return self.connection.operationsByThread + + @property + def connectionID(self) -> int: + """ + Delegate to connection. + """ + return self.connection.connectionID + + @property + def description(self) -> Optional[Sequence[DBAPIColumnDescription]]: + # note sqlite actually pads out the response with Nones like this + self.operationsByThread.append( + ("description", self.connectionID, pretendThreadID) + ) + return [("stub", None, None, None, None, None, None)] + + @property + def rowcount(self) -> int: + self.operationsByThread.append( + ("rowcount", self.connectionID, pretendThreadID) + ) + return 0 + + def close(self) -> object: + self.operationsByThread.append( + ("close", self.connectionID, pretendThreadID) + ) + return None + + def execute( + self, + operation: str, + parameters: Union[Sequence[Any], Mapping[str, Any]] = (), + ) -> object: + self.operationsByThread.append( + ("execute", self.connectionID, pretendThreadID) + ) + return None + + def executemany( + self, __operation: str, __seq_of_parameters: Sequence[Sequence[Any]] + ) -> object: + self.operationsByThread.append( + ("executemany", self.connectionID, pretendThreadID) + ) + return None + + def fetchone(self) -> Optional[Sequence[Any]]: + self.operationsByThread.append( + ("fetchone", self.connectionID, pretendThreadID) + ) + return None + + def fetchmany(self, __size: int = 0) -> Sequence[Sequence[Any]]: + self.operationsByThread.append( + ("fetchmany", self.connectionID, pretendThreadID) + ) + return [] + + def fetchall(self) -> Sequence[Sequence[Any]]: + self.operationsByThread.append( + ("fetchall", self.connectionID, pretendThreadID) + ) + return [] + + +@dataclass +class FakeDBAPIConnection: + """ + Fake PEP 249 connection. + """ + + cursors: list[FakeDBAPICursor] + # (operation, connectionID, threadID) + operationsByThread: list[tuple[str, int, int]] + connectionID: int = field(default_factory=count().__next__) + + def close(self) -> None: + self.operationsByThread.append( + ("connection.close", self.connectionID, pretendThreadID) + ) + return None + + def commit(self) -> None: + self.operationsByThread.append( + ("commit", self.connectionID, pretendThreadID) + ) + return None + + def rollback(self) -> Any: + self.operationsByThread.append( + ("rollback", self.connectionID, pretendThreadID) + ) + + def cursor(self) -> DBAPICursor: + self.operationsByThread.append( + ("cursor", self.connectionID, pretendThreadID) + ) + cursor = FakeDBAPICursor(self, 3) + self.cursors.append(cursor) + return cursor + + +if TYPE_CHECKING: + _1: Type[DBAPICursor] = FakeDBAPICursor + _2: Type[DBAPIConnection] = FakeDBAPIConnection + + +def realThreadedAdapter(testCase: TestCase) -> AsyncConnectable: + """ + Create an AsyncConnectable using real threads and scheduling its + non-threaded callbacks on the Twisted reactor, suitable for using in a + real-Deferred-returning integration test. + """ + memdb = sqlite3Connector() + scon = memdb() + + scur = scon.cursor() + scur.execute( + """ + create table sample (testcol int primary key, testcol2 str) + """ + ) + scur.execute( + """ + insert into sample values (1, 'hello'), (2, 'goodbye') + """ + ) + scon.commit() + + pool = adaptSynchronousDriver(memdb, sqlite3.paramstyle) + testCase.addCleanup(lambda: Deferred.fromCoroutine(pool.quit())) + return pool + + +Thunk = Callable[[], None] + + +def wrap(t: Thunk, threadID: int) -> Thunk: + """ + Wrap the given thunk in a fake thread ID + """ + + def wrapped() -> None: + with pretendToBeInThread(threadID): + t() + + return wrapped + + +@implementer(IExclusiveWorker) +@dataclass +class FakeExclusiveWorker: + queue: List[Thunk] + threadID: int = field(default_factory=count().__next__) + quitted: bool = False + + def do(self, work: Callable[[], None]) -> None: + assert ( + not self.quitted + ), "we should never schedule work on a quitted worker" + self.queue.append(wrap(work, self.threadID)) + + def quit(self) -> None: + """ + Exit & clean up. + """ + self.quitted = True + + +class SampleError(Exception): + """ + An error occurred. + """ + + +class ResourceManagementTests(TestCase): + """ + Tests to make sure that various thread resources are managed correctly. + """ + + def setUp(self) -> None: + """ + set up a thread pool that pretends to start threads + """ + # this queue of work is both "in threads" and "not in threads" work; + # all "in threads" work is wrapped up in a thing that sets the global + # L{pretendThreadID}; all "main thread" work has it set to 0. + self.queue: List[Thunk] = [] + self.cursors: List[FakeDBAPICursor] = [] + self.dbapiops: List[Tuple[str, int, int]] = [] + self.threads: List[FakeExclusiveWorker] = [] + + def newWorker() -> IExclusiveWorker: + w = FakeExclusiveWorker(self.queue) + self.threads.append(w) + return w + + def makeConnection() -> DBAPIConnection: + conn = FakeDBAPIConnection(self.cursors, self.dbapiops) + self.dbapiops.append( + ("connect", conn.connectionID, pretendThreadID) + ) + return conn + + self.poolInternals = ThreadedConnectionPool( + makeConnection, + "qmark", + 3, + newWorker, + self.queue.append, + ) + self.pool: AsyncConnectable = self.poolInternals + + def flush(self) -> None: + """ + Perform all outstanding "threaded" work. + """ + while self.queue: + self.queue.pop(0)() + + def test_allOperations(self) -> None: + """ + All the DB-API operations are wrapped. + """ + + async def dostuff() -> None: + con = await self.pool.connect() + cur = await con.cursor() + self.assertEqual( + await cur.description(), + [("stub", None, None, None, None, None, None)], + ) + self.assertEqual(await cur.rowcount(), 0) + await cur.execute("test expr", ["some", "params"]), [] + await cur.executemany( + "lots of operations", [["parameter", "seq"], ["etc", "etc"]] + ) + self.assertIs(await cur.fetchone(), None) + self.assertEqual(await cur.fetchmany(7), []) + self.assertEqual(await cur.fetchall(), []) + await cur.close() + await con.commit() + # already committed, so we need a new connection to test rollback + con2 = await self.pool.connect() + await con2.rollback() + + d1 = Deferred.fromCoroutine(dostuff()) + self.flush() + self.successResultOf(d1) + self.assertOperations( + [ + "connect", + "cursor", + "description", + "rowcount", + "execute", + "executemany", + "fetchone", + "fetchmany", + "fetchall", + "close", + "close", + "commit", + "rollback", + ], + ) + + def test_connectionClose(self) -> None: + """ + As opposed to committing or rolling back, closing a connection will + remove it from the pool entirely. + """ + + async def dostuff() -> None: + con = await self.pool.connect() + await con.close() + with self.assertRaises(InvalidConnection): + await con.cursor() + with self.assertRaises(InvalidConnection): + await con.close() + + d1 = Deferred.fromCoroutine(dostuff()) + self.flush() + self.successResultOf(d1) + self.assertOperations(["connect", "connection.close"]) + self.assertEqual(self.poolInternals._idlers, []) + # The associated thread is also quit. + self.assertEqual([thread.quitted for thread in self.threads], [True]) + + def assertOperations(self, expectedOperations: Sequence[str]) -> None: + """ + Assert that DB-API would have performed the named operations. + """ + self.assertEqual( + expectedOperations, [first for first, _, _ in self.dbapiops] + ) + + def test_inCorrectThread(self) -> None: + """ + Each connection's operations are executed on a dedicated thread. + """ + + async def dostuff() -> None: + con = await self.pool.connect() + cur = await con.cursor() + await cur.execute("select * from what") + await con.commit() + + d1 = Deferred.fromCoroutine(dostuff()) + d2 = Deferred.fromCoroutine(dostuff()) + self.assertNoResult(d1) + self.assertNoResult(d2) + self.flush() + + self.successResultOf(d1) + self.successResultOf(d2) + + async def cleanup() -> None: + await self.pool.quit() + + cleanedup = Deferred.fromCoroutine(cleanup()) + self.flush() + + self.successResultOf(cleanedup) + self.assertEqual((self.poolInternals._idlers), []) + + threadToConnection: Dict[int, int] = {} + connectionToThread: Dict[int, int] = {} + confirmed = 0 + + for _, connectionID, threadID in self.dbapiops: + if threadID in threadToConnection: + self.assertEqual(threadToConnection[threadID], connectionID) + self.assertEqual(connectionToThread[connectionID], threadID) + confirmed += 1 + else: + threadToConnection[threadID] = connectionID + connectionToThread[connectionID] = threadID + # ops = ['connect', 'cursor', 'execute', 'close', 'commit', 'close'] + # expected = (len(ops) * 2) - 2 + expected = 10 + self.assertEqual(confirmed, expected) + self.assertEqual(len(threadToConnection), 2) + + def test_basicPooling(self) -> None: + """ + When a pooled connection is committed or rolled back, we will + invalidate it and won't allocate additional underlying connections. + """ + + async def t1() -> None: + con = await self.pool.connect() + await con.commit() + with self.assertRaises(InvalidConnection): + await con.cursor() + con = await self.pool.connect() + await con.rollback() + with self.assertRaises(InvalidConnection): + await con.cursor() + + d = Deferred.fromCoroutine(t1()) + self.flush() + self.successResultOf(d) + self.assertEqual( + len({connectionID for _, connectionID, _ in self.dbapiops}), 1 + ) + + def test_tooManyConnections(self) -> None: + """ + When we exceed the idle-max of the pool, we close connections + immediately as they are returned. + """ + + async def t1() -> None: + c1 = await self.pool.connect() + c2 = await self.pool.connect() + c3 = await self.pool.connect() + c4 = await self.pool.connect() + await c1.commit() + await c2.commit() + await c3.commit() + await c4.commit() + + d = Deferred.fromCoroutine(t1()) + self.flush() + self.successResultOf(d) + self.assertEqual(len(self.poolInternals._idlers), 3) + self.assertOperations( + [ + *["connect"] * 4, + *["commit"] * 4, + "connection.close", + ] + ) + + def test_transactionContextManager(self) -> None: + """ + C{with transaction(pool)} results in an async context manager which + will commit when exited normally and rollback when exited with an + exception. + """ + + async def t1() -> None: + # committed + async with transaction(self.pool) as t: + await (await t.cursor()).execute("hello world") + + # rolled back + with self.assertRaises(SampleError): + async with transaction(self.pool) as t2: + await (await t2.cursor()).execute("a") + raise SampleError() + + started = Deferred.fromCoroutine(t1()) + self.flush() + self.successResultOf(started) + self.assertOperations( + [ + "connect", + "cursor", + "execute", + "close", + "commit", + "cursor", + "execute", + "rollback", + ] + ) + + def test_poolQuit(self) -> None: + """ + When the pool is shut down, all idlers are closed, and all active + connections invalidated. + """ + + async def t1() -> None: + c1 = await self.pool.connect() + c2 = await self.pool.connect() + await self.pool.quit() + with self.assertRaises(InvalidConnection): + await c1.cursor() + with self.assertRaises(InvalidConnection): + await c2.cursor() + + d = Deferred.fromCoroutine(t1()) + self.flush() + self.successResultOf(d) + self.assertOperations( + [ + "connect", + "connect", + "rollback", + "connection.close", + "rollback", + "connection.close", + ] + ) + self.assertEqual(self.poolInternals._idlers, []) + self.assertEqual(len(self.threads), 2) + self.assertEqual(self.threads[0].quitted, True) + self.assertEqual(self.threads[1].quitted, True) + + +class InternalSafetyTests(TestCase): + """ + Tests for internal safety mechanisms; states which I{should} be unreachable + via the public API but should nonetheless be reported. + """ + + def test_queueQuit(self) -> None: + """ + L{ExclusiveWorkQueue} should raise L{AlreadyQuit} when interacted with + after C{quit}. + """ + stuff: List[Callable[[], None]] = [] + ewc = ExclusiveWorkQueue(FakeExclusiveWorker(stuff), stuff.append) + ewc.quit() + with self.assertRaises(AlreadyQuit): + ewc.quit() + with self.assertRaises(AlreadyQuit): + ewc.perform(int) + self.assertEqual(stuff, []) + + +class SyncAdapterTests(TestCase): + """ + Integration tests for L{adaptSynchronousDriver}. + """ + + @eagerDeferredCoroutine + async def test_execAndFetch(self) -> None: + """ + Integration test: can we use an actual DB-API module, with real threads? + """ + pool = realThreadedAdapter(self) + con = await pool.connect() + cur = await con.cursor() + + query = """ + select * from sample order by testcol asc + """ + await cur.execute(query) + self.assertEqual(await cur.fetchall(), [(1, "hello"), (2, "goodbye")]) + await cur.execute( + """ + insert into sample values (3, 'more'), (4, 'even more') + """ + ) + await cur.execute(query) + self.assertEqual( + await cur.fetchmany(3), [(1, "hello"), (2, "goodbye"), (3, "more")] + ) + self.assertEqual(await cur.fetchmany(3), [(4, "even more")]) + + @eagerDeferredCoroutine + async def test_errors(self) -> None: + """ + Integration test: do errors propagate? + """ + pool = realThreadedAdapter(self) + con = await pool.connect() + cur = await con.cursor() + later = cur.execute("select * from nonexistent") + with self.assertRaises(sqlite3.OperationalError) as oe: + await later + self.assertIn("nonexistent", str(oe.exception)) + + @eagerDeferredCoroutine + async def test_invalidateAfterCommit(self) -> None: + """ + Connections will be invalidated after they've been committed. + """ + pool = realThreadedAdapter(self) + con = await pool.connect() + await con.commit() + with self.assertRaises(InvalidConnection): + await con.cursor() diff --git a/src/klein/storage/dbxs/testing.py b/src/klein/storage/dbxs/testing.py new file mode 100644 index 000000000..a63c5474a --- /dev/null +++ b/src/klein/storage/dbxs/testing.py @@ -0,0 +1,15 @@ +""" +Testing support for L{klein.storage.dbxs}. + +L{MemoryPool} creates a synchronous, in-memory SQLite database that can be used +for testing anything that needs an +L{klein.storage.dbxs.dbapi_async.AsyncConnectable}. +""" + +from ._testing import MemoryPool, immediateTest + + +__all__ = [ + "MemoryPool", + "immediateTest", +] diff --git a/src/klein/storage/memory.py b/src/klein/storage/memory.py deleted file mode 100644 index e68e20090..000000000 --- a/src/klein/storage/memory.py +++ /dev/null @@ -1,7 +0,0 @@ -from ._memory import MemorySessionStore, declareMemoryAuthorizer - - -__all__ = [ - "declareMemoryAuthorizer", - "MemorySessionStore", -] diff --git a/src/klein/storage/memory/__init__.py b/src/klein/storage/memory/__init__.py new file mode 100644 index 000000000..a3d50e2f1 --- /dev/null +++ b/src/klein/storage/memory/__init__.py @@ -0,0 +1,16 @@ +""" +In-memory implementations of L{klein.interfaces.ISessionStore} and +L{klein.interfaces.ISimpleAccount}, usable for testing or for ephemeral +applications with static authentication requirements rather than real account +databases. +""" + +from ._memory import MemorySessionStore, declareMemoryAuthorizer +from ._memory_users import MemoryAccountStore + + +__all__ = [ + "declareMemoryAuthorizer", + "MemorySessionStore", + "MemoryAccountStore", +] diff --git a/src/klein/storage/_memory.py b/src/klein/storage/memory/_memory.py similarity index 65% rename from src/klein/storage/_memory.py rename to src/klein/storage/memory/_memory.py index 6fa1c89cd..b5251690a 100644 --- a/src/klein/storage/_memory.py +++ b/src/klein/storage/memory/_memory.py @@ -1,9 +1,22 @@ # -*- test-case-name: klein.test.test_memory -*- +from __future__ import annotations + from binascii import hexlify from os import urandom -from typing import Any, Callable, Dict, Iterable, Type, cast +from typing import ( + Any, + Callable, + Dict, + Iterable, + Optional, + Type, + TypeVar, + Union, + cast, +) import attr +from attrs import define, field from zope.interface import Interface, implementer from twisted.internet.defer import Deferred, fail, succeed @@ -16,12 +29,16 @@ SessionMechanism, ) +from ..._isession import AuthorizationMap +from ..._typing_compat import Protocol +from ..._util import eagerDeferredCoroutine + -_authCB = Callable[[Type[Interface], ISession, Componentized], Any] +_authCB = Callable[[Type[object], ISession, Componentized], Any] @implementer(ISession) -@attr.s(auto_attribs=True) +@define class MemorySession: """ An in-memory session. @@ -31,59 +48,80 @@ class MemorySession: isConfidential: bool authenticatedBy: SessionMechanism _authorizationCallback: _authCB - _components: Componentized = attr.ib(factory=Componentized) + _components: Componentized = field(factory=Componentized) - def authorize(self, interfaces: Iterable[Type[Interface]]) -> Deferred: + @eagerDeferredCoroutine + async def authorize( + self, interfaces: Iterable[Type[object]] + ) -> AuthorizationMap: """ Authorize each interface by calling back to the session store's authorization callback. """ - result = {} + result: AuthorizationMap = {} # type:ignore[assignment] for interface in interfaces: provider = self._authorizationCallback( interface, self, self._components ) + if isinstance(provider, Deferred): + provider = await provider if provider is not None: result[interface] = provider - return succeed(result) + return result + + +T = TypeVar("T") -class _MemoryAuthorizerFunction: +class _MemoryAuthorizerFunction(Protocol[T]): """ Type shadow for function with the given attribute. """ - __memoryAuthInterface__: Type[Interface] = None # type: ignore[assignment] + __memoryAuthInterface__: Type[T] def __call__( - self, interface: Type[Interface], session: ISession, data: Componentized - ) -> Any: + self, interface: Type[object], session: ISession, data: Componentized + ) -> Union[Deferred[Optional[T]], T, None]: """ Return a provider of the given interface. """ -_authFn = Callable[[Type[Interface], ISession, Componentized], Any] +_authFn = Callable[[Type[object], ISession, Componentized], Any] def declareMemoryAuthorizer( forInterface: Type[Interface], -) -> Callable[[Callable], _MemoryAuthorizerFunction]: +) -> Callable[ + [ + Callable[ + [Type[T], ISession, Componentized], + Union[Deferred[Optional[T]], T, None], + ] + ], + _MemoryAuthorizerFunction[T], +]: """ Declare that the decorated function is an authorizer usable with a memory session store. """ - def decorate(decoratee: _authFn) -> _MemoryAuthorizerFunction: - decoratee = cast(_MemoryAuthorizerFunction, decoratee) - decoratee.__memoryAuthInterface__ = forInterface - return decoratee + def decorate( + decoratee: Callable[ + [Type[T], ISession, Componentized], + Union[Deferred[Optional[T]], T, None], + ] + ) -> _MemoryAuthorizerFunction[T]: + asAuthorizer = cast(_MemoryAuthorizerFunction, decoratee) + asAuthorizer.__memoryAuthInterface__ = forInterface + return asAuthorizer return decorate def _noAuthorization( - interface: Type[Interface], session: ISession, data: Componentized + interface: Type[object], session: ISession, data: Componentized ) -> None: return None @@ -98,7 +136,7 @@ class MemorySessionStore: @classmethod def fromAuthorizers( cls, authorizers: Iterable[_MemoryAuthorizerFunction] - ) -> "MemorySessionStore": + ) -> MemorySessionStore: """ Create a L{MemorySessionStore} from a collection of callbacks which can do authorization. @@ -110,7 +148,7 @@ def fromAuthorizers( interfaceToCallable[specifiedInterface] = authorizer def authorizationCallback( - interface: Type[Interface], session: ISession, data: Componentized + interface: Type[object], session: ISession, data: Componentized ) -> Any: return interfaceToCallable.get(interface, _noAuthorization)( interface, session, data @@ -129,7 +167,7 @@ def _storage(self, isConfidential: bool) -> Dict[str, Any]: def newSession( self, isConfidential: bool, authenticatedBy: SessionMechanism - ) -> Deferred: + ) -> Deferred[ISession]: storage = self._storage(isConfidential) identifier = hexlify(urandom(32)).decode("ascii") session = MemorySession( @@ -146,7 +184,7 @@ def loadSession( identifier: str, isConfidential: bool, authenticatedBy: SessionMechanism, - ) -> Deferred: + ) -> Deferred[ISession]: storage = self._storage(isConfidential) if identifier in storage: return succeed(storage[identifier]) @@ -160,4 +198,5 @@ def loadSession( ) def sentInsecurely(self, tokens: Iterable[str]) -> None: - return + for token in tokens: + self._storage(True).pop(token, None) diff --git a/src/klein/storage/memory/_memory_users.py b/src/klein/storage/memory/_memory_users.py new file mode 100644 index 000000000..d9bc59095 --- /dev/null +++ b/src/klein/storage/memory/_memory_users.py @@ -0,0 +1,144 @@ +# -*- test-case-name: klein.test.test_form.TestForms -*- +from __future__ import annotations + +from collections import defaultdict +from typing import Dict, Iterable, List, Optional, Sequence, Type + +from attrs import Factory, define, field +from zope.interface import implementer + +from twisted.python.components import Componentized + +from ..._util import eagerDeferredCoroutine +from ...interfaces import ISession, ISimpleAccount, ISimpleAccountBinding +from ._memory import _MemoryAuthorizerFunction, declareMemoryAuthorizer + + +@implementer(ISimpleAccount) +@define +class MemoryAccount: + """ + Implementation of in-memory simple account. + """ + + store: MemoryAccountStore + accountID: str + username: str + password: str = field(repr=False) + + @eagerDeferredCoroutine + async def bindSession(self, session: ISession) -> None: + """ + Bind this account to the given session. + """ + self.store._bindings[session.identifier].append(self) + + @eagerDeferredCoroutine + async def changePassword(self, newPassword: str) -> None: + """ + Change the password of this account. + """ + self.password = newPassword + + +@implementer(ISimpleAccountBinding) +@define +class MemoryAccountBinding: + """ + Implementation of in-memory simple account binding. + """ + + store: MemoryAccountStore + session: ISession + + @eagerDeferredCoroutine + async def boundAccounts(self) -> Sequence[ISimpleAccount]: + return self.store._bindings[self.session.identifier] + + @eagerDeferredCoroutine + async def createAccount( + self, username: str, email: str, password: str + ) -> Optional[ISimpleAccount]: + """ + Refuse to create new accounts; memory accounts should be pre-created, + since they won't persist. + """ + + @eagerDeferredCoroutine + async def bindIfCredentialsMatch( + self, username: str, password: str + ) -> Optional[ISimpleAccount]: + """ + Bind if the credentials match. + """ + account = self.store._accounts.get(username) + if account is None: + return None + if account.password != password: + return None + account.bindSession(self.session) + return account + + @eagerDeferredCoroutine + async def unbindThisSession(self) -> None: + """ + Un-bind this session from all accounts. + """ + del self.store._bindings[self.session.identifier] + + +@define +class MemoryAccountStore: + """ + In-memory account store. + """ + + _accounts: Dict[str, MemoryAccount] = field(default=Factory(dict)) + _bindings: Dict[str, List[MemoryAccount]] = field(default=defaultdict(list)) + + def authorizers(self) -> Iterable[_MemoryAuthorizerFunction]: + """ + Construct the list of authorizers from the account state populated on + this store. + """ + + @declareMemoryAuthorizer(MemoryAccount) + @eagerDeferredCoroutine + async def memauth( + interface: Type[MemoryAccount], + session: ISession, + componentized: Componentized, + ) -> Optional[MemoryAccount]: + for account in self._bindings[session.identifier]: + return account + return None + + @declareMemoryAuthorizer(ISimpleAccount) + @eagerDeferredCoroutine + async def alsoSimple( + interface: Type[ISimpleAccount], + session: ISession, + componentized: Componentized, + ) -> Optional[ISimpleAccount]: + return (await session.authorize([MemoryAccount])).get(MemoryAccount) + + @declareMemoryAuthorizer(ISimpleAccountBinding) + def membind( + interface: Type[ISimpleAccountBinding], + session: ISession, + componentized: Componentized, + ) -> ISimpleAccountBinding: + """ + ISimpleAccountBinding. + """ + return MemoryAccountBinding(self, session) + + return [membind, alsoSimple, memauth] + + def addAccount(self, username: str, password: str) -> None: + """ + Add an account with the given username and password. + """ + self._accounts[username] = MemoryAccount( + self, str(len(self._accounts)), username=username, password=password + ) diff --git a/src/klein/storage/memory/test/__init__.py b/src/klein/storage/memory/test/__init__.py new file mode 100644 index 000000000..6c2cc0948 --- /dev/null +++ b/src/klein/storage/memory/test/__init__.py @@ -0,0 +1,3 @@ +""" +Tests for in-memory session storage. +""" diff --git a/src/klein/test/test_memory.py b/src/klein/storage/memory/test/test_memory.py similarity index 100% rename from src/klein/test/test_memory.py rename to src/klein/storage/memory/test/test_memory.py diff --git a/src/klein/storage/passwords/__init__.py b/src/klein/storage/passwords/__init__.py new file mode 100644 index 000000000..df17c608d --- /dev/null +++ b/src/klein/storage/passwords/__init__.py @@ -0,0 +1,13 @@ +""" +Testable, secure hashing for passwords. +""" + +from ._interfaces import PasswordEngine +from ._scrypt import InvalidPasswordRecord, defaultSecureEngine + + +__all__ = [ + "InvalidPasswordRecord", + "defaultSecureEngine", + "PasswordEngine", +] diff --git a/src/klein/storage/passwords/_interfaces.py b/src/klein/storage/passwords/_interfaces.py new file mode 100644 index 000000000..6bccbdb18 --- /dev/null +++ b/src/klein/storage/passwords/_interfaces.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import Awaitable, Callable + +from ..._typing_compat import Protocol + + +class PasswordEngine(Protocol): + """ + Interface required to hash passwords for secure storage. + """ + + async def computeKeyText( + self, + passwordText: str, + ) -> str: + """ + Compute some text to store for a given plain-text password. + + @param passwordText: The text of a new password, as entered by a user. + + @return: The hashed text to store. + """ + + async def checkAndReset( + self, + storedPasswordHash: str, + providedPasswordText: str, + storeNewHash: Callable[[str], Awaitable[None]], + ) -> bool: + """ + Check the given stored password text against the given provided + password text. If password policies have changed since the given hash + was stored and C{providedPasswordText} is correct, compute a new hash + and use C{storeNewHash} to write it back to the data store. + + @param storedPasswordText: the opaque hashed output from our hash + function, stored in a datastore. + + @param providedPasswordText: the plain-text password provided by the + user. + + @param storeNewHash: A function that stores a new hash in the database. + + @return: an awaitable boolean; C{True} if the password matches (i.e, + the user has successfully authenticated) and C{False} if the + password does not match. + """ diff --git a/src/klein/storage/passwords/_scrypt.py b/src/klein/storage/passwords/_scrypt.py new file mode 100644 index 000000000..643799e4d --- /dev/null +++ b/src/klein/storage/passwords/_scrypt.py @@ -0,0 +1,193 @@ +# -*- test-case-name: klein.storage.passwords.test.test_passwords -*- +from __future__ import annotations + +from dataclasses import dataclass +from os import urandom +from re import compile as compileRE +from typing import TYPE_CHECKING, Awaitable, Callable, Type +from unicodedata import normalize + +from ..._util import threadedDeferredFunction +from ._interfaces import PasswordEngine + + +try: + from hashlib import scrypt +except ImportError: + # PyPy ships without scrypt so we need cryptography there. + from cryptography.hazmat.primitives.kdf.scrypt import Scrypt + + # The signature of C{scrypt} from the standard library has a bunch of + # additional complexity, supporting memory views and types other than + # `bytes`, but this is not a publicly exposed or particularly principled + # annotation so we ignore the minor differences in the two signatures here. + + def scrypt( # type:ignore[misc] + password: bytes, + *, + salt: bytes, + n: int, + r: int, + p: int, + maxmem: int = 0, + dklen: int = 64, + ) -> bytes: + return Scrypt(salt=salt, length=dklen, n=n, r=r, p=p).derive(password) + + +@threadedDeferredFunction +def runScrypt(password: str, salt: bytes, n: int, r: int, p: int) -> bytes: + """ + Run L{scrypt} in a thread. + """ + maxmem = (2**8) * n * r + return scrypt( + normalize("NFD", password).encode("utf-8"), + salt=salt, + n=n, + r=r, + p=p, + maxmem=maxmem, + ) + + +class InvalidPasswordRecord(Exception): + """ + A stored password was not in a valid format. + """ + + +sep = "\\$" +MARKER = "klein-scrypt" + + +HEX = "[0-9a-f]+" +INT = "[0-9]+" + + +def g(**names: str) -> str: + [[name, expression]] = list(names.items()) + return f"(?P<{name}>{expression})" + + +fields = [MARKER, g(hashed=HEX), g(salt=HEX), g(n=INT), g(r=INT), g(p=INT)] +recordRE = compileRE(sep + sep.join(fields) + sep) + + +@dataclass +class SCryptHashedPassword: + """ + a password hashed using SCrypt with certain parameters. + """ + + hashed: bytes + salt: bytes + n: int + r: int + p: int + + def serialize(self) -> str: + """ + Serialize this L{SCryptHashedPassword} to a string. Callers must + consider this opaque. + """ + return ( + f"${MARKER}${self.hashed.hex()}" + f"${self.salt.hex()}${self.n}${self.r}${self.p}$" + ) + + async def verify(self, password: str) -> bool: + """ + Compare the given password to this hash. + + @return: an awaitable True if it matches, False if it doesn't. + """ + computed = await runScrypt(password, self.salt, self.n, self.r, self.p) + return self.hashed == (computed) + + @classmethod + def load(cls, serialized: str) -> SCryptHashedPassword: + """ + Load a SCryptHashedPassword from a string produced by + L{SCryptHashedPassword.serialize}. + """ + matched = recordRE.fullmatch(serialized) + if not matched: + raise InvalidPasswordRecord("invalid password record") + return cls( + bytes.fromhex(matched["hashed"]), + bytes.fromhex(matched["salt"]), + int(matched["n"]), + int(matched["r"]), + int(matched["p"]), + ) + + @classmethod + # "If Argon2id is not available, use scrypt with a minimum CPU/memory cost + # parameter of (2^17), a minimum block size of 8 (1024 bytes), and a + # parallelization parameter of 1." - + # https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html + async def new( + cls, inputText: str, n: int = 2**18, r: int = 8, p: int = 1 + ) -> SCryptHashedPassword: + """ + Hash C{inputText} in a thread to create a new L{SCryptHashedPassword}. + """ + salt = urandom(16) + return cls(await runScrypt(inputText, salt, n, r, p), salt, n, r, p) + + +@dataclass +class KleinV1PasswordEngine: + """ + Built-in engine for hashing passwords for secure storage with basic + C{scrypt} parameters. + + Implementation of L{PasswordEngine}. + """ + + minimumN: int = 2**18 + preferredN: int = 2**19 + + async def computeKeyText(self, passwordText: str) -> str: + hashed = await SCryptHashedPassword.new(passwordText, self.preferredN) + return hashed.serialize() + + async def checkAndReset( + self, + storedPasswordHash: str, + providedPasswordText: str, + storeNewHash: Callable[[str], Awaitable[None]], + ) -> bool: + hashed = SCryptHashedPassword.load(storedPasswordHash) + if await hashed.verify(providedPasswordText): + if hashed.n < self.minimumN: + newHash = await SCryptHashedPassword.new( + providedPasswordText, self.preferredN + ) + await storeNewHash(newHash.serialize()) + return True + else: + return False + + +def defaultSecureEngine() -> PasswordEngine: + """ + Supply an implementation to the caller of L{PasswordEngine} suitable for + deployment to production. + + Presently this is an C{scrypt}-based implementation using cost parameters + recommended by OWASP as this is a least-common-denominator approach. + + However, this entrypoint is guaranteed to return a L{PasswordEngine} in the + future that can backward-compatibly parse outputs from C{computeKeyText} + and C{checkAndReset} from any previous version of Klein, as well as store + upgraded hashes whenever modern security standards are upgraded. + + @see: for testing, use L{klein.storage.passwords.testing.engineForTesting}. + """ + return KleinV1PasswordEngine() + + +if TYPE_CHECKING: + _1: Type[PasswordEngine] = KleinV1PasswordEngine diff --git a/src/klein/storage/passwords/_testing.py b/src/klein/storage/passwords/_testing.py new file mode 100644 index 000000000..e5dea2c3c --- /dev/null +++ b/src/klein/storage/passwords/_testing.py @@ -0,0 +1,85 @@ +# -*- test-case-name: klein.storage.passwords.test.test_passwords -*- + +from dataclasses import dataclass, field +from hashlib import sha256 +from os import urandom +from typing import Awaitable, Callable, Optional +from unicodedata import normalize +from unittest import TestCase + +from ._interfaces import PasswordEngine + + +@dataclass +class InsecurePasswordEngineOnlyForTesting: + """ + Very fast in-memory password engine that is suitable only for testing. + """ + + tempSalt: bytes = field(default_factory=lambda: urandom(16)) + hashVersion: int = 1 + upgradedHashes: int = 0 + + async def computeKeyText(self, passwordText: str) -> str: + # hashing here only in case someone *does* put this into production; + # the salt will be lost, and this will be garbage, so all auth will + # fail. + return ( + f"{self.hashVersion}-" + + sha256( + normalize("NFD", passwordText).encode("utf-8") + self.tempSalt + ).hexdigest() + ) + + async def checkAndReset( + self, + storedPasswordHash: str, + providedPasswordText: str, + storeNewHash: Callable[[str], Awaitable[None]], + ) -> bool: + storedVersion, storedActualHash = storedPasswordHash.split("-") + computedHash = await self.computeKeyText(providedPasswordText) + newVersion, receivedActualHash = computedHash.split("-") + valid = storedActualHash == receivedActualHash + if valid and int(newVersion) > int(storedVersion): + await storeNewHash(computedHash) + self.upgradedHashes += 1 + return valid + + +cacheAttribute = "__insecurePasswordEngine__" + + +def engineForTesting( + testCase: TestCase, *, upgradeHashes: bool = False +) -> PasswordEngine: + """ + Return an insecure password engine that is very fast, suitable for using in + unit tests. + + @param testCase: The test case for which this engine is to be used. The + engine will be cached on the test case, so that multiple calls will + return the same object. + + @param storeNewHashes: Should the engine's C{checkAndReset} method call its + C{storePasswordHash} argument? Note that this mutates the existing + engine if one has already been cached. + """ + result: Optional[InsecurePasswordEngineOnlyForTesting] = getattr( + testCase, cacheAttribute, None + ) + if result is None: + result = InsecurePasswordEngineOnlyForTesting() + setattr(testCase, cacheAttribute, result) + result.hashVersion += upgradeHashes + return result + + +def hashUpgradeCount(testCase: TestCase) -> int: + """ + How many times has the L{engineForTesting} for the given test upgraded the + hash of a stored password? + """ + engine = engineForTesting(testCase) + assert isinstance(engine, InsecurePasswordEngineOnlyForTesting) + return engine.upgradedHashes diff --git a/src/klein/storage/passwords/test/__init__.py b/src/klein/storage/passwords/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/klein/storage/passwords/test/test_passwords.py b/src/klein/storage/passwords/test/test_passwords.py new file mode 100644 index 000000000..51ab78176 --- /dev/null +++ b/src/klein/storage/passwords/test/test_passwords.py @@ -0,0 +1,140 @@ +from typing import Awaitable, Callable, List, Tuple + +from twisted.trial.unittest import SynchronousTestCase, TestCase + +from ...._util import eagerDeferredCoroutine +from .. import InvalidPasswordRecord, PasswordEngine, defaultSecureEngine +from .._scrypt import KleinV1PasswordEngine +from ..testing import engineForTesting, hashUpgradeCount + + +def pwStorage() -> Tuple[List[str], Callable[[str], Awaitable[None]]]: + hashes = [] + + async def storeSomething(something: str) -> None: + hashes.append(something) + + return hashes, storeSomething + + +class PasswordStorageTests(TestCase): + def setUp(self) -> None: + self.engine: PasswordEngine = KleinV1PasswordEngine(2**14, 2**15) + self.newHashes, self.storeSomething = pwStorage() + + @eagerDeferredCoroutine + async def test_checkAndResetDefault(self) -> None: + """ + Tests for L{checkAndReset} and L{computeHash} functions. These are a + little slow because they're verifying the normal / good default of the + production-grade CryptContext hash. + """ + kt1 = await self.engine.computeKeyText("hello world") + bad = await self.engine.checkAndReset( + kt1, "hello wordl", self.storeSomething + ) + self.assertFalse(bad, "passwords don't match") + good = await self.engine.checkAndReset( + kt1, "hello world", self.storeSomething + ) + self.assertTrue(good, "passwords do match") + self.assertEqual(self.newHashes, []) + + @eagerDeferredCoroutine + async def test_resetOnNewRounds(self) -> None: + """ + When the supplied CryptContext requires more rounds, the store function + will be called. + """ + oldEngine = KleinV1PasswordEngine(2**10, 2**12) + kt1 = await oldEngine.computeKeyText("hello world") + check1 = await self.engine.checkAndReset( + kt1, "hello world", self.storeSomething + ) + self.assertTrue(check1) + self.assertEqual(len(self.newHashes), 1) + newHash = self.newHashes.pop() + check2 = await self.engine.checkAndReset( + newHash, "hello world", self.storeSomething + ) + self.assertTrue(check2) + self.assertEqual(self.newHashes, []) + + @eagerDeferredCoroutine + async def test_serializationErrorHandling(self) -> None: + """ + Un-parseable passwords will result in a L{BadStoredPassword} exception. + """ + engine = KleinV1PasswordEngine(2**5, 2**6) + with self.assertRaises(InvalidPasswordRecord): + await engine.checkAndReset( + "gibberish", "my-password", self.storeSomething + ) + + +class TestDefaultEntryPoint(SynchronousTestCase): + """ + Test for externally facing /public API. + """ + + def test_entryPoint(self) -> None: + """ + L{defaultSecureEngine} returns a L{KleinV1PasswordEngine} + """ + self.assertIsInstance(defaultSecureEngine(), KleinV1PasswordEngine) + + +class TestTesting(SynchronousTestCase): + """ + Tests for L{engineForTesting} + """ + + def test_testingEngine(self) -> None: + """ + A cached password engine can verify passwords. + """ + changes, storeSomething = pwStorage() + engine1 = engineForTesting(self) + engine2 = engineForTesting(self) + # The same test should get back the same engine. + self.assertIs(engine1, engine2) + # should complete synchronously + kt1 = self.successResultOf(engine1.computeKeyText("hello world")) + + self.assertEqual( + True, + self.successResultOf( + engine1.checkAndReset( + kt1, + "hello world", + storeSomething, + ) + ), + ) + self.assertEqual( + False, + self.successResultOf( + engine1.checkAndReset( + kt1, + "hello wordl", + storeSomething, + ) + ), + ) + + self.assertEqual(hashUpgradeCount(self), 0) + engine3 = engineForTesting(self, upgradeHashes=True) + # Still the same. + self.assertIs(engine2, engine3) + # But now we will upgrade hashes upon reset. + self.successResultOf( + engine1.checkAndReset(kt1, "hello world", storeSomething) + ) + self.assertEqual(len(changes), 1) + self.assertEqual(hashUpgradeCount(self), 1) + kt2 = changes.pop() + self.successResultOf( + engine1.checkAndReset(kt2, "hello world", storeSomething) + ) + self.assertEqual(len(changes), 0) + self.assertEqual(hashUpgradeCount(self), 1) diff --git a/src/klein/storage/passwords/testing.py b/src/klein/storage/passwords/testing.py new file mode 100644 index 000000000..43ee303d5 --- /dev/null +++ b/src/klein/storage/passwords/testing.py @@ -0,0 +1,19 @@ +# -*- test-case-name: klein.storage.passwords.test.test_passwords -*- +""" +Unit testing support for L{klein.storage.passwords}. + +In production, password hashing needs to be slow enough that it requires +delgation to alternate threads, and, obviously, deterministic. These testing +facilities present the same interface, but are fast so as not to slow down your +tests; for safety, they are I{not} repeatable, generating per-session state and +providing no API to serialize it, so as to avoid accidentally relying on it in +production. +""" + +from ._testing import engineForTesting, hashUpgradeCount + + +__all__ = [ + "engineForTesting", + "hashUpgradeCount", +] diff --git a/src/klein/storage/sql/__init__.py b/src/klein/storage/sql/__init__.py new file mode 100644 index 000000000..058fad743 --- /dev/null +++ b/src/klein/storage/sql/__init__.py @@ -0,0 +1,19 @@ +""" +An implementation of a basic username/password authentication database using +C{dbxs}. +""" + +from ._sql_glue import ( + SessionStore, + SQLSessionProcurer, + applyBasicSchema, + authorizerFor, +) + + +__all__ = [ + "SQLSessionProcurer", + "SessionStore", + "authorizerFor", + "applyBasicSchema", +] diff --git a/src/klein/storage/sql/_sql_dal.py b/src/klein/storage/sql/_sql_dal.py new file mode 100644 index 000000000..d20137162 --- /dev/null +++ b/src/klein/storage/sql/_sql_dal.py @@ -0,0 +1,195 @@ +# -*- test-case-name: klein.storage.sql.test,klein.storage.test.test_common -*- +from __future__ import annotations + +from datetime import datetime +from typing import AsyncIterable, Optional + +from attrs import define + +from ..._typing_compat import Protocol +from ..._util import eagerDeferredCoroutine +from ...interfaces import ISession +from ..dbxs import accessor, many, maybe, query, statement + + +@define +class SessionRecord: + """ + The fields from the session that are stored in the database. + + The distinction between a L{SessionRecord} and an L{SQLSession} is that an + L{SQLSession} binds to an actual request, and thus has an + C{authenticatedBy} attribute, which inherently cannot be stored in the + database. + """ + + db: SessionDAL + session_id: str + confidential: bool + + +@define +class AccountRecord: + """ + An implementation of L{ISimpleAccount} backed by an SQL data store. + """ + + db: SessionDAL + accountID: str + username: str + email: str + password_blob: Optional[str] = None + + @eagerDeferredCoroutine + async def bindSession(self, session: ISession) -> None: + """ + Add a session to the database. + """ + await self.db.bindAccountToSession(self.accountID, session.identifier) + + +class SessionDAL(Protocol): + """ + Data access layer for core sessions database. + """ + + @statement( + sql="delete from session where " + "session_id = {sessionID} and " + "confidential = true" + ) + async def deleteSession(self, sessionID: str) -> None: + """ + Signature for deleting a session by session ID. + """ + + @statement( + sql=""" + insert into session + ( session_id, confidential, created, mechanism ) + values + ({sessionID}, {confidential}, {created}, {mechanism}) + """ + ) + async def insertSession( + self, sessionID: str, confidential: bool, created: float, mechanism: str + ) -> None: + """ + Signature for deleting a session by session ID. + """ + + @query( + sql=""" + select session_id, confidential from session + where session_id = {session_id} and + confidential = {is_confidential} and + mechanism = {mechanism} + """, + load=maybe(SessionRecord), + ) + async def sessionByID( + self, + session_id: str, + is_confidential: bool, + mechanism: str, + ) -> Optional[SessionRecord]: + """ + Signature for getting a session by session ID. + """ + + @statement( + sql="insert into account values " + "({account_id}, {username}, {email}, {password_blob})" + ) + async def createAccount( + self, account_id: str, username: str, email: str, password_blob: str + ) -> None: + """ + Signature for creating an account. + """ + + @statement( + sql="insert into session_account values ({account_id}, {session_id})" + ) + async def bindAccountToSession( + self, account_id: str, session_id: str + ) -> None: + """ + Signature for binding an account to a session. + """ + + @query( + sql=( + "select account_id, username, email, password_blob " + "from account " + "where username = {username}" + ), + load=maybe(AccountRecord), + ) + async def accountByUsername(self, username: str) -> Optional[AccountRecord]: + """ + Load an account by username. + """ + + @statement( + sql=""" + update account + set password_blob = {newBlob} + where account_id = {accountID} + """ + ) + async def resetPassword(self, accountID: str, newBlob: str) -> None: + """ + Reset the password for the given account ID. + """ + + @query( + sql=""" + select account.account_id, + account.username, + account.email, + account.password_blob + from session_account + join account + where session_account.session_id = {session_id} + and session_account.account_id = account.account_id + """, + load=many(AccountRecord), + ) + def boundAccounts(self, session_id: str) -> AsyncIterable[AccountRecord]: + """ + Load all account objects bound to the given session id. + """ + + @statement( + sql=""" + delete from session_account where session_id = {sessionID} + """ + ) + async def unbindSession(self, sessionID: str) -> None: + """ + Un-bind the given session from the account it's currently bound to. + """ + + @statement( + sql=""" + insert into session_ip values ( + {sessionID}, {ipAddress}, {addressFamily}, {lastUsed} + ) + on conflict(session_id, ip_address, address_family) + do update set last_used = excluded.last_used + """, + ) + async def createOrUpdateIPRecord( + self, + sessionID: str, + ipAddress: str, + addressFamily: str, + lastUsed: datetime, + ) -> None: + """ + Add the given IP or update its last-used timestamp. + """ + + +SessionDB = accessor(SessionDAL) diff --git a/src/klein/storage/sql/_sql_glue.py b/src/klein/storage/sql/_sql_glue.py new file mode 100644 index 000000000..3b20e641f --- /dev/null +++ b/src/klein/storage/sql/_sql_glue.py @@ -0,0 +1,470 @@ +# -*- test-case-name: klein.storage.test.test_common -*- +""" +Glue that connects the SQL DAL to Klein's session interfaces. +""" + +from __future__ import annotations + +from binascii import hexlify +from dataclasses import dataclass, field +from os import urandom +from time import time +from typing import ( + Any, + Awaitable, + Callable, + Generic, + Iterable, + Optional, + Sequence, + Type, + TypeVar, +) +from uuid import uuid4 + +from attrs import define +from zope.interface import implementer + +from twisted.internet.defer import Deferred, gatherResults, succeed +from twisted.python.modules import getModule +from twisted.web.iweb import IRequest + +from klein.interfaces import ( + ISession, + ISessionProcurer, + ISimpleAccountBinding, + NoSuchSession, + SessionMechanism, +) + +from ... import SessionProcurer +from ..._isession import AuthorizationMap +from ..._typing_compat import Protocol +from ..._util import eagerDeferredCoroutine +from ...interfaces import ISessionStore, ISimpleAccount +from ..dbxs.dbapi_async import AsyncConnectable, AsyncConnection, transaction +from ..passwords import PasswordEngine, defaultSecureEngine +from ._sql_dal import AccountRecord, SessionDAL, SessionDB, SessionRecord +from ._transactions import requestBoundTransaction + + +T = TypeVar("T") + + +@implementer(ISession) +@define +class SQLSession: + _sessionStore: SessionStore + identifier: str + isConfidential: bool + authenticatedBy: SessionMechanism + + @eagerDeferredCoroutine + async def authorize( + self, interfaces: Iterable[Type[object]] + ) -> AuthorizationMap: + """ + Authorize all the given interfaces and return a mapping that contains + all the ones that could be authorized. + """ + authTypes = set(interfaces) + result: AuthorizationMap + result = {} # type: ignore[assignment] + # ^ mypy really wants this container to be homogenous along some axis, + # so a dict with value types that depend on keys doesn't look right to + # it. + txn = await self._sessionStore._transaction() + store = self._sessionStore + + async def doAuthorize(a: SQLAuthorizer[T]) -> None: + result[a.authorizationType] = await a.authorizationForSession( + store, txn, self + ) + + await gatherResults( + [ + Deferred.fromCoroutine(doAuthorize(each)) + for each in self._sessionStore._authorizers + if each.authorizationType in authTypes + ] + ) + return result + + @classmethod + def realize( + cls, + record: SessionRecord, + store: SessionStore, + authenticatedBy: SessionMechanism, + ) -> SQLSession: + """ + Construct a 'live' session with authentication information and a store + with authorizers from a session record. + """ + return cls( + sessionStore=store, + authenticatedBy=authenticatedBy, + isConfidential=record.confidential, + identifier=record.session_id, + ) + + +@implementer(ISessionStore) +@define +class SessionStore: + """ + An implementation of L{ISessionStore} based on an L{AsyncConnection}, that + stores sessions in a database. + """ + + _transaction: Callable[[], Awaitable[AsyncConnection]] + _authorizers: Sequence[SQLAuthorizer[object]] + _passwordEngine: PasswordEngine + + async def _sentInsecurely(self, tokens: Sequence[str]) -> None: + """ + Tokens have been sent insecurely; delete any tokens expected to be + confidential. Return a deferred that fires when they've been deleted. + + @param tokens: L{list} of L{str} + + @return: a L{Deferred} that fires when the tokens have been + invalidated. + """ + db = SessionDB(await self._transaction()) + for token in tokens: + await db.deleteSession(token) + + def sentInsecurely(self, tokens: Sequence[str]) -> None: + """ + Per the interface, fire-and-forget version of _sentInsecurely + """ + Deferred.fromCoroutine(self._sentInsecurely(tokens)) + + @eagerDeferredCoroutine + async def newSession( + self, isConfidential: bool, authenticatedBy: SessionMechanism + ) -> ISession: + identifier = hexlify(urandom(32)).decode("ascii") + db = SessionDB(await self._transaction()) + await db.insertSession( + identifier, isConfidential, time(), authenticatedBy.name + ) + result = SQLSession( + self, + identifier=identifier, + isConfidential=isConfidential, + authenticatedBy=authenticatedBy, + ) + return result + + @eagerDeferredCoroutine + async def loadSession( + self, + identifier: str, + isConfidential: bool, + authenticatedBy: SessionMechanism, + ) -> ISession: + db = SessionDB(await self._transaction()) + record = await db.sessionByID( + identifier, isConfidential, authenticatedBy.name + ) + if record is None: + raise NoSuchSession("session not found") + return SQLSession.realize(record, self, authenticatedBy) + + +@implementer(ISimpleAccount) +@dataclass +class SQLAccount: + """ + SQL-backed implementation of ISimpleAccount + """ + + _store: SessionStore + _record: AccountRecord + + @property + def accountID(self) -> str: + return self._record.accountID + + @eagerDeferredCoroutine + async def bindSession(self, session: ISession) -> None: + return await self._record.bindSession(session) + + @eagerDeferredCoroutine + async def changePassword(self, newPassword: str) -> None: + """ + @param newPassword: The text of the new password. + @type newPassword: L{unicode} + """ + computedHash = await self._store._passwordEngine.computeKeyText( + newPassword + ) + await self._record.db.resetPassword(self.accountID, computedHash) + + @property + def username(self) -> str: + return self._record.username + + +@implementer(ISimpleAccountBinding) +@define +class AccountSessionBinding: + """ + (Stateless) binding between an account and a session, so that sessions can + attach to and detach from authenticated account objects. + """ + + _store: SessionStore + _session: ISession + _transaction: AsyncConnection + + def _account(self, accountID: str, username: str, email: str) -> SQLAccount: + """ + Construct an L{SQLAccount} bound to this plugin & dataStore. + """ + return SQLAccount( + self._store, AccountRecord(self.db, accountID, username, email) + ) + + @property + def db(self) -> SessionDAL: + """ + session db + """ + return SessionDB(self._transaction) + + @eagerDeferredCoroutine + async def createAccount( + self, username: str, email: str, password: str + ) -> Optional[ISimpleAccount]: + """ + Create a new account with the given username, email and password. + + @return: an L{Account} if one could be created, L{None} if one could + not be. + """ + computedHash = await self._store._passwordEngine.computeKeyText( + password + ) + newAccountID = str(uuid4()) + try: + await self.db.createAccount( + newAccountID, username, email, computedHash + ) + except Exception: + # TODO: wrap up IntegrityError from DB binding somehow so we can be + # more selective about what we're catching. + return None + else: + accountID = newAccountID + account = self._account(accountID, username, email) + return account + + @eagerDeferredCoroutine + async def bindIfCredentialsMatch( + self, username: str, password: str + ) -> Optional[ISimpleAccount]: + """ + Associate this session with a given user account, if the password + matches. + + @param username: The username input by the user. + + @param password: The plain-text password input by the user. + """ + maybeAccountRecord = await self.db.accountByUsername(username) + if maybeAccountRecord is None: + return None + + accountRecord = maybeAccountRecord + + def storeNewBlob(newPWText: str) -> Any: + return self.db.resetPassword(accountRecord.accountID, newPWText) + + assert accountRecord.password_blob is not None + if await self._store._passwordEngine.checkAndReset( + accountRecord.password_blob, + password, + storeNewBlob, + ): + account = SQLAccount(self._store, accountRecord) + await account.bindSession(self._session) + return account + return None + + @eagerDeferredCoroutine + async def boundAccounts(self) -> Sequence[ISimpleAccount]: + """ + Retrieve the accounts currently associated with this session. + + @return: L{Deferred} firing with a L{list} of accounts. + """ + accounts = [] + async for record in self.db.boundAccounts(self._session.identifier): + accounts.append(SQLAccount(self._store, record)) + return accounts + + @eagerDeferredCoroutine + async def unbindThisSession(self) -> None: + """ + Disassociate this session from any accounts it's logged in to. + + @return: a L{Deferred} that fires when the account is logged out. + """ + await self.db.unbindSession(self._session.identifier) + + +@implementer(ISessionProcurer) +@dataclass +class SQLSessionProcurer: + """ + Alternate implementation of L{ISessionProcurer}, necessary because the + underlying L{SessionProcurer} requires an L{ISessionStore}, and our + L{ISessionStore} implementation requires a database transaction to be + associated with both it and the request. + """ + + connectable: AsyncConnectable + authorizers: Sequence[SQLAuthorizer[Any]] + passwordEngine: PasswordEngine = field(default_factory=defaultSecureEngine) + storeToProcurer: Callable[ + [ISessionStore], SessionProcurer + ] = SessionProcurer + + @eagerDeferredCoroutine + async def procureSession( + self, request: IRequest, forceInsecure: bool = False + ) -> ISession: + """ + Procure a session from the underlying procurer, keeping track of the IP + of the request object. + """ + alreadyProcured: Optional[ISession] = ISession(request, None) + + assert ( + alreadyProcured is None + ), """ + Sessions should only be procured once during the lifetime of the + request, and it should not be possible to invoke procureSession + multiple times when getting them from dependency injection. + """ + + # Deferred is declared as contravariant, but this is an error, it + # really ought to be covariant (like Awaitable) + allAuthorizers: Sequence[SQLAuthorizer[Any]] = [ + simpleAccountBinding.authorizer, + logMeIn.authorizer, + *self.authorizers, + ] + + async def getTransaction() -> AsyncConnection: + return await requestBoundTransaction(request, self.connectable) + + procurer = self.storeToProcurer( + SessionStore(getTransaction, allAuthorizers, self.passwordEngine) + ) + return await procurer.procureSession(request, forceInsecure) + + +_authorizerFunction = Callable[ + [SessionStore, AsyncConnection, ISession], "Awaitable[Optional[T]]" +] + + +class _FunctionWithAuthorizer(Protocol[T]): + authorizer: SQLAuthorizer[T] + authorizerType: Type[T] + + def __call__( + self, + sessionStore: SessionStore, + transaction: AsyncConnection, + session: ISession, + ) -> Deferred[T]: + """ + Signature for a function that can have an authorizer attached to it. + """ + + +@define +class SQLAuthorizer(Generic[T]): + authorizationType: Type[T] + _decorated: _authorizerFunction[T] + + def authorizationForSession( + self, + sessionStore: SessionStore, + transaction: AsyncConnection, + session: ISession, + ) -> Awaitable[Optional[T]]: + return self._decorated(sessionStore, transaction, session) + + +def authorizerFor( + authorizationType: Type[T], +) -> Callable[[_authorizerFunction[T]], _FunctionWithAuthorizer[T]]: + """ + Declare an authorizer. + """ + + def decorator( + decorated: _authorizerFunction[T], + ) -> _FunctionWithAuthorizer[T]: + result: _FunctionWithAuthorizer = decorated # type:ignore[assignment] + result.authorizer = SQLAuthorizer[T](authorizationType, decorated) + result.authorizerType = authorizationType + return result + + return decorator + + +@authorizerFor(ISimpleAccountBinding) +def simpleAccountBinding( + sessionStore: SessionStore, + transaction: AsyncConnection, + session: ISession, +) -> Deferred[ISimpleAccountBinding]: + """ + All sessions are authorized for access to an L{ISimpleAccountBinding}. + """ + return succeed(AccountSessionBinding(sessionStore, session, transaction)) + + +@authorizerFor(ISimpleAccount) +async def logMeIn( + sessionStore: ISessionStore, + transaction: AsyncConnection, + session: ISession, +) -> Optional[ISimpleAccount]: + """ + Retrieve an L{ISimpleAccount} authorization. + """ + binding = (await session.authorize([ISimpleAccountBinding]))[ + ISimpleAccountBinding + ] + accounts = await binding.boundAccounts() + for account in accounts: + return account + return None + + +async def applyBasicSchema(connectable: AsyncConnectable) -> None: + """ + Apply the session and authentication schema to the given database within a + dedicated transaction. + """ + async with transaction(connectable) as c: + cursor = await c.cursor() + for stmt in ( + getModule(__name__) + .filePath.parent() + .parent() + .child("sql") + .child("basic_auth_schema.sql") + .getContent() + .decode("utf-8") + .split(";") + ): + await cursor.execute(stmt) diff --git a/src/klein/storage/sql/_transactions.py b/src/klein/storage/sql/_transactions.py new file mode 100644 index 000000000..54bde5264 --- /dev/null +++ b/src/klein/storage/sql/_transactions.py @@ -0,0 +1,196 @@ +# -*- test-case-name: klein.storage.sql.test.test_transactions -*- +from __future__ import annotations + +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any, AsyncIterator, Awaitable, Callable, Dict + +from attrs import Factory, define, field +from zope.interface import Interface, implementer + +from twisted.internet.defer import Deferred, gatherResults, succeed +from twisted.logger import Logger +from twisted.python.components import Componentized, registerAdapter +from twisted.web.iweb import IRequest +from twisted.web.server import Request + +from klein.interfaces import ( + IDependencyInjector, + IRequestLifecycle, + IRequiredParameter, +) + +from ..._util import eagerDeferredCoroutine +from ...interfaces import IRequirementContext +from ..dbxs.dbapi_async import AsyncConnectable, AsyncConnection + + +log = Logger() + + +class ITransactionRequestAssociator(Interface): + """ + Request component which associates transactions with requests. + """ + + def transactionForConnectable( + connectable: AsyncConnectable, + ) -> Deferred[AsyncConnection]: + """ + Get the open database transaction for the given engine. + """ + + +synchronous = succeed(None) + + +@implementer(ITransactionRequestAssociator) +@define +class TransactionRequestAssociator: + """ + Associate a transaction with a request. + """ + + request: Request + map: Dict[AsyncConnectable, AsyncConnection] = field(default=Factory(dict)) + waitMap: Dict[AsyncConnectable, Awaitable[None]] = field( + default=Factory(dict) + ) + attached: bool = False + + @eagerDeferredCoroutine + async def transactionForConnectable( + self, connectable: AsyncConnectable + ) -> AsyncConnection: + """ + Retrieve a transaction from the async connection. + """ + await self.waitMap.get(connectable, synchronous) + if connectable in self.map: + return self.map[connectable] + reqctx = IRequirementContext(self.request) + waiter = self.waitMap[connectable] = Deferred() + await reqctx.enter_async_context(self.transactify()) + cxn = await connectable.connect() + self.map[connectable] = cxn + del self.waitMap[connectable] + waiter.callback(None) + return cxn + + @asynccontextmanager + async def transactify(self) -> AsyncIterator[None]: + """ + Commit all associated transactions. + + @param ignored: To be usable as a Deferred callback, accept an + argument, but discard it. + """ + try: + yield + finally: + # Break cycle, allow for sub-transactions later (i.e. in renderers) + self.request.unsetComponent(ITransactionRequestAssociator) + await gatherResults( + [ + Deferred.fromCoroutine(value.commit()) + for value in self.map.values() + ] + ) + + +@implementer(IRequiredParameter, IDependencyInjector) +@dataclass +class Transaction: + """ + Require a transaction from a specified connectable. + + Example:: + + @dataclass + class Application: + connectable: AsyncConnectable + + router: ClassVar[Klein] = Klein() + requirer: ClassVar[Requirer] = Requirer() + + def _db(self) -> AsyncConnectable: + return self.connectable + + @requirer.require(router.route("/page"), + txn=Transaction(_db)) + async def page(self, txn: AsyncConnection): + return (await (await txn.cursor()) + .execute("select * from rows")) + + """ + + getConnectable: Callable[[Any], AsyncConnectable] + + def registerInjector( + self, + injectionComponents: Componentized, + parameterName: str, + lifecycle: IRequestLifecycle, + ) -> IDependencyInjector: + """ + I am a dependency injector. + """ + return self + + async def injectValue( + self, + instance: object, + request: IRequest, + routeParams: Dict[str, object], + ) -> AsyncConnection: + """ + Get a transaction from the associated connectable. + """ + associator = ITransactionRequestAssociator(request) + connector = self.getConnectable(instance) + return await associator.transactionForConnectable(connector) + + def finalize(self) -> None: + """ + Finalize parameter injection setup. + """ + + +registerAdapter( + TransactionRequestAssociator, IRequest, ITransactionRequestAssociator +) + + +async def requestBoundTransaction( + request: IRequest, connectable: AsyncConnectable +) -> AsyncConnection: + """ + Retrieve a transaction that is bound to the lifecycle of the given request. + + There are three use-cases for this lifecycle: + + 1. 'normal CRUD' - a request begins, a transaction is associated with + it, and the transaction completes when the request completes. The + appropriate time to commit the transaction is the moment before the + first byte goes out to the client. The appropriate moment to + interpose this commit is in , since the + HTTP status code should be an indicator of whether the transaction + succeeded or failed. + + 2. 'just the session please' - a request begins, a transaction is + associated with it in order to discover the session, and the + application code in question isn't actually using the database. + (Ideally as expressed through "the dependency-declaration decorator, + such as @authorized, did not indicate that a transaction will be + required"). + + 3. 'fancy API stuff' - a request begins, a transaction is associated + with it in order to discover the session, the application code needs + to then do I{something} with that transaction in-line with the + session discovery, but then needs to commit in order to relinquish + all database locks while doing some potentially slow external API + calls, then start a I{new} transaction later in the request flow. + """ + return await ITransactionRequestAssociator( + request + ).transactionForConnectable(connectable) diff --git a/src/klein/storage/sql/basic_auth_schema.sql b/src/klein/storage/sql/basic_auth_schema.sql new file mode 100644 index 000000000..05679d5e9 --- /dev/null +++ b/src/klein/storage/sql/basic_auth_schema.sql @@ -0,0 +1,34 @@ + +-- `session` identifies individual clients with a particular set of +-- capabilities. the `session_id` is the secret held by the client. +CREATE TABLE session ( + session_id VARCHAR NOT NULL, + confidential BOOLEAN NOT NULL, + created REAL NOT NULL, + mechanism TEXT NOT NULL, + PRIMARY KEY (session_id) +); + +-- `account` is a user with a name and password. the password_blob is computed +-- by the password engine in klein.storage.passwords. +CREATE TABLE account ( + account_id VARCHAR NOT NULL, + username VARCHAR NOT NULL, + email VARCHAR NOT NULL, + password_blob VARCHAR NOT NULL, + PRIMARY KEY (account_id), + UNIQUE (username) +); + +-- `session_account` is a record of which acccount is logged in to which session. +CREATE TABLE session_account ( + account_id VARCHAR, + session_id VARCHAR, + UNIQUE (account_id, session_id), + FOREIGN KEY(account_id) + REFERENCES account (account_id) + ON DELETE CASCADE, + FOREIGN KEY(session_id) + REFERENCES session (session_id) + ON DELETE CASCADE +); diff --git a/src/klein/storage/sql/test/__init__.py b/src/klein/storage/sql/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/klein/storage/sql/test/test_transactions.py b/src/klein/storage/sql/test/test_transactions.py new file mode 100644 index 000000000..d4f608a7c --- /dev/null +++ b/src/klein/storage/sql/test/test_transactions.py @@ -0,0 +1,87 @@ +""" +Tests for L{klein.storage.sql._transactions} +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import ClassVar, Optional + +from treq import content +from treq.testing import StubTreq + +from twisted.internet.defer import Deferred +from twisted.trial.unittest import SynchronousTestCase + +from klein import Klein, Requirer +from klein.storage.sql._transactions import Transaction + +from ...dbxs.dbapi_async import AsyncConnectable, AsyncConnection +from ...dbxs.testing import MemoryPool + + +@dataclass +class TestObject: + """ + Object to test request commit hooks. + """ + + testCase: SynchronousTestCase + connectable: AsyncConnectable + t1: Optional[AsyncConnection] = None + t2: Optional[AsyncConnection] = None + incomplete: Deferred[None] = field(default_factory=Deferred) + + router: ClassVar[Klein] = Klein() + requirer: ClassVar[Requirer] = Requirer() + + def _getDB(self) -> AsyncConnectable: + return self.connectable + + @requirer.require( + router.route("/succeed"), + t1=Transaction(_getDB), + t2=Transaction(_getDB), + ) + async def succeed(self, t1: AsyncConnection, t2: AsyncConnection) -> str: + """ + Get a transaction that commits when the request has completed. + """ + self.t1 = t1 + await self.incomplete + self.t2 = t2 + return "Hello, world!" + + +class WriteHeadersHookTests(SynchronousTestCase): + """ + Tests for L{klein.storage.sql._transactions}. + """ + + def test_sameTransactions(self) -> None: + """ + If a transaction is required multiple times, it results in the same + object. + """ + mpool = MemoryPool.new() + to = TestObject(self, mpool.connectable) + stub = StubTreq(to.router.resource()) + inProgress = stub.get("https://localhost/succeed") + self.assertNoResult(inProgress) + mpool.flush() + to.incomplete.callback(None) + self.assertNoResult(inProgress) + mpool.flush() + stub.flush() + response = self.successResultOf(inProgress) + self.assertIsNot(to.t1, None) + self.assertIs(to.t1, to.t2) + self.assertEqual( + self.successResultOf(content(response)), b"Hello, world!" + ) + + def test_everythingCommitted(self) -> None: + """ + Completing the request commits the transaction. + """ + mpool = MemoryPool.new() + mpool.flush() diff --git a/src/klein/storage/test/__init__.py b/src/klein/storage/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/klein/storage/test/test_common.py b/src/klein/storage/test/test_common.py new file mode 100644 index 000000000..5adf730fd --- /dev/null +++ b/src/klein/storage/test/test_common.py @@ -0,0 +1,319 @@ +from typing import Awaitable, Callable, List, Optional, TypeVar + +import attr +from treq import content +from treq.testing import StubTreq + +from twisted.internet.defer import Deferred +from twisted.python.compat import nativeString +from twisted.trial.unittest import TestCase +from twisted.web.iweb import IRequest + +from klein import Authorization, Field, Klein, Requirer, SessionProcurer +from klein.interfaces import ( + ISession, + ISessionProcurer, + ISimpleAccountBinding, + SessionMechanism, +) +from klein.storage.memory import MemoryAccountStore, MemorySessionStore +from klein.storage.sql._sql_glue import AccountSessionBinding, SessionStore + +from ...interfaces import ISimpleAccount +from ..dbxs.dbapi_async import transaction +from ..dbxs.testing import MemoryPool, immediateTest +from ..passwords.testing import engineForTesting, hashUpgradeCount +from ..sql import SQLSessionProcurer, applyBasicSchema + + +T = TypeVar("T") + + +@attr.s(auto_attribs=True, hash=False) +class TestObject: + procurer: ISessionProcurer + loggedInAs: Optional[ISimpleAccount] = None + boundAccounts: Optional[List[ISimpleAccount]] = None + + router = Klein() + requirer = Requirer() + + @requirer.prerequisite([ISession]) + async def procureASession(self, request: IRequest) -> Optional[ISession]: + return await self.procurer.procureSession(request) + + @requirer.require( + router.route("/private", methods=["get"]), + account=Authorization(ISimpleAccount), + ) + async def whenLoggedIn(self, account: ISimpleAccount) -> str: + """ + handle a login. + """ + return f"itsa me, {account.username}" + + @requirer.require( + router.route("/change-password", methods=["post"]), + acct=Authorization(ISimpleAccount), + newPassword=Field.password(), + ) + async def changePassword( + self, newPassword: str, acct: ISimpleAccount + ) -> str: + """ + Change the password on the logged in account. + """ + await acct.changePassword(newPassword) + return "changed" + + @requirer.require( + router.route("/login", methods=["post"]), + username=Field.text(), + password=Field.password(), + binder=Authorization(ISimpleAccountBinding), + ) + async def handleLogin( + self, username: str, password: str, binder: ISimpleAccountBinding + ) -> str: + """ + handle a login. + """ + account = self.loggedInAs = await binder.bindIfCredentialsMatch( + username, password + ) + self.boundAccounts = await binder.boundAccounts() + if account is None: + return "auth fail" + else: + return "logged in" + + @requirer.require( + router.route("/logout", methods=["post"]), + binder=Authorization(ISimpleAccountBinding), + ) + async def handleLogout(self, binder: ISimpleAccountBinding) -> str: + """ + handle a logout + """ + await binder.unbindThisSession() + return "unbound" + + +class CommonStoreTests(TestCase): + """ + Common interface! + """ + + async def authWithStoreTest( + self, + newSession: Callable[[bool, SessionMechanism], Awaitable[ISession]], + procurer: ISessionProcurer, + pool: Optional[MemoryPool] = None, + ) -> None: + """ + Test using a form to log in to an in-memory store. + """ + session = await newSession(True, SessionMechanism.Cookie) + otherSession = await newSession(True, SessionMechanism.Cookie) + + cookies = {"Klein-Secure-Session": nativeString(session.identifier)} + to = TestObject(procurer) + stub = StubTreq(to.router.resource()) + if pool is not None: + pool.additionalPump(stub.flush) + presponse = stub.get( + "https://localhost/private", + cookies={"Klein-Secure-Session": nativeString(session.identifier)}, + ) + response = await presponse + self.assertEqual(response.code, 401) + self.assertIn(b"DENIED", await content(response)) + + # wrong password + async def badLogin(badUsername: str, badPassword: str) -> None: + response = await stub.post( + "https://localhost/login", + data=dict( + username=badUsername, + password=badPassword, + __csrf_protection__=session.identifier, + ), + cookies=cookies, + ) + self.assertEqual(response.code, 200) + self.assertIn(b"auth fail", await content(response)) + + # still not logged in + presponse = stub.get( + "https://localhost/private", + cookies={ + "Klein-Secure-Session": nativeString(session.identifier) + }, + ) + response = await presponse + self.assertEqual(response.code, 401) + self.assertIn(b"DENIED", await content(response)) + + await badLogin("itsme", "wrongpassword") + await badLogin("wronguser", "doesntmatter") + + # correct password + response = await stub.post( + "https://localhost/login", + data=dict( + username="itsme", + password="secretstuff", + __csrf_protection__=session.identifier, + ), + cookies=cookies, + ) + self.assertEqual(response.code, 200) + self.assertIn(b"logged in", await content(response)) + toAccounts = to.boundAccounts + loggedIn = to.loggedInAs + assert toAccounts is not None + assert loggedIn is not None + self.assertEqual( + [each.username for each in toAccounts], [loggedIn.username] + ) + + async def check( + whichSession: ISession, code: int, contents: bytes + ) -> None: + response = await stub.get( + "https://localhost/private", + cookies={ + "Klein-Secure-Session": nativeString( + whichSession.identifier + ) + }, + ) + self.assertEqual(response.code, code) + self.assertIn(contents, await content(response)) + + # we can see it + await check(session, 200, b"itsa me") + # other session can't see it + await check(otherSession, 401, b"DENIED") + + # we'll use a different password in a sec + newPw = "differentstuff" + response = await stub.post( + "https://localhost/change-password", + data=dict( + newPassword=newPw, + __csrf_protection__=session.identifier, + ), + cookies=cookies, + ) + + response = await stub.post("https://localhost/logout", cookies=cookies) + self.assertEqual(200, response.code) + self.assertIn(b"unbound", await content(response)) + # log out and we can't see it again + await check(session, 401, b"DENIED") + + await badLogin("itsame", "secretstuff") + response = await stub.post( + "https://localhost/login", + data=dict( + username="itsme", + password=newPw, + __csrf_protection__=session.identifier, + ), + cookies=cookies, + ) + self.assertEqual(200, response.code) + # logged in again + self.assertIn(b"logged in", await content(response)) + self.assertEqual(to.boundAccounts, [to.loggedInAs]) + self.assertEqual( + {cookie.value for cookie in response.cookies()}, + {session.identifier}, + ) + + # sending insecure tokens should invalidate our session + response = await stub.get("http://localhost/private", cookies=cookies) + self.assertEqual(response.code, 401) + self.assertIn(b"DENIED", await content(response)) + + # sending invalid tokens insecurely should be like sending no tokens + # (i.e. this happens when you clear a database, or restart an in-memory + # server) + + response = await stub.get( + "http://localhost/private", + cookies={"Klein-Secure-Session": "never seen this session"}, + ) + self.assertEqual(response.code, 401) + self.assertIn(b"DENIED", await content(response)) + + response = await stub.get("https://localhost/private", cookies=cookies) + # jar = response.cookies() + # self.assertEqual() + body = await content(response) + self.assertEqual(response.code, 401) + self.assertIn(b"DENIED", body) + self.assertNotIn( + session.identifier, {cookie.value for cookie in response.cookies()} + ) + + def test_memoryStore(self) -> None: + """ + Test that L{MemoryAccountStore} can store simple accounts and bindings. + """ + users = MemoryAccountStore() + users.addAccount("itsme", "secretstuff") + sessions = MemorySessionStore.fromAuthorizers(users.authorizers()) + self.successResultOf( + Deferred.fromCoroutine( + self.authWithStoreTest( + sessions.newSession, SessionProcurer(sessions) + ) + ) + ) + + @immediateTest() + async def test_sqlStore(self, pool: MemoryPool) -> None: + """ + Test that L{procurerFromConnectable} gives us a usable session procurer. + """ + + await applyBasicSchema(pool.connectable) + + def asyncify(x: T) -> Callable[[], Awaitable[T]]: + async def get() -> T: + return x + + return get + + async def newSession( + isSecure: bool, mechanism: SessionMechanism + ) -> ISession: + async with transaction(pool.connectable) as c: + return await SessionStore( + asyncify(c), [], engineForTesting(self) + ).newSession(isSecure, mechanism) + + async with transaction(pool.connectable) as c: + sampleStore = SessionStore(asyncify(c), [], engineForTesting(self)) + sampleSession = await newSession(True, SessionMechanism.Cookie) + b = AccountSessionBinding(sampleStore, sampleSession, c) + self.assertIsNot( + await b.createAccount( + "itsme", "ignore@example.com", "secretstuff" + ), + None, + ) + async with transaction(pool.connectable) as c: + self.assertIs( + await b.createAccount("itsme", "somethingelse", "whatever"), + None, + ) + + self.assertEqual(hashUpgradeCount(self), 0) + proc = SQLSessionProcurer( + pool.connectable, [], engineForTesting(self, upgradeHashes=True) + ) + await self.authWithStoreTest(newSession, proc, pool) + self.assertEqual(hashUpgradeCount(self), 1) diff --git a/src/klein/test/test_plating.py b/src/klein/test/test_plating.py index 5cf795bbf..0463ef527 100644 --- a/src/klein/test/test_plating.py +++ b/src/klein/test/test_plating.py @@ -12,7 +12,7 @@ from twisted.trial.unittest import SynchronousTestCase from twisted.trial.unittest import TestCase as AsynchronousTestCase from twisted.web.error import FlattenerError, MissingRenderMethod -from twisted.web.template import slot, tags +from twisted.web.template import Tag, slot, tags from .. import Klein, Plating from .._plating import ATOM_TYPES, PlatedElement, resolveDeferredObjects @@ -552,6 +552,38 @@ def rsrc(request): }, ) + def test_platingFragmentList(self) -> None: + """ + A function decorated with L{Plating.fragment} can serve as a list. + """ + + @page.fragment + def color(r: str, g: str, b: str) -> Tag: + return tags.div("red: ", r, " green: ", g, " blue: ", b) + + @page.routed( + self.app.route("/"), tags.div(render="colors:list")(slot("item")) + ) + def plateMe(result): + return { + "colors": [ + color("1", "2", "3"), + color(r="4", g="5", b="6"), + ], + } + + _, writtenJSON = self.get(b"/?json=1") + _, writtenHTML = self.get(b"/") + self.assertEqual( + json.loads(writtenJSON)["colors"], + [{"r": "1", "g": "2", "b": "3"}, {"r": "4", "g": "5", "b": "6"}], + ) + self.assertIn( + b"
red: 1 green: 2 blue: 3
" + b"
red: 4 green: 5 blue: 6
", + writtenHTML, + ) + def test_prime_directive_return(self): """ Nothing within these Articles Of Federation shall authorize the United @@ -626,3 +658,18 @@ def no(request): test("garbage") test("garbage:missing") + + def test_alternateReturn(self) -> None: + """ + If a L{Plating.routed} route returns a klein renderable value other + than a mutable mapping, it will be returned as if the route were not + plated at all. + """ + + @page.routed(self.app.route("/"), tags.span(slot("ok"))) + def plateMe(request): + return "oops, not a dict!" + + request, written = self.get(b"/") + + self.assertEqual(b"oops, not a dict!", written) diff --git a/tox.ini b/tox.ini index 0d09311ca..e77223d95 100644 --- a/tox.ini +++ b/tox.ini @@ -3,6 +3,7 @@ envlist = lint, mypy coverage-py{37,38,39,310,311,312,py3}-tw{212,221,238,trunk} + coverage_combine coverage_report docs, docs-linkcheck packaging @@ -18,14 +19,13 @@ deps = tw212: Twisted==21.2.0 tw221: Twisted==22.1.0 twcurrent: Twisted - # See https://github.com/twisted/klein/issues/486 - twtrunk: --use-deprecated=legacy-resolver twtrunk: https://github.com/twisted/twisted/tarball/trunk#egg=Twisted -r requirements/tox-pin-base.txt {test,coverage}: -r requirements/tox-tests.txt coverage: {[testenv:coverage_report]deps} + coverage: coverage_enable_subprocess setenv = PY_MODULE=klein @@ -59,33 +59,14 @@ deps = {[default]deps} setenv = {[default]setenv} - coverage: COVERAGE_FILE={toxworkdir}/coverage.{envname} coverage: COVERAGE_PROCESS_START={toxinidir}/.coveragerc - TRIAL_ARGS={env:TRIAL_ARGS:} + TRIAL_ARGS={env:TRIAL_ARGS:--jobs=2} commands = # Run trial without coverage test: python -b "{envdir}/bin/trial" --random=0 {env:TRIAL_ARGS} --temp-directory="{envlogdir}/trial.d" {posargs:{env:PY_MODULE}} - - # Run trial with coverage - # Notes: - # - Because we run tests in parallel, which uses multiple subprocesses, - # we need to drop in a .pth file that causes coverage to start when - # Python starts. See: - # https://coverage.readthedocs.io/en/coverage-5.5/subprocess.html - # - We use coverage in parallel mode, then combine here to get the results - # to get a unified result for the current test environment. - # - Use `tox -e coverage_report` to generate a report for all environments. - coverage: python -c 'f=open("{envsitepackagesdir}/zz_coverage.pth", "w"); f.write("import coverage; coverage.process_startup()\n")' - coverage: coverage erase - coverage: python -b -m coverage run --source="{env:PY_MODULE}" "{envdir}/bin/trial" --random=0 {env:TRIAL_ARGS} --temp-directory="{envlogdir}/trial.d" {posargs:{env:PY_MODULE}} - coverage: coverage combine - coverage: coverage xml - - # Run coverage reports, ignore exit status - coverage: - coverage report --skip-covered - + coverage: coverage run "{envdir}/bin/trial" --random=0 {env:TRIAL_ARGS} --temp-directory="{envlogdir}/trial.d" {posargs:{env:PY_MODULE}} ## # Lint @@ -134,13 +115,24 @@ commands = # Coverage report ## +[testenv:coverage_combine] +commands = coverage combine +basepython={[default]basepython} +deps = coverage +depends = + coverage-py{37,38,39,310,311,py3}-tw{212,221,2310,trunk} + + [testenv:coverage_report] description = generate coverage report depends = + coverage-py{37,38,39,310,311,py3}-tw{1,2}{0,1,2,3,4,5,6,7,8,9}{0,1,2,3,4,5,6,7,8,9} + coverage-py{37,38,39,310,311,py3}-tw{2310,trunk} coverage-py{37,38,39,310,311,312,py3}-tw{1,2}{0,1,2,3,4,5,6,7,8,9}{0,1,2,3,4,5,6,7,8,9} - coverage-py{37,38,39,310,311,312,py3}-tw{current,trunk} + coverage-py{37,38,39,310,311,312,py3}-tw{2310,trunk} + coverage_combine basepython = {[default]basepython} @@ -152,11 +144,9 @@ deps = setenv = {[default]setenv} - COVERAGE_FILE={toxworkdir}/coverage - commands = - coverage combine - - coverage report + - coverage xml + - coverage report --skip-covered - coverage html @@ -243,7 +233,7 @@ description = check for potential packaging problems depends = coverage-py{37,38,39,310,311,py3}-tw{1,2}{0,1,2,3,4,5,6,7,8,9}{0,1,2,3,4,5,6,7,8,9} - coverage-py{37,38,39,310,311,py3}-tw{current,trunk} + coverage-py{37,38,39,310,311,py3}-tw{2310,trunk} basepython = {[default]basepython}