refactor: improve config parsing, allow for domain-specific secrets and multi-value secrets

This commit is contained in:
ptrcnull 2024-08-10 21:21:20 +02:00
parent d2570641ad
commit a37af17d58
6 changed files with 82 additions and 22 deletions

View file

@ -3,7 +3,6 @@ import logging
import os.path
import shutil
import subprocess
import sys
from datetime import datetime
from pathlib import Path

View file

@ -1,19 +1,29 @@
import logging
import os
import sys
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Any, TypeAlias, TypeVar
import tomllib
log = logging.getLogger(__name__)
Secret: TypeAlias = str | dict[str, str] | None
@dataclass
class DomainInfo:
handler: str
secret: Secret
def invalid_type_msg(what: str, expected: str, got: Any) -> str:
return f'invalid type for {what}: expected {expected}, got {type(got).__name__}'
class Config:
post_acquire: list[str]
certificates: list[str]
domains: dict[str, str]
secrets: dict[str, str]
domains: dict[str, DomainInfo]
secrets: dict[str, Secret]
acme_path: Path
def find_zone(self, domain: str) -> str:
@ -27,25 +37,64 @@ class Config:
def get_handler(self, domain: str) -> str:
if domain in self.domains:
return self.domains[domain]
return self.domains[domain].handler
raise Exception(f'domain {domain} not found in the config')
def get_secret(self, handler: str) -> str:
return self.secrets[handler]
def get_secret(self, domain: str) -> Secret:
return self.domains[domain].secret
def get_single_secret(self, domain: str) -> str:
secret = self.get_secret(domain)
if type(secret) is not str:
raise TypeError(f'invalid type for domain "{domain}" secret: expected str, got {type(secret).__name__}')
return secret
def config_parse_dict(raw_conf: dict[str, Any], key: str) -> dict[str, str]:
def get_multi_secret(self, domain: str) -> dict[str, str]:
secret = self.get_secret(domain)
if type(secret) is not dict:
raise TypeError(f'invalid type for domain "{domain}" secret: expected dict, got {type(secret).__name__}')
return secret
def config_parse_secret(raw: Any) -> Secret:
if type(raw) is str:
return raw
if type(raw) is dict:
for k, v in raw:
if type(k) is not str:
raise TypeError(f'invalid type for secret name: expected str, got {type(k).__name__}')
if type(v) is not str:
raise TypeError(f'invalid type for secret "{k}" value: expected str, got {type(v).__name__}')
return raw
raise TypeError(f'invalid type for secret: expected str or dict, got {type(raw).__name__}')
def config_parse_domain(raw: Any) -> DomainInfo:
if type(raw) is str:
return DomainInfo(handler=raw, secret=None)
if type(raw) is dict:
if 'handler' not in raw:
raise TypeError('domain info does not contain "handler"')
handler = raw['handler']
if type(handler) is not str:
raise TypeError(f'invalid type for domain info handler: expected str, got {type(handler).__name__}')
secret = None
if 'secret' in raw:
secret = config_parse_secret(raw['secret'])
return DomainInfo(handler=handler, secret=secret)
raise TypeError(f'invalid type for domain info: expected str or dict, got {type(raw).__name__}')
T = TypeVar('T')
def config_parse_dict(raw_conf: dict[str, Any], key: str, parse_item: Callable[[Any], T]) -> dict[str, T]:
if key not in raw_conf:
log.error(f'missing "{key}" in config')
sys.exit(1)
result: dict[str, T] = {}
for k, v in raw_conf[key].items():
if not isinstance(k, str):
raise TypeError(f'"{k}" is not a string')
if not isinstance(v, str):
raise TypeError(f'"{k}" value "{v}" is not a string')
result: dict[str, str] = raw_conf[key]
result[k] = parse_item(v)
return result
@ -53,11 +102,11 @@ def config_parse_list(raw_conf: dict[str, Any], key: str) -> list[str]:
result: list[str] = []
if not isinstance(raw_conf[key], list):
raise TypeError(f'"{key}" must be a list, not {type(raw_conf[key]).__name__}')
raise TypeError(f'invalid type for "{key}": expected list, not {type(raw_conf[key]).__name__}')
for item in raw_conf[key]:
if not isinstance(item, str):
raise TypeError(f'"{key}" list item must be a string')
if type(item) is not str:
raise TypeError(f'invalid type for "{key}" list item: expected str, got {type(item).__name__}')
result.append(item)
return result
@ -77,8 +126,15 @@ def read_config(path: str | None) -> Config:
c = Config()
c.domains = config_parse_dict(raw_conf, 'domains')
c.secrets = config_parse_dict(raw_conf, 'secrets')
c.domains = config_parse_dict(raw_conf, 'domains', config_parse_domain)
secrets = config_parse_dict(raw_conf, 'secrets', config_parse_secret)
# fill default secrets
for domain in c.domains:
info = c.domains[domain]
if info.secret is None:
if info.handler not in secrets:
raise KeyError(f'"{info.handler}" not in secrets')
info.secret = secrets[info.handler]
post_acquire = []
if 'post_acquire' in raw_conf:

View file

@ -14,7 +14,7 @@ class CloudflareHandler(Handler):
def __init__(self, zone_name: str, config: Config, token: str) -> None:
super().__init__(zone_name, config, token)
self.secret = config.get_secret('cloudflare')
self.secret = config.get_single_secret(zone_name)
servers = dns.resolver.resolve(zone_name, 'NS')
self.nameservers = [ str(rdata.target).strip('.') for rdata in servers ]

View file

@ -10,7 +10,7 @@ class HEHandler(Handler):
def __init__(self, zone_name: str, config: Config, token: str) -> None:
super().__init__(zone_name, config, token)
self.nameservers = ['ns1.he.net', 'ns2.he.net', 'ns3.he.net', 'ns4.he.net', 'ns5.he.net']
self.password = config.get_secret('he')
self.password = config.get_single_secret(zone_name)
def set_record(self, record_name: str, value: str) -> Any:
full_record_name = record_name + '.' + self.zone

View file

@ -13,7 +13,7 @@ class HetznerHandler(Handler):
def __init__(self, zone_name: str, config: Config, token: str) -> None:
super().__init__(zone_name, config, token)
self.secret = config.get_secret('hetzner')
self.secret = config.get_single_secret(zone_name)
zones = self.fetch('/zones')['zones']
for zone in zones:

View file

@ -12,8 +12,13 @@ class PorkbunHandler(Handler):
def __init__(self, zone_name: str, config: Config, token: str) -> None:
super().__init__(zone_name, config, token)
self.apikey = config.get_secret('porkbun.apikey')
self.secretapikey = config.get_secret('porkbun.secretapikey')
secret = config.get_multi_secret(zone_name)
if 'apikey' not in secret:
raise TypeError('"apikey" missing')
if 'secretapikey' not in secret:
raise TypeError('"secretapikey" missing')
self.apikey = secret['apikey']
self.secretapikey = secret['secretapikey']
self.nameservers = self.fetch(f'/domain/getNs/{self.zone}')['ns']