Skip to content

Commit

Permalink
added CsrfHeaderMiddlaware
Browse files Browse the repository at this point in the history
  • Loading branch information
olegbaturin committed Sep 27, 2024
1 parent b26cf68 commit 78afec8
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 1 deletion.
92 changes: 92 additions & 0 deletions src/CsrfHeaderMiddleware.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
<?php

declare(strict_types=1);

namespace Yiisoft\Csrf;

use Psr\Http\Message\ResponseFactoryInterface;
use Psr\Http\Message\ResponseInterface;
use Psr\Http\Message\ServerRequestInterface;
use Psr\Http\Server\MiddlewareInterface;
use Psr\Http\Server\RequestHandlerInterface;
use Yiisoft\Http\Method;
use Yiisoft\Http\Status;

use function count;
use function in_array;

/**
* PSR-15 middleware that takes care of HTTP header validation.
*
* @link https://www.php-fig.org/psr/psr-15/
* @link https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#employing-custom-request-headers-for-ajaxapi
*/
final class CsrfHeaderMiddleware implements MiddlewareInterface
{
public const HEADER_NAME = 'X-CSRF-Token';

private string $headerName = self::HEADER_NAME;
private array $safeMethods = [Method::OPTIONS];

private ResponseFactoryInterface $responseFactory;
private ?RequestHandlerInterface $failureHandler;

public function __construct(
ResponseFactoryInterface $responseFactory,
RequestHandlerInterface $failureHandler = null
) {
$this->responseFactory = $responseFactory;
$this->failureHandler = $failureHandler;
}

public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface
{
if ($this->validateCsrfToken($request)) {
return $handler->handle($request);
}

if ($this->failureHandler !== null) {
return $this->failureHandler->handle($request);
}

$response = $this->responseFactory->createResponse(Status::UNPROCESSABLE_ENTITY);
$response
->getBody()
->write(Status::TEXTS[Status::UNPROCESSABLE_ENTITY]);
return $response;
}

public function withHeaderName(string $name): self
{
$new = clone $this;
$new->headerName = $name;
return $new;
}

public function withSafeMethods(array $methods): self
{
$new = clone $this;
$new->safeMethods = $methods;
return $new;
}

public function getHeaderName(): string
{
return $this->headerName;
}

public function getSafeMethods(): array
{
return $this->safeMethods;
}

private function validateCsrfToken(ServerRequestInterface $request): bool
{
if (in_array($request->getMethod(), $this->safeMethods, true)) {
return true;
}

$headers = $request->getHeader($this->headerName);
return (bool) count($headers);
}
}
15 changes: 14 additions & 1 deletion src/CsrfMiddleware.php
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ final class CsrfMiddleware implements MiddlewareInterface

private string $parameterName = self::PARAMETER_NAME;
private string $headerName = self::HEADER_NAME;
private array $safeMethods = [Method::GET, Method::HEAD, Method::OPTIONS];

private ResponseFactoryInterface $responseFactory;
private CsrfTokenInterface $token;
Expand Down Expand Up @@ -73,6 +74,13 @@ public function withHeaderName(string $name): self
return $new;
}

public function withSafeMethods(array $methods): self
{
$new = clone $this;
$new->safeMethods = $methods;
return $new;
}

public function getParameterName(): string
{
return $this->parameterName;
Expand All @@ -83,9 +91,14 @@ public function getHeaderName(): string
return $this->headerName;
}

public function getSafeMethods(): array
{
return $this->safeMethods;
}

