nyacme/nyacme/hook.py

74 lines
2.4 KiB
Python
Executable file

#!/usr/bin/env python3
import argparse
import logging
from itertools import chain
import time
import sys
import dns.resolver
from .config import read_config
from .handlers import HetznerHandler
logging.basicConfig(level=logging.INFO, format='> [%(levelname)s] %(name)s: %(message)s')
log = logging.getLogger('nyacme_hook')
handlers = {'hetzner': HetznerHandler}
def main() -> None:
parser = argparse.ArgumentParser(
prog='nyacme-hook',
description='nyacme hook (not meant to be ran manually)'
)
parser.add_argument('method', help='one of begin, done or failed')
parser.add_argument('type', help='challenge type (dns-01, http-01 or tls-alpn-01)')
parser.add_argument('domain', help='the identifier the challenge refers to (domain name)')
parser.add_argument('token', help='the challenge token')
parser.add_argument('auth', help='the key authorization (DNS record contents, etc.)')
args = parser.parse_args()
config = read_config(None)
record_name = f'_acme-challenge.{args.domain}'
zone_name = config.find_zone(args.domain)
short_record_name = record_name.replace('.' + zone_name, '')
handler_name = config.get_handler(zone_name)
handler = handlers[handler_name](zone_name, config, token)
if args.method == 'begin':
handler.create(short_record_name, args.auth)
else:
handler.remove(short_record_name)
if args.type == 'dns-01':
resolver = dns.resolver.Resolver('', configure=False)
resolver.nameservers = list(chain.from_iterable(list(map(resolve4, handler.nameservers)) + list(map(resolve6, handler.nameservers))))
for i in range(5):
log.info('checking DNS (attempt %d/5)', i+1)
try:
res = resolver.resolve(record_name, 'TXT')
values = list(map(lambda rdata: rdata.to_text().strip('"'), res))
except dns.resolver.NXDOMAIN:
values = []
log.info('response from DNS: %s', values)
if (args.method == 'begin') == (args.auth in values):
sys.exit(0)
time.sleep(5)
log.warning('could not ensure the DNS record was created!!')
def resolve4(addr: str) -> list[str]:
res = dns.resolver.resolve(addr, 'A')
return list(map(str, res))
def resolve6(addr: str) -> list[str]:
res = dns.resolver.resolve(addr, 'AAAA')
return list(map(str, res))
if __name__ == '__main__':
main()