From f3ac5598448a0c2080ee267f00931ed556a6779d Mon Sep 17 00:00:00 2001 From: "g.denis" Date: Fri, 17 Apr 2020 23:52:19 +0200 Subject: Add pg_service.conf handling (#1155) * add parse_service_info * added tests * changelog + AUTHORS * py35 --- changelog.rst | 1 + pgcli/main.py | 51 +++++++++++++++++++++++++++++++++++++++++++++++---- tests/test_main.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 4 deletions(-) diff --git a/changelog.rst b/changelog.rst index c197de2f..62a68b41 100644 --- a/changelog.rst +++ b/changelog.rst @@ -7,6 +7,7 @@ Features: * Add `__main__.py` file to execute pgcli as a package directly (#1123). * Add support for ANSI escape sequences for coloring the prompt (#1122). * Add support for partitioned tables (relkind "p"). +* Add support for `pg_service.conf` files Bug fixes: diff --git a/pgcli/main.py b/pgcli/main.py index 08c6c491..cf21bc75 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -1,5 +1,8 @@ +import platform import warnings +from os.path import expanduser +from configobj import ConfigObj from pgspecial.namedqueries import NamedQueries warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2") @@ -470,6 +473,21 @@ class PGCli(object): def connect_dsn(self, dsn, **kwargs): self.connect(dsn=dsn, **kwargs) + def connect_service(self, service, user): + service_config, file = parse_service_info(service) + if service_config is None: + click.secho( + "service '%s' was not found in %s" % (service, file), err=True, fg="red" + ) + exit(1) + self.connect( + database=service_config.get("dbname"), + host=service_config.get("host"), + user=user or service_config.get("user"), + port=service_config.get("port"), + passwd=service_config.get("password"), + ) + def connect_uri(self, uri): kwargs = psycopg2.extensions.parse_dsn(uri) remap = {"dbname": "database", "password": "passwd"} @@ -1248,7 +1266,11 @@ def cli( username = dbname database = dbname_opt or dbname or "" user = username_opt or username - + service = None + if database.startswith("service="): + service = database[8:] + elif os.getenv("PGSERVICE") is not None: + service = os.getenv("PGSERVICE") # because option --list or -l are not supposed to have a db name if list_databases: database = "postgres" @@ -1269,10 +1291,10 @@ def cli( pgcli.dsn_alias = dsn elif "://" in database: pgcli.connect_uri(database) - elif "=" in database: + elif "=" in database and service is None: pgcli.connect_dsn(database, user=user) - elif os.environ.get("PGSERVICE", None): - pgcli.connect_dsn("service={0}".format(os.environ["PGSERVICE"])) + elif service is not None: + pgcli.connect_service(service, user) else: pgcli.connect(database, host, user, port) @@ -1446,5 +1468,26 @@ def format_output(title, cur, headers, status, settings): return output +def parse_service_info(service): + service = service or os.getenv("PGSERVICE") + service_file = os.getenv("PGSERVICEFILE") + if not service_file: + # try ~/.pg_service.conf (if that exists) + if platform.system() == "Windows": + service_file = os.getenv("PGSYSCONFDIR") + "\\pg_service.conf" + elif os.getenv("PGSYSCONFDIR"): + service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf") + else: + service_file = expanduser("~/.pg_service.conf") + if not service: + # nothing to do + return None, service_file + service_file_config = ConfigObj(service_file) + if service not in service_file_config: + return None, service_file + service_conf = service_file_config.get(service) + return service_conf, service_file + + if __name__ == "__main__": cli() diff --git a/tests/test_main.py b/tests/test_main.py index 044181b1..9b85a34b 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -282,6 +282,54 @@ def test_quoted_db_uri(tmpdir): ) +def test_pg_service_file(tmpdir): + + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf: + service_conf.write( + """[myservice] + host=a_host + user=a_user + port=5433 + password=much_secure + dbname=a_dbname + + [my_other_service] + host=b_host + user=b_user + port=5435 + dbname=b_dbname + """ + ) + os.environ["PGSERVICEFILE"] = tmpdir.join(".pg_service.conf").strpath + cli.connect_service("myservice", "another_user") + mock_connect.assert_called_with( + database="a_dbname", + host="a_host", + user="another_user", + port="5433", + passwd="much_secure", + ) + + with mock.patch.object(PGExecute, "__init__") as mock_pgexecute: + mock_pgexecute.return_value = None + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + os.environ["PGPASSWORD"] = "very_secure" + cli.connect_service("my_other_service", None) + mock_pgexecute.assert_called_with( + "b_dbname", + "b_user", + "very_secure", + "b_host", + "5435", + "", + application_name="pgcli", + ) + del os.environ["PGPASSWORD"] + del os.environ["PGSERVICEFILE"] + + def test_ssl_db_uri(tmpdir): with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) -- cgit v1.2.3