private function validateCsrfToken(ServerRequestInterface $request): bool
{
if (in_array($request->getMethod(), [Method::GET, Method::HEAD, Method::OPTIONS], true)) {
if (in_array($request->getMethod(), $this->safeMethods, true)) {
return true;
}

Expand Down
147 changes: 147 additions & 0 deletions tests/CsrfHeaderMiddlewareProcessTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
<?php

declare(strict_types=1);

namespace Yiisoft\Csrf\Tests;

use Nyholm\Psr7\Factory\Psr17Factory;
use Nyholm\Psr7\Response;
use Nyholm\Psr7\ServerRequest;
use PHPUnit\Framework\TestCase;
use Psr\Http\Message\ResponseInterface;
use Psr\Http\Message\ServerRequestInterface;
use Psr\Http\Server\RequestHandlerInterface;
use Yiisoft\Csrf\CsrfHeaderMiddleware;
use Yiisoft\Http\Method;
use Yiisoft\Http\Status;
use Yiisoft\Security\Random;

final class CsrfHeaderMiddlewareProcessTest extends TestCase
{
public function testOptionsIsAlwaysAllowed(): void
{
$middleware = $this->createMiddleware();
$response = $middleware->process(
$this->createServerRequest(Method::OPTIONS),
$this->createRequestHandler()
);
$this->assertEquals(200, $response->getStatusCode());
}

public function testCustomSafeGetRequestResultIn200(): void
{
$middleware = $this
->createMiddleware()
->withSafeMethods([Method::GET, Method::HEAD, Method::OPTIONS]);
$response = $middleware->process(
$this->createServerRequest(Method::GET),
$this->createRequestHandler()
);
$this->assertEquals(200, $response->getStatusCode());
}

public function testUnsafeMethodGetRequestResultIn422(): void
{
$middleware = $this->createMiddleware();
$response = $middleware->process(
$this->createServerRequest(Method::GET),
$this->createRequestHandler()
);
$this->assertEquals(Status::TEXTS[Status::UNPROCESSABLE_ENTITY], $response->getBody());
$this->assertEquals(Status::UNPROCESSABLE_ENTITY, $response->getStatusCode());
}

public function testCustomUnsafeMethodPostRequestResultIn422(): void
{
$middleware = $this
->createMiddleware()
->withSafeMethods([Method::GET, Method::HEAD, Method::OPTIONS]);
$response = $middleware->process(
$this->createServerRequest(Method::POST),
$this->createRequestHandler()
);
$this->assertEquals(Status::TEXTS[Status::UNPROCESSABLE_ENTITY], $response->getBody());
$this->assertEquals(Status::UNPROCESSABLE_ENTITY, $response->getStatusCode());
}

public function testValidCustomHeaderResultIn200(): void
{
$headerName = 'X-MY-CSRF';

$middleware = $this
->createMiddleware()
->withHeaderName($headerName)
->withSafeMethods([Method::GET]);
$response = $middleware->process(
$this->createServerRequest(Method::GET, [$headerName => Random::string()]),
$this->createRequestHandler()
);
$this->assertEquals(200, $response->getStatusCode());
}

public function testEmptyTokenInRequestResultIn200(): void
{
$middleware = $this->createMiddleware();
$response = $middleware->process(
$this->createServerRequest(Method::GET, [CsrfHeaderMiddleware::HEADER_NAME => '']),
$this->createRequestHandler()
);
$this->assertEquals(200, $response->getStatusCode());
}

public function testInvalidHeaderResultIn422(): void
{
$middleware = $this->createMiddleware();
$response = $middleware->process(
$this->createServerRequest(Method::POST, ['X-MY-CSRF' => '']),
$this->createRequestHandler()
);
$this->assertEquals(Status::UNPROCESSABLE_ENTITY, $response->getStatusCode());
$this->assertEquals(Status::TEXTS[Status::UNPROCESSABLE_ENTITY], $response->getBody());
}

public function testInvalidHeaderResultWithCustomFailureHandler(): void
{
$failureHandler = new class () implements RequestHandlerInterface {
public function handle(ServerRequestInterface $request): ResponseInterface
{
$response = new Response(Status::BAD_REQUEST);
$response
->getBody()
->write(Status::TEXTS[Status::BAD_REQUEST]);
return $response;
}
};
$middleware = $this->createMiddleware($failureHandler);
$response = $middleware->process(
$this->createServerRequest(Method::POST, ['X-MY-CSRF' => '']),
$this->createRequestHandler(),
);
$this->assertEquals(Status::BAD_REQUEST, $response->getStatusCode());
$this->assertEquals(Status::TEXTS[Status::BAD_REQUEST], $response->getBody());
}

private function createMiddleware(
RequestHandlerInterface $failureHandler = null
): CsrfHeaderMiddleware
{
return new CsrfHeaderMiddleware(new Psr17Factory(), $failureHandler);
}

private function createRequestHandler(): RequestHandlerInterface
{
$requestHandler = $this->createMock(RequestHandlerInterface::class);
$requestHandler
->method('handle')
->willReturn(new Response(200));

return $requestHandler;
}

private function createServerRequest(
string $method = Method::GET,
array $headParams = []
): ServerRequestInterface {
return new ServerRequest($method, '/', $headParams);
}
}
54 changes: 54 additions & 0 deletions tests/CsrfHeaderMiddlewareTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
<?php

declare(strict_types=1);

namespace Yiisoft\Csrf\Tests;

use Nyholm\Psr7\Factory\Psr17Factory;
use PHPUnit\Framework\TestCase;
use Yiisoft\Csrf\CsrfHeaderMiddleware;
use Yiisoft\Http\Method;

final class CsrfHeaderMiddlewareTest extends TestCase
{
public function testDefaultHeaderName(): void
{
$middleware = $this->createMiddleware();
$this->assertSame(CsrfHeaderMiddleware::HEADER_NAME, $middleware->getHeaderName());
}

public function testGetHeaderName(): void
{
$middleware = $this
->createMiddleware()
->withHeaderName('X-MY-CSRF');
$this->assertSame('X-MY-CSRF', $middleware->getHeaderName());
}

public function testImmutability(): void
{
$original = $this->createMiddleware();
$this->assertNotSame($original, $original->withHeaderName('X-MY-CSRF'));
$this->assertNotSame($original, $original->withSafeMethods([Method::HEAD]));
}

public function testDefaultSafeMethods(): void
{
$middleware = $this->createMiddleware();
$this->assertSame([Method::OPTIONS], $middleware->getSafeMethods());
}

public function testGetSafeMethods(): void
{
$methods = [Method::GET, Method::HEAD, Method::OPTIONS];
$middleware = $this
->createMiddleware()
->withSafeMethods($methods);
$this->assertSame($methods, $middleware->getSafeMethods());
}

private function createMiddleware(): CsrfHeaderMiddleware
{
return new CsrfHeaderMiddleware(new Psr17Factory());
}
}
17 changes: 17 additions & 0 deletions tests/CsrfMiddlewareTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
use Yiisoft\Csrf\Synchronizer\Generator\RandomCsrfTokenGenerator;
use Yiisoft\Csrf\Synchronizer\SynchronizerCsrfToken;
use Yiisoft\Csrf\Tests\Synchronizer\Storage\MockCsrfTokenStorage;
use Yiisoft\Http\Method;

final class CsrfMiddlewareTest extends TestCase
{
Expand Down Expand Up @@ -46,6 +47,22 @@ public function testImmutability(): void
$original = $this->createMiddleware();
$this->assertNotSame($original, $original->withHeaderName('csrf'));
$this->assertNotSame($original, $original->withParameterName('csrf'));
$this->assertNotSame($original, $original->withSafeMethods([Method::HEAD]));
}

public function testDefaultSafeMethods(): void
{
$middleware = $this->createMiddleware();
$this->assertSame([Method::GET, Method::HEAD, Method::OPTIONS], $middleware->getSafeMethods());
}

public function testGetSafeMethods(): void
{
$methods = [Method::OPTIONS];
$middleware = $this
->createMiddleware()
->withSafeMethods($methods);
$this->assertSame($methods, $middleware->getSafeMethods());
}

private function createMiddleware(): CsrfMiddleware
Expand Down
Loading

0 comments on commit 78afec8

Please sign in to comment.