Skip to content

Commit

Permalink
Correct live/ready checks to consider special host values (#2182)
Browse files Browse the repository at this point in the history
  • Loading branch information
steve-chavez authored Mar 11, 2022
1 parent c803c4d commit a3c1d99
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 18 deletions.
49 changes: 34 additions & 15 deletions src/PostgREST/Admin.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE RecordWildCards #-}
module PostgREST.Admin
( postgrestAdmin
) where
Expand All @@ -21,7 +22,7 @@ import Protolude
-- | PostgREST admin application
postgrestAdmin :: AppState.AppState -> AppConfig -> Wai.Application
postgrestAdmin appState appConfig req respond = do
isMainAppReachable <- isRight <$> reachMainApp appConfig
isMainAppReachable <- any isRight <$> reachMainApp appConfig
isSchemaCacheLoaded <- isJust <$> AppState.getDbStructure appState
isConnectionUp <-
if configDbChannelEnabled appConfig
Expand All @@ -38,19 +39,37 @@ postgrestAdmin appState appConfig req respond = do

-- Try to connect to the main app socket
-- Note that it doesn't even send a valid HTTP request, we just want to check that the main app is accepting connections
reachMainApp :: AppConfig -> IO (Either IOException ())
reachMainApp appConfig =
try . withSocketsDo $ bracket open close sendEmpty
where
open = case configServerUnixSocket appConfig of
Just path -> do
sock <- socket AF_UNIX Stream 0
-- The code for resolving the "*4", "!4", "*6", "!6", "*" special values is taken from
-- https://hackage.haskell.org/package/streaming-commons-0.2.2.4/docs/src/Data.Streaming.Network.html#bindPortGenEx
reachMainApp :: AppConfig -> IO [Either IOException ()]
reachMainApp AppConfig{..} =
case configServerUnixSocket of
Just path -> do
sock <- socket AF_UNIX Stream 0
(:[]) <$> try (do
connect sock $ SockAddrUnix path
return sock
Nothing -> do
let hints = defaultHints { addrSocketType = Stream }
addr:_ <- getAddrInfo (Just hints) (Just . T.unpack $ configServerHost appConfig) (Just . show $ configServerPort appConfig)
sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
connect sock $ addrAddress addr
return sock
withSocketsDo $ bracket (pure sock) close sendEmpty)
Nothing -> do
let
host | configServerHost `elem` ["*4", "!4", "*6", "!6", "*"] = Nothing
| otherwise = Just configServerHost
filterAddrs xs =
case configServerHost of
"*4" -> ipv4Addrs xs ++ ipv6Addrs xs
"!4" -> ipv4Addrs xs
"*6" -> ipv6Addrs xs ++ ipv4Addrs xs
"!6" -> ipv6Addrs xs
_ -> xs
ipv4Addrs xs = filter ((/=) AF_INET6 . addrFamily) xs
ipv6Addrs xs = filter ((==) AF_INET6 . addrFamily) xs

addrs <- getAddrInfo (Just $ defaultHints { addrSocketType = Stream }) (T.unpack <$> host) (Just . show $ configServerPort)
tryAddr `traverse` filterAddrs addrs
where
sendEmpty sock = void $ send sock mempty
tryAddr :: AddrInfo -> IO (Either IOException ())
tryAddr addr = do
sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
try $ do
connect sock $ addrAddress addr
withSocketsDo $ bracket (pure sock) close sendEmpty
2 changes: 1 addition & 1 deletion src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ run installHandlers maybeRunWithSocket appState = do
AppState.logWithZTime appState $ "Listening on unix socket " <> show socket
runWithSocket (serverSettings conf) app configServerUnixSocketMode socket
Nothing ->
panic "Cannot run with socket on non-unix plattforms."
panic "Cannot run with unix socket on non-unix plattforms."
Nothing ->
do
AppState.logWithZTime appState $ "Listening on port " <> show configServerPort
Expand Down
7 changes: 7 additions & 0 deletions test/io/fixtures.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,10 @@ invalidjointypes:
- 'left!'
- 'right'
- '.#$$%&$%/'

specialhostvalues:
- '*4'
- '!4'
- '*6'
- '!6'
- '*'
15 changes: 13 additions & 2 deletions test/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def dumpconfig(configpath=None, env=None, stdin=None):


@contextlib.contextmanager
def run(configpath=None, stdin=None, env=None, port=None):
def run(configpath=None, stdin=None, env=None, port=None, host=None):
"Run PostgREST and yield an endpoint that is ready for connections."
env = env or {}
env["PGRST_DB_POOL"] = "1"
Expand All @@ -149,7 +149,7 @@ def run(configpath=None, stdin=None, env=None, port=None):
with tempfile.TemporaryDirectory() as tmpdir:
if port:
env["PGRST_SERVER_PORT"] = str(port)
env["PGRST_SERVER_HOST"] = "localhost"
env["PGRST_SERVER_HOST"] = host or "localhost"
baseurl = f"http://localhost:{port}"
else:
socketfile = pathlib.Path(tmpdir) / "postgrest.sock"
Expand Down Expand Up @@ -883,6 +883,17 @@ def test_admin_live_dependent_on_main_app(defaultenv):
response = postgrest.admin.get("/live")
assert response.status_code == 503

@pytest.mark.parametrize("specialhostvalue", FIXTURES["specialhostvalues"])
def test_admin_works_with_host_special_values(specialhostvalue, defaultenv):
"Should get a success from the admin live and ready endpoints when using special host values for the main app"

with run(env=defaultenv, port=freeport(), host=specialhostvalue) as postgrest:

response = postgrest.admin.get("/live")
assert response.status_code == 200

response = postgrest.admin.get("/ready")
assert response.status_code == 200

@pytest.mark.parametrize(
"level, has_output",
Expand Down

0 comments on commit a3c1d99

Please sign in to comment.