Compare commits

...

3 commits

8 changed files with 61 additions and 30 deletions

View file

@ -1,3 +1,4 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from nyacme.hook import main from nyacme.hook import main
main() main()

View file

@ -1,9 +1,10 @@
import argparse import argparse
import os.path
import subprocess
import logging import logging
import os.path
import shutil import shutil
import subprocess
from datetime import datetime from datetime import datetime
from pathlib import Path
from .config import read_config from .config import read_config
@ -21,6 +22,7 @@ def main() -> None:
args = parser.parse_args() args = parser.parse_args()
config = read_config(args.config) config = read_config(args.config)
output_dir = Path(args.output)
acquired = False acquired = False
@ -31,9 +33,10 @@ def main() -> None:
uacme_domains = [ domain[2:], domain ] uacme_domains = [ domain[2:], domain ]
domain = domain[2:] domain = domain[2:]
cert_path = f'{args.output}/{domain}/cert.pem' cert_path = output_dir / domain / 'cert.pem'
if os.path.exists(cert_path): if cert_path.is_file():
out = subprocess.run([ 'openssl', 'x509', '-enddate', '-noout', '-in', cert_path ], stdout=subprocess.PIPE, check=True).stdout.decode('utf-8').strip() cmd = [ 'openssl', 'x509', '-enddate', '-noout', '-in', cert_path ]
out = subprocess.run(cmd, stdout=subprocess.PIPE, check=True).stdout.decode('utf-8').strip()
date = datetime.strptime(out, 'notAfter=%b %d %H:%M:%S %Y %Z') date = datetime.strptime(out, 'notAfter=%b %d %H:%M:%S %Y %Z')
# if more than 1 month, skip # if more than 1 month, skip
delta = date - datetime.now() delta = date - datetime.now()
@ -62,23 +65,23 @@ def main() -> None:
if res.returncode == 0: if res.returncode == 0:
acquired = True acquired = True
private_key = os.path.join(args.output, f'private/{domain}/key.pem') private_key = output_dir / 'private' / domain / 'key.pem'
domain_key = os.path.join(args.output, f'{domain}/cert.pem.key') domain_key = output_dir / domain / 'cert.pem.key'
domain_pem = os.path.join(args.output, f'{domain}/cert.pem') domain_pem = output_dir / domain / 'cert.pem'
os.unlink(domain_key) domain_key.unlink(missing_ok=True)
os.link(private_key, domain_key) private_key.hardlink_to(domain_key)
# TODO: add user/group to config # TODO: add user/group to config
shutil.chown(domain_key, 'acme', 'acme') shutil.chown(domain_key, 'acme', 'acme')
os.chmod(domain_key, 0o440) domain_key.chmod(0o440)
all_pem = os.path.join(args.output, f'all/{domain}.pem') all_pem = output_dir / 'all' / f'{domain}.pem'
all_key = os.path.join(args.output, f'all/{domain}.pem.key') all_key = output_dir / 'all' / f'{domain}.pem.key'
os.unlink(all_pem) all_pem.unlink(missing_ok=True)
os.link(domain_pem, all_pem) domain_pem.hardlink_to(all_pem)
os.unlink(all_key) all_key.unlink(missing_ok=True)
os.link(domain_key, all_key) domain_key.hardlink_to(all_key)
if acquired: if acquired:
for cmd in config.post_acquire: for cmd in config.post_acquire:

View file

@ -1,8 +1,9 @@
from typing import Optional
import tomllib
import logging import logging
import sys
import os import os
import sys
from typing import Optional
import tomllib
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -2,6 +2,7 @@ import logging
from ..config import Config from ..config import Config
class Handler: class Handler:
zone: str zone: str
config: Config config: Config

View file

@ -1,9 +1,10 @@
import urllib.request
import json import json
from typing import Optional, Any import urllib.request
from typing import Any, Optional
from .base import Handler
from ..config import Config from ..config import Config
from .base import Handler
class HetznerHandler(Handler): class HetznerHandler(Handler):
# discovered # discovered
@ -41,7 +42,7 @@ class HetznerHandler(Handler):
raise Exception(json.loads(res)['error']) raise Exception(json.loads(res)['error'])
except Exception: except Exception:
raise Exception(res) raise Exception(res)
def create(self, record_name: str, record_value: str) -> None: def create(self, record_name: str, record_value: str) -> None:
self.remove(record_name) self.remove(record_name)
self.log.info('creating %s with value %s', record_name, record_value) self.log.info('creating %s with value %s', record_name, record_value)

View file

@ -1,11 +1,12 @@
import os import os
from .base import Handler
from ..config import Config from ..config import Config
from .base import Handler
class HTTPHandler(Handler): class HTTPHandler(Handler):
def __init__(self, zone: str, config: Config, token: str) -> None: def __init__(self, zone: str, config: Config, token: str) -> None:
super().__init__(zone_name, config, token) super().__init__(zone, config, token)
self.filepath = os.path.join(config.acme_path, token) self.filepath = os.path.join(config.acme_path, token)
def create(self, record_name: str, record_value: str) -> None: def create(self, record_name: str, record_value: str) -> None:

View file

@ -1,15 +1,14 @@
import argparse import argparse
import logging import logging
from itertools import chain
import time
import sys import sys
import time
from itertools import chain
import dns.resolver import dns.resolver
from .config import read_config from .config import read_config
from .handlers import HetznerHandler from .handlers import HetznerHandler
logging.basicConfig(level=logging.INFO, format='> [%(levelname)s] %(name)s: %(message)s') logging.basicConfig(level=logging.INFO, format='> [%(levelname)s] %(name)s: %(message)s')
log = logging.getLogger('nyacme_hook') log = logging.getLogger('nyacme_hook')
@ -49,7 +48,9 @@ def main() -> None:
if args.type == 'dns-01': if args.type == 'dns-01':
resolver = dns.resolver.Resolver('', configure=False) resolver = dns.resolver.Resolver('', configure=False)
resolver.nameservers = list(chain.from_iterable(list(map(resolve4, handler.nameservers)) + list(map(resolve6, handler.nameservers)))) ns4 = list(map(resolve4, handler.nameservers))
ns6 = list(map(resolve6, handler.nameservers))
resolver.nameservers = list(chain.from_iterable(ns4 + ns6))
for i in range(5): for i in range(5):
log.info('checking DNS (attempt %d/5)', i+1) log.info('checking DNS (attempt %d/5)', i+1)
try: try:

22
ruff.toml Normal file
View file

@ -0,0 +1,22 @@
line-length = 120
# https://docs.astral.sh/ruff/rules/
[lint]
extend-select = [
"Q", # quotes
"PT", # pytest
"I", # isort
"F", # pyflakes
"E", # pycodestyle (errors)
"W", # pycodestyle (warnings)
"UP", # pyupgrade
"ISC", # implicit string concat
"T20", # no print/pprint
"G001", # no str.format in logging
"C901", # check for complexity
]
[lint.flake8-quotes]
inline-quotes = "single"
docstring-quotes = "single"