refactor: improve config parsing, allow for domain-specific secrets and multi-value secrets

This commit is contained in:
ptrcnull 2024-08-10 21:21:20 +02:00
parent d2570641ad
commit a37af17d58
6 changed files with 82 additions and 22 deletions

View file

@ -3,7 +3,6 @@ import logging
import os.path import os.path
import shutil import shutil
import subprocess import subprocess
import sys
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path

View file

@ -1,19 +1,29 @@
import logging import logging
import os import os
import sys import sys
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, TypeAlias, TypeVar
import tomllib import tomllib
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
Secret: TypeAlias = str | dict[str, str] | None
@dataclass
class DomainInfo:
handler: str
secret: Secret
def invalid_type_msg(what: str, expected: str, got: Any) -> str:
return f'invalid type for {what}: expected {expected}, got {type(got).__name__}'
class Config: class Config:
post_acquire: list[str] post_acquire: list[str]
certificates: list[str] certificates: list[str]
domains: dict[str, str] domains: dict[str, DomainInfo]
secrets: dict[str, str] secrets: dict[str, Secret]
acme_path: Path acme_path: Path
def find_zone(self, domain: str) -> str: def find_zone(self, domain: str) -> str:
@ -27,25 +37,64 @@ class Config:
def get_handler(self, domain: str) -> str: def get_handler(self, domain: str) -> str:
if domain in self.domains: if domain in self.domains:
return self.domains[domain] return self.domains[domain].handler
raise Exception(f'domain {domain} not found in the config') raise Exception(f'domain {domain} not found in the config')
def get_secret(self, handler: str) -> str: def get_secret(self, domain: str) -> Secret:
return self.secrets[handler] return self.domains[domain].secret
def get_single_secret(self, domain: str) -> str:
secret = self.get_secret(domain)
if type(secret) is not str:
raise TypeError(f'invalid type for domain "{domain}" secret: expected str, got {type(secret).__name__}')
return secret
def config_parse_dict(raw_conf: dict[str, Any], key: str) -> dict[str, str]: def get_multi_secret(self, domain: str) -> dict[str, str]:
secret = self.get_secret(domain)
if type(secret) is not dict:
raise TypeError(f'invalid type for domain "{domain}" secret: expected dict, got {type(secret).__name__}')
return secret
def config_parse_secret(raw: Any) -> Secret:
if type(raw) is str:
return raw
if type(raw) is dict:
for k, v in raw:
if type(k) is not str:
raise TypeError(f'invalid type for secret name: expected str, got {type(k).__name__}')
if type(v) is not str:
raise TypeError(f'invalid type for secret "{k}" value: expected str, got {type(v).__name__}')
return raw
raise TypeError(f'invalid type for secret: expected str or dict, got {type(raw).__name__}')
def config_parse_domain(raw: Any) -> DomainInfo:
if type(raw) is str:
return DomainInfo(handler=raw, secret=None)
if type(raw) is dict:
if 'handler' not in raw:
raise TypeError('domain info does not contain "handler"')
handler = raw['handler']
if type(handler) is not str:
raise TypeError(f'invalid type for domain info handler: expected str, got {type(handler).__name__}')
secret = None
if 'secret' in raw:
secret = config_parse_secret(raw['secret'])
return DomainInfo(handler=handler, secret=secret)
raise TypeError(f'invalid type for domain info: expected str or dict, got {type(raw).__name__}')
T = TypeVar('T')
def config_parse_dict(raw_conf: dict[str, Any], key: str, parse_item: Callable[[Any], T]) -> dict[str, T]:
if key not in raw_conf: if key not in raw_conf:
log.error(f'missing "{key}" in config') log.error(f'missing "{key}" in config')
sys.exit(1) sys.exit(1)
result: dict[str, T] = {}
for k, v in raw_conf[key].items(): for k, v in raw_conf[key].items():
if not isinstance(k, str): if not isinstance(k, str):
raise TypeError(f'"{k}" is not a string') raise TypeError(f'"{k}" is not a string')
if not isinstance(v, str): result[k] = parse_item(v)
raise TypeError(f'"{k}" value "{v}" is not a string')
result: dict[str, str] = raw_conf[key]
return result return result
@ -53,11 +102,11 @@ def config_parse_list(raw_conf: dict[str, Any], key: str) -> list[str]:
result: list[str] = [] result: list[str] = []
if not isinstance(raw_conf[key], list): if not isinstance(raw_conf[key], list):
raise TypeError(f'"{key}" must be a list, not {type(raw_conf[key]).__name__}') raise TypeError(f'invalid type for "{key}": expected list, not {type(raw_conf[key]).__name__}')
for item in raw_conf[key]: for item in raw_conf[key]:
if not isinstance(item, str): if type(item) is not str:
raise TypeError(f'"{key}" list item must be a string') raise TypeError(f'invalid type for "{key}" list item: expected str, got {type(item).__name__}')
result.append(item) result.append(item)
return result return result
@ -77,8 +126,15 @@ def read_config(path: str | None) -> Config:
c = Config() c = Config()
c.domains = config_parse_dict(raw_conf, 'domains') c.domains = config_parse_dict(raw_conf, 'domains', config_parse_domain)
c.secrets = config_parse_dict(raw_conf, 'secrets') secrets = config_parse_dict(raw_conf, 'secrets', config_parse_secret)
# fill default secrets
for domain in c.domains:
info = c.domains[domain]
if info.secret is None:
if info.handler not in secrets:
raise KeyError(f'"{info.handler}" not in secrets')
info.secret = secrets[info.handler]
post_acquire = [] post_acquire = []
if 'post_acquire' in raw_conf: if 'post_acquire' in raw_conf:

