Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for pgpass #666

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 67 additions & 4 deletions connectorx-python/connectorx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import importlib
import pathlib
import urllib.parse

from importlib.metadata import version
Expand All @@ -24,16 +25,16 @@
# only for typing hints
from .connectorx import _DataframeInfos, _ArrowInfos


__version__ = version(__name__)

import os
import sys

dir_path = os.path.dirname(os.path.realpath(__file__))
# check whether it is in development env or installed
if (
not os.path.basename(os.path.abspath(os.path.join(dir_path, "..")))
== "connectorx-python"
not os.path.basename(os.path.abspath(os.path.join(dir_path, "..")))
== "connectorx-python"
):
os.environ.setdefault("J4RS_BASE_PATH", os.path.join(dir_path, "dependencies"))

Expand All @@ -43,7 +44,6 @@

Protocol = Literal["csv", "binary", "cursor", "simple", "text"]


_BackendT = TypeVar("_BackendT")


Expand Down Expand Up @@ -247,6 +247,66 @@ def read_sql(
) -> pl.DataFrame: ...


def get_passfile_content(path: pathlib.Path) -> dict:
# host:port:db_name:user_name:password
with open(path) as f:
Copy link
Contributor

@aimtsou aimtsou Jul 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we do something different? the content we need from pgpass is the password only the restof the elements are used for retrieving it. So:

            for line in f:
                contents = line.read().split(':')
                if len(contents) != 5:
                    raise Exception('Pgpass content should follow: host:port:db_name:user_name:password')
                else:
                    if ((contents['host'] == host or contents['host'] == '*') and
                        (contents['port'] == port or contents['port'] == '*') and
                        (contents['database'] == database or contents['database'] == '*') and
                        (contents['user'] == user or contents['user'] == '*')):
                        return contents['password']

So in this way we do also multiple lines and we return the first line matched.

contents = f.read().split(':')
if len(contents) != 5:
raise Exception('Pgpass content should follow: host:port:db_name:user_name:password')
return dict(zip(['hostname', 'port', 'path', 'username', 'password'], contents))


def get_pgpass(conn) -> dict:
# check param or (PGPASS env or DEFAULT)
# DEFAULT = linux or windows
passfile = '%APPDATA%\postgresql\pgpass.conf' if sys.platform == 'windows' else os.path.expanduser(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of sys.platform maybe we can use os.name == 'nt'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it better?

'~/') + '.pgpass'
if 'passfile' in conn:
parsed_conn = urllib.parse.urlparse(conn)

# test if there is no '&'
for param in parsed_conn.params.split('&'):
k, v = param.split('=')
if k == 'passfile':
passfile = k

passfile_path = pathlib.Path(passfile)

if sys.platform != 'windows':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here: os.name != 'nt'

# In *nix platforms we check that the file is safe; it has to have 0600 permission.
# https://www.postgresql.org/docs/current/libpq-pgpass.html

try:
# We trust that the last three digits of st_mode are the permissions, e.g.
st_mode = int(oct(passfile_path.stat().st_mode)[-3:])

except Exception as e:
raise Exception(
'Could not check if file is safe, report to this to the maintainers please.') from e

if st_mode != 600:
raise Exception(
f'pgpass file does not have safe permissions (0600), it currently has "0{st_mode}" you can fix this buy running: $ chmod 0600 PASSFILE_PATH')

return get_passfile_content(passfile)


def replace_conn_content(conn: str, contents: dict) -> str:
# We rewrite the netloc from scratch here e.g.
# netloc = contents['username'] or o.username + ':' + contents['password'] or o.password + ...
parsed_conn = urllib.parse.urlparse(conn)
return str(parsed_conn._replace(**contents))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do the following:
database = parsed_conn.path[1:] or os.environ.get('PGDATABASE', os.getenv('USER'))
host = parsed_conn.hostname or os.environ.get('PGHOST', 'localhost')
port = parsed_conn.port or os.environ.get('PGPORT', '5432')
user = parsed_conn.username or os.environ.get('PGUSER', os.getenv('USER'))
password = parsed_conn.password

We parse 2 times the URI so maybe we need to rethink that maybe we call pgpass from inside reconstruct? That would allow something like that also:

query_params = parse_qs(parsed_con.query)

passfile = query_params.get('passfile', [None])[0] or os.environ.get('PGPASSFILE')
default_pgpass_path = os.path.expanduser('~/.pgpass') if os.name != 'nt' else os.path.join(os.getenv('APPDATA', ''), 'postgresql', 'pgpass.conf')

# Determine the password precedence
if not password:
    if 'PGPASSWORD' in os.environ:
        password = os.environ['PGPASSWORD']
    else:
        pgpass_path = passfile or default_pgpass_path
        password = get_passfile_content(pgpass_path, host or '*', str(port or '*'), database or '*', user or '*')

And with all these information we can get the new connection string.



def run_per_database(conn):
# Todo rename to something better.
if 'postgresql' in conn:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we ought to change it to postgres

contents = get_pgpass(conn)
conn = replace_conn_content(conn, contents)
return conn
return conn


def read_sql(
conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl],
query: list[str] | str,
Expand Down Expand Up @@ -332,6 +392,9 @@ def read_sql(
df = pl.DataFrame.from_arrow(df)
return df

# Rewrite conn.
conn = run_per_database(conn)

if isinstance(query, str):
query = remove_ending_semicolon(query)

Expand Down
Loading