Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytest: Add system-postgres DB provider and allow pyln-testing to run on older CLN versions too #6947

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
4 changes: 4 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
with:
ref: ${{ github.event.pull_request.head.sha }}
fetch-depth: 0
fetch-tags: true

- name: Rebase
# We can't rebase if we're on master already.
Expand Down Expand Up @@ -94,6 +95,9 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
fetch-tags: true

- name: Set up Python 3.8
uses: actions/setup-python@v4
Expand Down
47 changes: 47 additions & 0 deletions contrib/pyln-testing/pyln/testing/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def stop(self):
cur = conn.cursor()
cur.execute("DROP DATABASE {};".format(self.dbname))
cur.close()
conn.close()

def wipe_db(self):
cur = self.conn.cursor()
Expand Down Expand Up @@ -240,3 +241,49 @@ def stop(self):
self.proc.send_signal(signal.SIGINT)
self.proc.wait()
shutil.rmtree(self.pgdir)


class SystemPostgresProvider(PostgresDbProvider):
"""A Postgres DB variant that uses an existing instance on the system.

The specifics of the DB admin connection are passed in via
`CLN_TEST_POSTGRES_DSN`, or uses `dbname=template1 user=postgres
host=localhost port=5432` by default. The DSN must have permission
to create new roles and schemas.

Currently only supports postgres instances running on the default
port (5432).

"""

def __init__(self, directory):
self.directory = directory
self.dbs: List[str] = []
self.admin_dsn = os.environ.get(
"CLN_TEST_POSTGRES_DSN",
"dbname=template1 user=postgres host=127.0.0.1 port=5432"
)
self.port = 5432

def start(self):
"""We assume the postgres instance is already running, so this is a no-op. """

def connect(self):
return psycopg2.connect(self.admin_dsn)

def get_db(self, node_directory, testname, node_id):
# Random suffix to avoid collisions on repeated tests
nonce = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(8))
dbname = "{}_{}_{}".format(testname, node_id, nonce)

conn = self.connect()
conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
cur = conn.cursor()
cur.execute("CREATE DATABASE {} TEMPLATE template0;".format(dbname))
cur.close()
conn.close()
db = PostgresDb(dbname, self.port)
return db

def stop(self):
"""Cleanup the schemas we created. """
3 changes: 2 additions & 1 deletion contrib/pyln-testing/pyln/testing/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from concurrent import futures
from pyln.testing.db import SqliteDbProvider, PostgresDbProvider
from pyln.testing.db import SqliteDbProvider, PostgresDbProvider, SystemPostgresProvider
from pyln.testing.utils import NodeFactory, BitcoinD, ElementsD, env, LightningNode, TEST_DEBUG, TEST_NETWORK
from pyln.client import Millisatoshi
from typing import Dict
Expand Down Expand Up @@ -618,6 +618,7 @@ def checkMemleak(node):
providers = {
'sqlite3': SqliteDbProvider,
'postgres': PostgresDbProvider,
'system-postgres': SystemPostgresProvider
}