View file

@ -14,7 +14,7 @@ class CloudflareHandler(Handler):
def __init__(self, zone_name: str, config: Config, token: str) -> None: def __init__(self, zone_name: str, config: Config, token: str) -> None:
super().__init__(zone_name, config, token) super().__init__(zone_name, config, token)
self.secret = config.get_secret('cloudflare') self.secret = config.get_single_secret(zone_name)
servers = dns.resolver.resolve(zone_name, 'NS') servers = dns.resolver.resolve(zone_name, 'NS')
self.nameservers = [ str(rdata.target).strip('.') for rdata in servers ] self.nameservers = [ str(rdata.target).strip('.') for rdata in servers ]

View file

@ -10,7 +10,7 @@ class HEHandler(Handler):
def __init__(self, zone_name: str, config: Config, token: str) -> None: def __init__(self, zone_name: str, config: Config, token: str) -> None:
super().__init__(zone_name, config, token) super().__init__(zone_name, config, token)
self.nameservers = ['ns1.he.net', 'ns2.he.net', 'ns3.he.net', 'ns4.he.net', 'ns5.he.net'] self.nameservers = ['ns1.he.net', 'ns2.he.net', 'ns3.he.net', 'ns4.he.net', 'ns5.he.net']
self.password = config.get_secret('he') self.password = config.get_single_secret(zone_name)
def set_record(self, record_name: str, value: str) -> Any: def set_record(self, record_name: str, value: str) -> Any:
full_record_name = record_name + '.' + self.zone full_record_name = record_name + '.' + self.zone

View file

@ -13,7 +13,7 @@ class HetznerHandler(Handler):
def __init__(self, zone_name: str, config: Config, token: str) -> None: def __init__(self, zone_name: str, config: Config, token: str) -> None:
super().__init__(zone_name, config, token) super().__init__(zone_name, config, token)
self.secret = config.get_secret('hetzner') self.secret = config.get_single_secret(zone_name)
zones = self.fetch('/zones')['zones'] zones = self.fetch('/zones')['zones']
for zone in zones: for zone in zones:

View file

@ -12,8 +12,13 @@ class PorkbunHandler(Handler):
def __init__(self, zone_name: str, config: Config, token: str) -> None: def __init__(self, zone_name: str, config: Config, token: str) -> None:
super().__init__(zone_name, config, token) super().__init__(zone_name, config, token)
self.apikey = config.get_secret('porkbun.apikey') secret = config.get_multi_secret(zone_name)
self.secretapikey = config.get_secret('porkbun.secretapikey') if 'apikey' not in secret:
raise TypeError('"apikey" missing')
if 'secretapikey' not in secret:
raise TypeError('"secretapikey" missing')
self.apikey = secret['apikey']
self.secretapikey = secret['secretapikey']
self.nameservers = self.fetch(f'/domain/getNs/{self.zone}')['ns'] self.nameservers = self.fetch(f'/domain/getNs/{self.zone}')['ns']