Expand Down
21 changes: 17 additions & 4 deletions contrib/pyln-testing/pyln/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pyln.client import RpcError
from pyln.testing.btcproxy import BitcoinRpcProxy
from pyln.testing.gossip import GossipStore
from pyln.testing.version import Version
from collections import OrderedDict
from decimal import Decimal
from pyln.client import LightningRpc
Expand Down Expand Up @@ -598,6 +599,7 @@ def __init__(
self.rpcproxy = bitcoindproxy
self.env['CLN_PLUGIN_LOG'] = "cln_plugin=trace,cln_rpc=trace,cln_grpc=trace,debug"

self.early_opts = {}
self.opts = LIGHTNINGD_CONFIG.copy()
opts = {
'lightning-dir': lightning_dir,
Expand All @@ -613,6 +615,11 @@ def __init__(
'bitcoin-datadir': lightning_dir,
}

# Options that must be early in the command line can be stored
# in `early_args`. They will be passed first to the
# executable.
self.early_opts = {}

if grpc_port is not None:
opts['grpc-port'] = grpc_port

Expand All @@ -633,8 +640,11 @@ def __init__(
# Log to stdout so we see it in failure cases, and log file for TailableProc.
self.opts['log-file'] = ['-', os.path.join(lightning_dir, "log")]
self.opts['log-prefix'] = self.prefix + ' '
# In case you want specific ordering!
self.early_opts = ['--developer']

@property
def version(self) -> Version:
v = subprocess.check_output([self.executable, "--version"]).decode('ASCII')
return Version.from_str(v)

def cleanup(self):
# To force blackhole to exit, disconnect file must be truncated!
Expand All @@ -644,9 +654,12 @@ def cleanup(self):

@property
def cmd_line(self):
if self.version >= Version.from_str('v23.11'):
# Starting with v23.11 we ahve the `--developer` flag
self.early_opts = {'developer': None}

opts = []
for k, v in self.opts.items():
for k, v in list(self.early_opts.items()) + list(self.opts.items()):
if v is None:
opts.append("--{}".format(k))
elif isinstance(v, list):
Expand All @@ -655,7 +668,7 @@ def cmd_line(self):
else:
opts.append("--{}={}".format(k, v))

return self.cmd_prefix + [self.executable] + self.early_opts + opts
return self.cmd_prefix + [self.executable] + opts

def start(self, stdin=None, wait_for_initialized=True, stderr_redir=False):
self.opts['bitcoin-rpcport'] = self.rpcproxy.rpcport
Expand Down
39 changes: 39 additions & 0 deletions contrib/pyln-testing/pyln/testing/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

from dataclasses import dataclass
import re


@dataclass
class Version:
year: int
month: int
patch: int = 0

def __lt__(self, other):
return [self.year, self.month, self.patch] < [other.year, other.month, other.patch]

def __gt__(self, other):
return other < self

def __le__(self, other):
return [self.year, self.month, self.patch] <= [other.year, other.month, other.patch]

def __ge__(self, other):
return other <= self

def __eq__(self, other):
return [self.year, self.month] == [other.year, other.month]

@classmethod
def from_str(cls, s: str) -> "Version":
m = re.search(r'^v(\d+).(\d+).?(\d+)?(rc\d+)?', s)
if m is None:
raise ValueError(f"Could not parse version {s}")
parts = [int(m.group(i)) for i in range(1, 4) if m.group(i) is not None]
year, month = parts[0], parts[1]
if len(parts) == 3:
patch = parts[2]
else:
patch = 0

return Version(year=year, month=month, patch=patch)
12 changes: 12 additions & 0 deletions contrib/pyln-testing/tests/test_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pyln.testing.version import Version


def test_version_parsing():
cases = [
("v24.02", Version(24, 2)),
("v23.11.2", Version(23, 11, 2)),
]

for test_in, test_out in cases:
v = Version.from_str(test_in)
assert test_out == v
11 changes: 8 additions & 3 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def node_cls():

class LightningNode(utils.LightningNode):
def __init__(self, *args, **kwargs):
# Yes, we really want to test the local development version, not
# something in out path.
self.old_path = os.environ['PATH']
binpath = Path(__file__).parent / ".." / "lightningd"
os.environ['PATH'] = f"{binpath}:{self.old_path}"

utils.LightningNode.__init__(self, *args, **kwargs)

# We have some valgrind suppressions in the `tests/`
Expand Down Expand Up @@ -47,9 +53,8 @@ def __init__(self, *args, **kwargs):
accts_db = self.db.provider.get_db('', 'accounts', 0)
self.daemon.opts['bookkeeper-db'] = accts_db.get_dsn()

# Yes, we really want to test the local development version, not
# something in out path.
self.daemon.executable = 'lightningd/lightningd'
def __del__(self):
os.environ['PATH'] = self.old_path


class CompatLevel(object):
Expand Down
Loading