summaryrefslogtreecommitdiffstats
path: root/python/knot_resolver
diff options
context:
space:
mode:
Diffstat (limited to 'python/knot_resolver')
-rw-r--r--python/knot_resolver/__init__.py5
-rw-r--r--python/knot_resolver/client/__init__.py5
-rw-r--r--python/knot_resolver/client/__main__.py4
-rw-r--r--python/knot_resolver/client/client.py51
-rw-r--r--python/knot_resolver/client/command.py125
-rw-r--r--python/knot_resolver/client/commands/cache.py123
-rw-r--r--python/knot_resolver/client/commands/completion.py95
-rw-r--r--python/knot_resolver/client/commands/config.py222
-rw-r--r--python/knot_resolver/client/commands/convert.py85
-rw-r--r--python/knot_resolver/client/commands/help.py24
-rw-r--r--python/knot_resolver/client/commands/metrics.py67
-rw-r--r--python/knot_resolver/client/commands/reload.py36
-rw-r--r--python/knot_resolver/client/commands/schema.py55
-rw-r--r--python/knot_resolver/client/commands/stop.py32
-rw-r--r--python/knot_resolver/client/commands/validate.py63
-rw-r--r--python/knot_resolver/client/main.py69
-rw-r--r--python/knot_resolver/compat/__init__.py3
-rw-r--r--python/knot_resolver/compat/asyncio.py128
-rw-r--r--python/knot_resolver/controller/__init__.py94
-rw-r--r--python/knot_resolver/controller/interface.py296
-rw-r--r--python/knot_resolver/controller/registered_workers.py49
-rw-r--r--python/knot_resolver/controller/supervisord/__init__.py281
-rw-r--r--python/knot_resolver/controller/supervisord/config_file.py197
-rw-r--r--python/knot_resolver/controller/supervisord/plugin/fast_rpcinterface.py173
-rw-r--r--python/knot_resolver/controller/supervisord/plugin/manager_integration.py85
-rw-r--r--python/knot_resolver/controller/supervisord/plugin/notifymodule.c176
-rw-r--r--python/knot_resolver/controller/supervisord/plugin/patch_logger.py97
-rw-r--r--python/knot_resolver/controller/supervisord/plugin/sd_notify.py227
-rw-r--r--python/knot_resolver/controller/supervisord/supervisord.conf.j293
-rw-r--r--python/knot_resolver/datamodel/__init__.py3
-rw-r--r--python/knot_resolver/datamodel/cache_schema.py139
-rw-r--r--python/knot_resolver/datamodel/config_schema.py242
-rw-r--r--python/knot_resolver/datamodel/design-notes.yml237
-rw-r--r--python/knot_resolver/datamodel/dns64_schema.py19
-rw-r--r--python/knot_resolver/datamodel/dnssec_schema.py45
-rw-r--r--python/knot_resolver/datamodel/forward_schema.py84
-rw-r--r--python/knot_resolver/datamodel/globals.py57
-rw-r--r--python/knot_resolver/datamodel/local_data_schema.py95
-rw-r--r--python/knot_resolver/datamodel/logging_schema.py153
-rw-r--r--python/knot_resolver/datamodel/lua_schema.py23
-rw-r--r--python/knot_resolver/datamodel/management_schema.py21
-rw-r--r--python/knot_resolver/datamodel/monitoring_schema.py25
-rw-r--r--python/knot_resolver/datamodel/network_schema.py181
-rw-r--r--python/knot_resolver/datamodel/options_schema.py36
-rw-r--r--python/knot_resolver/datamodel/policy_schema.py126
-rw-r--r--python/knot_resolver/datamodel/rpz_schema.py29
-rw-r--r--python/knot_resolver/datamodel/slice_schema.py21
-rw-r--r--python/knot_resolver/datamodel/static_hints_schema.py27
-rw-r--r--python/knot_resolver/datamodel/stub_zone_schema.py32
-rw-r--r--python/knot_resolver/datamodel/templates/__init__.py43
-rw-r--r--python/knot_resolver/datamodel/templates/cache.lua.j232
-rw-r--r--python/knot_resolver/datamodel/templates/dns64.lua.j217
-rw-r--r--python/knot_resolver/datamodel/templates/dnssec.lua.j260
-rw-r--r--python/knot_resolver/datamodel/templates/forward.lua.j29
-rw-r--r--python/knot_resolver/datamodel/templates/local_data.lua.j241
-rw-r--r--python/knot_resolver/datamodel/templates/logging.lua.j243
-rw-r--r--python/knot_resolver/datamodel/templates/macros/cache_macros.lua.j211
-rw-r--r--python/knot_resolver/datamodel/templates/macros/common_macros.lua.j2101
-rw-r--r--python/knot_resolver/datamodel/templates/macros/forward_macros.lua.j242
-rw-r--r--python/knot_resolver/datamodel/templates/macros/local_data_macros.lua.j2101
-rw-r--r--python/knot_resolver/datamodel/templates/macros/network_macros.lua.j255
-rw-r--r--python/knot_resolver/datamodel/templates/macros/policy_macros.lua.j2279
-rw-r--r--python/knot_resolver/datamodel/templates/macros/view_macros.lua.j225
-rw-r--r--python/knot_resolver/datamodel/templates/monitoring.lua.j233
-rw-r--r--python/knot_resolver/datamodel/templates/network.lua.j2102
-rw-r--r--python/knot_resolver/datamodel/templates/options.lua.j252
-rw-r--r--python/knot_resolver/datamodel/templates/policy-config.lua.j240
-rw-r--r--python/knot_resolver/datamodel/templates/static_hints.lua.j251
-rw-r--r--python/knot_resolver/datamodel/templates/views.lua.j225
-rw-r--r--python/knot_resolver/datamodel/templates/webmgmt.lua.j225
-rw-r--r--python/knot_resolver/datamodel/templates/worker-config.lua.j258
-rw-r--r--python/knot_resolver/datamodel/types/__init__.py69
-rw-r--r--python/knot_resolver/datamodel/types/base_types.py227
-rw-r--r--python/knot_resolver/datamodel/types/enums.py153
-rw-r--r--python/knot_resolver/datamodel/types/files.py245
-rw-r--r--python/knot_resolver/datamodel/types/generic_types.py38
-rw-r--r--python/knot_resolver/datamodel/types/types.py526
-rw-r--r--python/knot_resolver/datamodel/view_schema.py45
-rw-r--r--python/knot_resolver/datamodel/webmgmt_schema.py27
-rw-r--r--python/knot_resolver/manager/__init__.py0
-rw-r--r--python/knot_resolver/manager/__main__.py5
-rw-r--r--python/knot_resolver/manager/config_store.py101
-rw-r--r--python/knot_resolver/manager/constants.py108
-rw-r--r--python/knot_resolver/manager/exceptions.py28
-rw-r--r--python/knot_resolver/manager/kres_manager.py429
-rw-r--r--python/knot_resolver/manager/log.py105
-rw-r--r--python/knot_resolver/manager/main.py49
-rw-r--r--python/knot_resolver/manager/server.py637
-rw-r--r--python/knot_resolver/manager/statistics.py434
-rw-r--r--python/knot_resolver/utils/__init__.py45
-rw-r--r--python/knot_resolver/utils/async_utils.py129
-rw-r--r--python/knot_resolver/utils/custom_atexit.py20
-rw-r--r--python/knot_resolver/utils/etag.py10
-rw-r--r--python/knot_resolver/utils/functional.py72
-rw-r--r--python/knot_resolver/utils/modeling/README.md155
-rw-r--r--python/knot_resolver/utils/modeling/__init__.py14
-rw-r--r--python/knot_resolver/utils/modeling/base_generic_type_wrapper.py9
-rw-r--r--python/knot_resolver/utils/modeling/base_schema.py816
-rw-r--r--python/knot_resolver/utils/modeling/base_value_type.py45
-rw-r--r--python/knot_resolver/utils/modeling/exceptions.py63
-rw-r--r--python/knot_resolver/utils/modeling/json_pointer.py88
-rw-r--r--python/knot_resolver/utils/modeling/parsing.py99
-rw-r--r--python/knot_resolver/utils/modeling/query.py183
-rw-r--r--python/knot_resolver/utils/modeling/renaming.py90
-rw-r--r--python/knot_resolver/utils/modeling/types.py105
-rw-r--r--python/knot_resolver/utils/requests.py135
-rw-r--r--python/knot_resolver/utils/systemd_notify.py54
-rw-r--r--python/knot_resolver/utils/which.py22
108 files changed, 11180 insertions, 0 deletions
diff --git a/python/knot_resolver/__init__.py b/python/knot_resolver/__init__.py
new file mode 100644
index 00000000..511e8d44
--- /dev/null
+++ b/python/knot_resolver/__init__.py
@@ -0,0 +1,5 @@
+from .datamodel.config_schema import KresConfig
+
+__version__ = "0.1.0"
+
+__all__ = ["KresConfig"]
diff --git a/python/knot_resolver/client/__init__.py b/python/knot_resolver/client/__init__.py
new file mode 100644
index 00000000..5b82d3be
--- /dev/null
+++ b/python/knot_resolver/client/__init__.py
@@ -0,0 +1,5 @@
+from pathlib import Path
+
+from knot_resolver.datamodel.globals import Context, set_global_validation_context
+
+set_global_validation_context(Context(Path("."), False))
diff --git a/python/knot_resolver/client/__main__.py b/python/knot_resolver/client/__main__.py
new file mode 100644
index 00000000..56200674
--- /dev/null
+++ b/python/knot_resolver/client/__main__.py
@@ -0,0 +1,4 @@
+from knot_resolver.client.main import main
+
+if __name__ == "__main__":
+ main()
diff --git a/python/knot_resolver/client/client.py b/python/knot_resolver/client/client.py
new file mode 100644
index 00000000..4e7d13ea
--- /dev/null
+++ b/python/knot_resolver/client/client.py
@@ -0,0 +1,51 @@
+import argparse
+
+from knot_resolver.client.command import CommandArgs
+
+KRES_CLIENT_NAME = "kresctl"
+
+
+class KresClient:
+ def __init__(
+ self,
+ namespace: argparse.Namespace,
+ parser: argparse.ArgumentParser,
+ prompt: str = KRES_CLIENT_NAME,
+ ) -> None:
+ self.path = None
+ self.prompt = prompt
+ self.namespace = namespace
+ self.parser = parser
+
+ def execute(self):
+ if hasattr(self.namespace, "command"):
+ args = CommandArgs(self.namespace, self.parser)
+ command = args.command(self.namespace)
+ command.run(args)
+ else:
+ self.parser.print_help()
+
+ def _prompt_format(self) -> str:
+ bolt = "\033[1m"
+ white = "\033[38;5;255m"
+ reset = "\033[0;0m"
+
+ if self.path:
+ prompt = f"{bolt}[{self.prompt} {white}{self.path}{reset}{bolt}]"
+ else:
+ prompt = f"{bolt}{self.prompt}"
+ return f"{prompt}> {reset}"
+
+ def interactive(self):
+ try:
+ while True:
+ pass
+ # TODO: not working yet
+ # cmd = input(f"{self._prompt_format()}")
+ # namespace = self.parser.parse_args(cmd.split(" "))
+ # namespace.interactive = True
+ # namespace.socket = self.namespace.socket
+ # self.namespace = namespace
+ # self.execute()
+ except KeyboardInterrupt:
+ pass
diff --git a/python/knot_resolver/client/command.py b/python/knot_resolver/client/command.py
new file mode 100644
index 00000000..af59c42e
--- /dev/null
+++ b/python/knot_resolver/client/command.py
@@ -0,0 +1,125 @@
+import argparse
+import os
+from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module]
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple, Type, TypeVar
+from urllib.parse import quote
+
+from knot_resolver.manager.constants import API_SOCK_ENV_VAR, CONFIG_FILE_ENV_VAR, DEFAULT_MANAGER_CONFIG_FILE
+from knot_resolver.datamodel.config_schema import DEFAULT_MANAGER_API_SOCK
+from knot_resolver.datamodel.types import IPAddressPort
+from knot_resolver.utils.modeling import parsing
+from knot_resolver.utils.modeling.exceptions import DataValidationError
+from knot_resolver.utils.requests import SocketDesc
+
+T = TypeVar("T", bound=Type["Command"])
+
+CompWords = Dict[str, Optional[str]]
+
+_registered_commands: List[Type["Command"]] = []
+
+
+def register_command(cls: T) -> T:
+ _registered_commands.append(cls)
+ return cls
+
+
+def get_help_command() -> Type["Command"]:
+ for command in _registered_commands:
+ if command.__name__ == "HelpCommand":
+ return command
+ raise ValueError("missing HelpCommand")
+
+
+def install_commands_parsers(parser: argparse.ArgumentParser) -> None:
+ subparsers = parser.add_subparsers(help="command type")
+ for command in _registered_commands:
+ subparser, typ = command.register_args_subparser(subparsers)
+ subparser.set_defaults(command=typ, subparser=subparser)
+
+
+def get_socket_from_config(config: Path, optional_file: bool) -> Optional[SocketDesc]:
+ try:
+ with open(config, "r", encoding="utf8") as f:
+ data = parsing.try_to_parse(f.read())
+ mkey = "management"
+ if mkey in data:
+ management = data[mkey]
+ if "unix-socket" in management:
+ return SocketDesc(
+ f'http+unix://{quote(management["unix-socket"], safe="")}/',
+ f'Key "/management/unix-socket" in "{config}" file',
+ )
+ elif "interface" in management:
+ ip = IPAddressPort(management["interface"], object_path=f"/{mkey}/interface")
+ return SocketDesc(
+ f"http://{ip.addr}:{ip.port}",
+ f'Key "/management/interface" in "{config}" file',
+ )
+ return None
+ except ValueError as e:
+ raise DataValidationError(*e.args) from e # pylint: disable=no-value-for-parameter
+ except OSError as e:
+ if not optional_file:
+ raise e
+ return None
+
+
+def determine_socket(namespace: argparse.Namespace) -> SocketDesc:
+ # 1) socket from '--socket' argument
+ if len(namespace.socket) > 0:
+ return SocketDesc(namespace.socket[0], "--socket argument")
+
+ config_path = os.getenv(CONFIG_FILE_ENV_VAR)
+ socket_env = os.getenv(API_SOCK_ENV_VAR)
+
+ socket: Optional[SocketDesc] = None
+ # 2) socket from config file ('--config' argument)
+ if len(namespace.config) > 0:
+ socket = get_socket_from_config(namespace.config[0], False)
+ # 3) socket from config file (environment variable)
+ elif config_path:
+ socket = get_socket_from_config(Path(config_path), False)
+ # 4) socket from environment variable
+ elif socket_env:
+ socket = SocketDesc(socket_env, f'Environment variable "{API_SOCK_ENV_VAR}"')
+ # 5) socket from config file (default config file constant)
+ else:
+ socket = get_socket_from_config(DEFAULT_MANAGER_CONFIG_FILE, True)
+
+ if socket:
+ return socket
+ # 6) socket default
+ return SocketDesc(DEFAULT_MANAGER_API_SOCK, f'Default value "{DEFAULT_MANAGER_API_SOCK}"')
+
+
+class CommandArgs:
+ def __init__(self, namespace: argparse.Namespace, parser: argparse.ArgumentParser) -> None:
+ self.namespace = namespace
+ self.parser = parser
+ self.subparser: argparse.ArgumentParser = namespace.subparser
+ self.command: Type["Command"] = namespace.command
+
+ self.socket: SocketDesc = determine_socket(namespace)
+
+
+class Command(ABC):
+ @staticmethod
+ @abstractmethod
+ def register_args_subparser(
+ subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
+ ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def __init__(self, namespace: argparse.Namespace) -> None: # pylint: disable=[unused-argument]
+ super().__init__()
+
+ @abstractmethod
+ def run(self, args: CommandArgs) -> None:
+ raise NotImplementedError()
+
+ @staticmethod
+ @abstractmethod
+ def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
+ raise NotImplementedError()
diff --git a/python/knot_resolver/client/commands/cache.py b/python/knot_resolver/client/commands/cache.py
new file mode 100644
index 00000000..60417eec
--- /dev/null
+++ b/python/knot_resolver/client/commands/cache.py
@@ -0,0 +1,123 @@
+import argparse
+import sys
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple, Type
+
+from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command
+from knot_resolver.datamodel.cache_schema import CacheClearRPCSchema
+from knot_resolver.utils.modeling.exceptions import AggregateDataValidationError, DataValidationError
+from knot_resolver.utils.modeling.parsing import DataFormat, parse_json
+from knot_resolver.utils.requests import request
+
+
+class CacheOperations(Enum):
+ CLEAR = 0
+
+
+@register_command
+class CacheCommand(Command):
+ def __init__(self, namespace: argparse.Namespace) -> None:
+ super().__init__(namespace)
+ self.operation: Optional[CacheOperations] = namespace.operation if hasattr(namespace, "operation") else None
+ self.output_format: DataFormat = (
+ namespace.output_format if hasattr(namespace, "output_format") else DataFormat.YAML
+ )
+
+ # CLEAR operation
+ self.clear_dict: Dict[str, Any] = {}
+ if hasattr(namespace, "exact_name"):
+ self.clear_dict["exact-name"] = namespace.exact_name
+ if hasattr(namespace, "name"):
+ self.clear_dict["name"] = namespace.name
+ if hasattr(namespace, "rr_type"):
+ self.clear_dict["rr-type"] = namespace.rr_type
+ if hasattr(namespace, "chunk_size"):
+ self.clear_dict["chunk-size"] = namespace.chunk_size
+
+ @staticmethod
+ def register_args_subparser(
+ subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
+ ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
+ cache_parser = subparser.add_parser("cache", help="Performs operations on the cache of the running resolver.")
+
+ config_subparsers = cache_parser.add_subparsers(help="operation type")
+
+ # 'clear' operation
+ clear_subparser = config_subparsers.add_parser(
+ "clear", help="Purge cache records that match specified criteria."
+ )
+ clear_subparser.set_defaults(operation=CacheOperations.CLEAR, exact_name=False)
+ clear_subparser.add_argument(
+ "--exact-name",
+ help="If set, only records with the same name are purged.",
+ action="store_true",
+ dest="exact_name",
+ )
+ clear_subparser.add_argument(
+ "--rr-type",
+ help="Optional, the resource record type to purge. It is supported only with the '--exact-name' flag set.",
+ action="store",
+ type=str,
+ )
+ clear_subparser.add_argument(
+ "--chunk-size",
+ help="Optional, the number of records to remove in one round; the default is 100."
+ " The purpose is not to block the resolver for long."
+ " The resolver repeats the cache clearing after one millisecond until all matching data is cleared.",
+ action="store",
+ type=int,
+ default=100,
+ )
+ clear_subparser.add_argument(
+ "name",
+ type=str,
+ nargs="?",
+ help="Optional, subtree name to purge; if omitted, the entire cache is purged (and all other parameters are ignored).",
+ default=None,
+ )
+
+ output_format = clear_subparser.add_mutually_exclusive_group()
+ output_format_default = DataFormat.YAML
+ output_format.add_argument(
+ "--json",
+ help="Set JSON as the output format.",
+ const=DataFormat.JSON,
+ action="store_const",
+ dest="output_format",
+ default=output_format_default,
+ )
+ output_format.add_argument(
+ "--yaml",
+ help="Set YAML as the output format. YAML is the default.",
+ const=DataFormat.YAML,
+ action="store_const",
+ dest="output_format",
+ default=output_format_default,
+ )
+
+ return cache_parser, CacheCommand
+
+ @staticmethod
+ def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
+ return {}
+
+ def run(self, args: CommandArgs) -> None:
+ if not self.operation:
+ args.subparser.print_help()
+ sys.exit()
+
+ if self.operation == CacheOperations.CLEAR:
+ try:
+ validated = CacheClearRPCSchema(self.clear_dict)
+ except (AggregateDataValidationError, DataValidationError) as e:
+ print(e, file=sys.stderr)
+ sys.exit(1)
+
+ body: str = DataFormat.JSON.dict_dump(validated.get_unparsed_data())
+ response = request(args.socket, "POST", "cache/clear", body)
+ body_dict = parse_json(response.body)
+
+ if response.status != 200:
+ print(response, file=sys.stderr)
+ sys.exit(1)
+ print(self.output_format.dict_dump(body_dict, indent=4))
diff --git a/python/knot_resolver/client/commands/completion.py b/python/knot_resolver/client/commands/completion.py
new file mode 100644
index 00000000..05fdded8
--- /dev/null
+++ b/python/knot_resolver/client/commands/completion.py
@@ -0,0 +1,95 @@
+import argparse
+from enum import Enum
+from typing import List, Tuple, Type
+
+from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command
+
+
+class Shells(Enum):
+ BASH = 0
+ FISH = 1
+
+
+@register_command
+class CompletionCommand(Command):
+ def __init__(self, namespace: argparse.Namespace) -> None:
+ super().__init__(namespace)
+ self.shell: Shells = namespace.shell
+ self.space = namespace.space
+ self.comp_args: List[str] = namespace.comp_args
+
+ if self.space:
+ self.comp_args.append("")
+
+ @staticmethod
+ def register_args_subparser(
+ subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
+ ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
+ completion = subparser.add_parser("completion", help="commands auto-completion")
+ completion.add_argument(
+ "--space",
+ help="space after last word, returns all possible folowing options",
+ dest="space",
+ action="store_true",
+ default=False,
+ )
+ completion.add_argument(
+ "comp_args",
+ type=str,
+ help="arguments to complete",
+ nargs="*",
+ )
+
+ shells_dest = "shell"
+ shells = completion.add_mutually_exclusive_group()
+ shells.add_argument("--bash", action="store_const", dest=shells_dest, const=Shells.BASH, default=Shells.BASH)
+ shells.add_argument("--fish", action="store_const", dest=shells_dest, const=Shells.FISH)
+
+ return completion, CompletionCommand
+
+ @staticmethod
+ def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
+ words: CompWords = {}
+ # for action in parser._actions:
+ # for opt in action.option_strings:
+ # words[opt] = action.help
+ # return words
+ return words
+
+ def run(self, args: CommandArgs) -> None:
+ pass
+ # subparsers = args.parser._subparsers
+ # words: CompWords = {}
+
+ # if subparsers:
+ # words = parser_words(subparsers._actions)
+
+ # uargs = iter(self.comp_args)
+ # for uarg in uargs:
+ # subparser = subparser_by_name(uarg, subparsers._actions) # pylint: disable=W0212
+
+ # if subparser:
+ # cmd: Command = subparser_command(subparser)
+ # subparser_args = self.comp_args[self.comp_args.index(uarg) + 1 :]
+ # if subparser_args:
+ # words = cmd.completion(subparser_args, subparser)
+ # break
+ # elif uarg in ["-s", "--socket"]:
+ # # if arg is socket config, skip next arg
+ # next(uargs)
+ # continue
+ # elif uarg in words:
+ # # uarg is walid arg, continue
+ # continue
+ # else:
+ # raise ValueError(f"unknown argument: {uarg}")
+
+ # # print completion words
+ # # based on required bash/fish shell format
+ # if self.shell == Shells.BASH:
+ # print(" ".join(words))
+ # elif self.shell == Shells.FISH:
+ # # TODO: FISH completion implementation
+ # pass
+ # else:
+ # raise ValueError(f"unexpected value of {Shells}: {self.shell}")
diff --git a/python/knot_resolver/client/commands/config.py b/python/knot_resolver/client/commands/config.py
new file mode 100644
index 00000000..add17272
--- /dev/null
+++ b/python/knot_resolver/client/commands/config.py
@@ -0,0 +1,222 @@
+import argparse
+import sys
+from enum import Enum
+from typing import List, Optional, Tuple, Type
+
+from typing_extensions import Literal
+
+from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command
+from knot_resolver.utils.modeling.parsing import DataFormat, parse_json, try_to_parse
+from knot_resolver.utils.requests import request
+
+
+class Operations(Enum):
+ SET = 0
+ DELETE = 1
+ GET = 2
+
+
+def operation_to_method(operation: Operations) -> Literal["PUT", "GET", "DELETE"]:
+ if operation == Operations.SET:
+ return "PUT"
+ elif operation == Operations.DELETE:
+ return "DELETE"
+ return "GET"
+
+
+# def _properties_words(props: Dict[str, Any]) -> CompWords:
+# words: CompWords = {}
+# for name, prop in props.items():
+# words[name] = prop["description"] if "description" in prop else None
+# return words
+
+
+# def _path_comp_words(node: str, nodes: List[str], props: Dict[str, Any]) -> CompWords:
+# i = nodes.index(node)
+# ln = len(nodes[i:])
+
+# # if node is last in path, return all possible words on thi level
+# if ln == 1:
+# return _properties_words(props)
+# # if node is valid
+# elif node in props:
+# node_schema = props[node]
+
+# if "anyOf" in node_schema:
+# for item in node_schema["anyOf"]:
+# print(item)
+
+# elif "type" not in node_schema:
+# pass
+
+# elif node_schema["type"] == "array":
+# if ln > 2:
+# # skip index for item in array
+# return _path_comp_words(nodes[i + 2], nodes, node_schema["items"]["properties"])
+# if "enum" in node_schema["items"]:
+# print(node_schema["items"]["enum"])
+# return {"0": "first array item", "-": "last array item"}
+# elif node_schema["type"] == "object":
+# if "additionalProperties" in node_schema:
+# print(node_schema)
+# return _path_comp_words(nodes[i + 1], nodes, node_schema["properties"])
+# return {}
+
+# # arrays/lists must be handled sparately
+# if node_schema["type"] == "array":
+# if ln > 2:
+# # skip index for item in array
+# return _path_comp_words(nodes[i + 2], nodes, node_schema["items"]["properties"])
+# return {"0": "first array item", "-": "last array item"}
+# return _path_comp_words(nodes[i + 1], nodes, node_schema["properties"])
+# else:
+# # if node is not last or valid, value error
+# raise ValueError(f"unknown config path node: {node}")
+
+
+@register_command
+class ConfigCommand(Command):
+ def __init__(self, namespace: argparse.Namespace) -> None:
+ super().__init__(namespace)
+ self.path: str = str(namespace.path) if hasattr(namespace, "path") else ""
+ self.format: DataFormat = namespace.format if hasattr(namespace, "format") else DataFormat.JSON
+ self.operation: Optional[Operations] = namespace.operation if hasattr(namespace, "operation") else None
+ self.file: Optional[str] = namespace.file if hasattr(namespace, "file") else None
+
+ @staticmethod
+ def register_args_subparser(
+ subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
+ ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
+ config = subparser.add_parser("config", help="Performs operations on the running resolver's configuration.")
+ path_help = "Optional, path (JSON pointer, RFC6901) to the configuration resources. By default, the entire configuration is selected."
+
+ config_subparsers = config.add_subparsers(help="operation type")
+
+ # GET operation
+ get = config_subparsers.add_parser("get", help="Get current configuration from the resolver.")
+ get.set_defaults(operation=Operations.GET, format=DataFormat.YAML)
+
+ get.add_argument(
+ "-p",
+ "--path",
+ help=path_help,
+ action="store",
+ type=str,
+ default="",
+ )
+ get.add_argument(
+ "file",
+ help="Optional, path to the file where to save exported configuration data. If not specified, data will be printed.",
+ type=str,
+ nargs="?",
+ )
+
+ get_formats = get.add_mutually_exclusive_group()
+ get_formats.add_argument(
+ "--json",
+ help="Get configuration data in JSON format.",
+ const=DataFormat.JSON,
+ action="store_const",
+ dest="format",
+ )
+ get_formats.add_argument(
+ "--yaml",
+ help="Get configuration data in YAML format, default.",
+ const=DataFormat.YAML,
+ action="store_const",
+ dest="format",
+ )
+
+ # SET operation
+ set = config_subparsers.add_parser("set", help="Set new configuration for the resolver.")
+ set.set_defaults(operation=Operations.SET)
+
+ set.add_argument(
+ "-p",
+ "--path",
+ help=path_help,
+ action="store",
+ type=str,
+ default="",
+ )
+
+ value_or_file = set.add_mutually_exclusive_group()
+ value_or_file.add_argument(
+ "file",
+ help="Optional, path to file with new configuraion.",
+ type=str,
+ nargs="?",
+ )
+ value_or_file.add_argument(
+ "value",
+ help="Optional, new configuration value.",
+ type=str,
+ nargs="?",
+ )
+
+ # DELETE operation
+ delete = config_subparsers.add_parser(
+ "delete", help="Delete given configuration property or list item at the given index."
+ )
+ delete.set_defaults(operation=Operations.DELETE)
+ delete.add_argument(
+ "-p",
+ "--path",
+ help=path_help,
+ action="store",
+ type=str,
+ default="",
+ )
+
+ return config, ConfigCommand
+
+ @staticmethod
+ def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
+ # words = parser_words(parser._actions) # pylint: disable=W0212
+
+ # for arg in args:
+ # if arg in words:
+ # continue
+ # elif arg.startswith("-"):
+ # return words
+ # elif arg == args[-1]:
+ # config_path = arg[1:].split("/") if arg.startswith("/") else arg.split("/")
+ # schema_props: Dict[str, Any] = KresConfig.json_schema()["properties"]
+ # return _path_comp_words(config_path[0], config_path, schema_props)
+ # else:
+ # break
+ return {}
+
+ def run(self, args: CommandArgs) -> None:
+ if not self.operation:
+ args.subparser.print_help()
+ sys.exit()
+
+ new_config = None
+ path = f"v1/config{self.path}"
+ method = operation_to_method(self.operation)
+
+ if self.operation == Operations.SET:
+ if self.file:
+ try:
+ with open(self.file, "r") as f:
+ new_config = f.read()
+ except FileNotFoundError:
+ new_config = self.file
+ else:
+ # use STDIN also when file is not specified
+ new_config = input("Type new configuration: ")
+
+ body = DataFormat.JSON.dict_dump(try_to_parse(new_config)) if new_config else None
+ response = request(args.socket, method, path, body)
+
+ if response.status != 200:
+ print(response, file=sys.stderr)
+ sys.exit(1)
+
+ if self.operation == Operations.GET and self.file:
+ with open(self.file, "w") as f:
+ f.write(self.format.dict_dump(parse_json(response.body), indent=4))
+ print(f"saved to: {self.file}")
+ elif response.body:
+ print(self.format.dict_dump(parse_json(response.body), indent=4))
diff --git a/python/knot_resolver/client/commands/convert.py b/python/knot_resolver/client/commands/convert.py
new file mode 100644
index 00000000..a25c5cd9
--- /dev/null
+++ b/python/knot_resolver/client/commands/convert.py
@@ -0,0 +1,85 @@
+import argparse
+import sys
+from pathlib import Path
+from typing import List, Optional, Tuple, Type
+
+from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command
+from knot_resolver.datamodel import KresConfig
+from knot_resolver.datamodel.globals import (
+ Context,
+ reset_global_validation_context,
+ set_global_validation_context,
+)
+from knot_resolver.utils.modeling import try_to_parse
+from knot_resolver.utils.modeling.exceptions import DataParsingError, DataValidationError
+
+
+@register_command
+class ConvertCommand(Command):
+ def __init__(self, namespace: argparse.Namespace) -> None:
+ super().__init__(namespace)
+ self.input_file: str = namespace.input_file
+ self.output_file: Optional[str] = namespace.output_file
+ self.strict: bool = namespace.strict
+ self.type: str = namespace.type
+
+ @staticmethod
+ def register_args_subparser(
+ subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
+ ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
+ convert = subparser.add_parser("convert", help="Converts JSON or YAML configuration to Lua script.")
+ convert.set_defaults(strict=True)
+ convert.add_argument(
+ "--no-strict",
+ help="Ignore strict rules during validation, e.g. path/file existence.",
+ action="store_false",
+ dest="strict",
+ )
+ convert.add_argument(
+ "--type", help="The type of Lua script to generate", choices=["worker", "policy-loader"], default="worker"
+ )
+ convert.add_argument(
+ "input_file",
+ type=str,
+ help="File with configuration in YAML or JSON format.",
+ )
+
+ convert.add_argument(
+ "output_file",
+ type=str,
+ nargs="?",
+ help="Optional, output file for converted configuration in Lua script. If not specified, converted configuration is printed.",
+ default=None,
+ )
+
+ return convert, ConvertCommand
+
+ @staticmethod
+ def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
+ return {}
+
+ def run(self, args: CommandArgs) -> None:
+ with open(self.input_file, "r") as f:
+ data = f.read()
+
+ try:
+ parsed = try_to_parse(data)
+ set_global_validation_context(Context(Path(Path(self.input_file).parent), self.strict))
+
+ if self.type == "worker":
+ lua = KresConfig(parsed).render_lua()
+ elif self.type == "policy-loader":
+ lua = KresConfig(parsed).render_lua_policy()
+ else:
+ raise ValueError(f"Invalid self.type={self.type}")
+
+ reset_global_validation_context()
+ except (DataParsingError, DataValidationError) as e:
+ print(e, file=sys.stderr)
+ sys.exit(1)
+
+ if self.output_file:
+ with open(self.output_file, "w") as f:
+ f.write(lua)
+ else:
+ print(lua)
diff --git a/python/knot_resolver/client/commands/help.py b/python/knot_resolver/client/commands/help.py
new file mode 100644
index 00000000..87306c2a
--- /dev/null
+++ b/python/knot_resolver/client/commands/help.py
@@ -0,0 +1,24 @@
+import argparse
+from typing import List, Tuple, Type
+
+from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command
+
+
+@register_command
+class HelpCommand(Command):
+ def __init__(self, namespace: argparse.Namespace) -> None:
+ super().__init__(namespace)
+
+ def run(self, args: CommandArgs) -> None:
+ args.parser.print_help()
+
+ @staticmethod
+ def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
+ return {}
+
+ @staticmethod
+ def register_args_subparser(
+ subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
+ ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
+ stop = subparser.add_parser("help", help="show this help message and exit")
+ return stop, HelpCommand
diff --git a/python/knot_resolver/client/commands/metrics.py b/python/knot_resolver/client/commands/metrics.py
new file mode 100644
index 00000000..058cad8b
--- /dev/null
+++ b/python/knot_resolver/client/commands/metrics.py
@@ -0,0 +1,67 @@
+import argparse
+import sys
+from typing import List, Optional, Tuple, Type
+
+from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command
+from knot_resolver.utils.modeling.parsing import DataFormat, parse_json
+from knot_resolver.utils.requests import request
+
+
+@register_command
+class MetricsCommand(Command):
+ def __init__(self, namespace: argparse.Namespace) -> None:
+ self.file: Optional[str] = namespace.file
+ self.prometheus: bool = namespace.prometheus
+
+ super().__init__(namespace)
+
+ @staticmethod
+ def register_args_subparser(
+ subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
+ ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
+ metrics = subparser.add_parser(
+ "metrics",
+ help="Get aggregated metrics from the running resolver in JSON format (default) or optionally in Prometheus format."
+ "\nThe 'prometheus-client' Python package needs to be installed if you wish to use the Prometheus format."
+ "\nRequires a connection to the management HTTP API.",
+ )
+
+ metrics.add_argument(
+ "--prometheus",
+ help="Get metrics in Prometheus format if dependencies are met in the resolver.",
+ action="store_true",
+ default=False,
+ )
+
+ metrics.add_argument(
+ "file",
+ help="Optional. The file into which metrics will be exported."
+ "\nIf not specified, the metrics are printed into stdout.",
+ nargs="?",
+ default=None,
+ )
+ return metrics, MetricsCommand
+
+ @staticmethod
+ def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
+ return {}
+
+ def run(self, args: CommandArgs) -> None:
+ response = request(args.socket, "GET", "metrics/prometheus" if self.prometheus else "metrics/json")
+
+ if response.status == 200:
+ if self.prometheus:
+ metrics = response.body
+ else:
+ metrics = DataFormat.JSON.dict_dump(parse_json(response.body), indent=4)
+
+ if self.file:
+ with open(self.file, "w") as f:
+ f.write(metrics)
+ else:
+ print(metrics)
+ else:
+ print(response, file=sys.stderr)
+ if self.prometheus and response.status == 404:
+ print("Prometheus is unavailable due to missing optional dependencies", file=sys.stderr)
+ sys.exit(1)
diff --git a/python/knot_resolver/client/commands/reload.py b/python/knot_resolver/client/commands/reload.py
new file mode 100644
index 00000000..c1350fc5
--- /dev/null
+++ b/python/knot_resolver/client/commands/reload.py
@@ -0,0 +1,36 @@
+import argparse
+import sys
+from typing import List, Tuple, Type
+
+from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command
+from knot_resolver.utils.requests import request
+
+
+@register_command
+class ReloadCommand(Command):
+ def __init__(self, namespace: argparse.Namespace) -> None:
+ super().__init__(namespace)
+
+ @staticmethod
+ def register_args_subparser(
+ subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
+ ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
+ reload = subparser.add_parser(
+ "reload",
+ help="Tells the resolver to reload YAML configuration file."
+ " Old processes are replaced by new ones (with updated configuration) using rolling restarts."
+ " So there will be no DNS service unavailability during reload operation.",
+ )
+
+ return reload, ReloadCommand
+
+ @staticmethod
+ def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
+ return {}
+
+ def run(self, args: CommandArgs) -> None:
+ response = request(args.socket, "POST", "reload")
+
+ if response.status != 200:
+ print(response, file=sys.stderr)
+ sys.exit(1)
diff --git a/python/knot_resolver/client/commands/schema.py b/python/knot_resolver/client/commands/schema.py
new file mode 100644
index 00000000..f5538424
--- /dev/null
+++ b/python/knot_resolver/client/commands/schema.py
@@ -0,0 +1,55 @@
+import argparse
+import json
+import sys
+from typing import List, Optional, Tuple, Type
+
+from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command
+from knot_resolver.datamodel.config_schema import KresConfig
+from knot_resolver.utils.requests import request
+
+
+@register_command
+class SchemaCommand(Command):
+ def __init__(self, namespace: argparse.Namespace) -> None:
+ super().__init__(namespace)
+ self.live: bool = namespace.live
+ self.file: Optional[str] = namespace.file
+
+ @staticmethod
+ def register_args_subparser(
+ subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
+ ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
+ schema = subparser.add_parser(
+ "schema", help="Shows JSON-schema repersentation of the Knot Resolver's configuration."
+ )
+ schema.add_argument(
+ "-l",
+ "--live",
+ help="Get configuration JSON-schema from the running resolver. Requires connection to the management API.",
+ action="store_true",
+ default=False,
+ )
+ schema.add_argument("file", help="Optional, file where to export JSON-schema.", nargs="?", default=None)
+
+ return schema, SchemaCommand
+
+ @staticmethod
+ def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
+ return {}
+ # return parser_words(parser._actions) # pylint: disable=W0212
+
+ def run(self, args: CommandArgs) -> None:
+ if self.live:
+ response = request(args.socket, "GET", "schema")
+ if response.status != 200:
+ print(response, file=sys.stderr)
+ sys.exit(1)
+ schema = response.body
+ else:
+ schema = json.dumps(KresConfig.json_schema(), indent=4)
+
+ if self.file:
+ with open(self.file, "w") as f:
+ f.write(schema)
+ else:
+ print(schema)
diff --git a/python/knot_resolver/client/commands/stop.py b/python/knot_resolver/client/commands/stop.py
new file mode 100644
index 00000000..35baf36c
--- /dev/null
+++ b/python/knot_resolver/client/commands/stop.py
@@ -0,0 +1,32 @@
+import argparse
+import sys
+from typing import List, Tuple, Type
+
+from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command
+from knot_resolver.utils.requests import request
+
+
+@register_command
+class StopCommand(Command):
+ def __init__(self, namespace: argparse.Namespace) -> None:
+ super().__init__(namespace)
+
+ @staticmethod
+ def register_args_subparser(
+ subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
+ ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
+ stop = subparser.add_parser(
+ "stop", help="Tells the resolver to shutdown everthing. No process will run after this command."
+ )
+ return stop, StopCommand
+
+ def run(self, args: CommandArgs) -> None:
+ response = request(args.socket, "POST", "stop")
+
+ if response.status != 200:
+ print(response, file=sys.stderr)
+ sys.exit(1)
+
+ @staticmethod
+ def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
+ return {}
diff --git a/python/knot_resolver/client/commands/validate.py b/python/knot_resolver/client/commands/validate.py
new file mode 100644
index 00000000..267bf562
--- /dev/null
+++ b/python/knot_resolver/client/commands/validate.py
@@ -0,0 +1,63 @@
+import argparse
+import sys
+from pathlib import Path
+from typing import List, Tuple, Type
+
+from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command
+from knot_resolver.datamodel import KresConfig
+from knot_resolver.datamodel.globals import (
+ Context,
+ reset_global_validation_context,
+ set_global_validation_context,
+)
+from knot_resolver.utils.modeling import try_to_parse
+from knot_resolver.utils.modeling.exceptions import DataParsingError, DataValidationError
+
+
+@register_command
+class ValidateCommand(Command):
+ def __init__(self, namespace: argparse.Namespace) -> None:
+ super().__init__(namespace)
+ self.input_file: str = namespace.input_file
+ self.strict: bool = namespace.strict
+
+ @staticmethod
+ def register_args_subparser(
+ subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
+ ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
+ validate = subparser.add_parser("validate", help="Validates configuration in JSON or YAML format.")
+ validate.set_defaults(strict=True)
+ validate.add_argument(
+ "--no-strict",
+ help="Ignore strict rules during validation, e.g. path/file existence.",
+ action="store_false",
+ dest="strict",
+ )
+ validate.add_argument(
+ "input_file",
+ type=str,
+ nargs="?",
+ help="File with configuration in YAML or JSON format.",
+ default=None,
+ )
+
+ return validate, ValidateCommand
+
+ @staticmethod
+ def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
+ return {}
+
+ def run(self, args: CommandArgs) -> None:
+ if self.input_file:
+ with open(self.input_file, "r") as f:
+ data = f.read()
+ else:
+ data = input("Type configuration to validate: ")
+
+ try:
+ set_global_validation_context(Context(Path(self.input_file).parent, self.strict))
+ KresConfig(try_to_parse(data))
+ reset_global_validation_context()
+ except (DataParsingError, DataValidationError) as e:
+ print(e, file=sys.stderr)
+ sys.exit(1)
diff --git a/python/knot_resolver/client/main.py b/python/knot_resolver/client/main.py
new file mode 100644
index 00000000..933da54d
--- /dev/null
+++ b/python/knot_resolver/client/main.py
@@ -0,0 +1,69 @@
+import argparse
+import importlib
+import os
+
+from .command import install_commands_parsers
+from .client import KresClient, KRES_CLIENT_NAME
+
+
+def auto_import_commands() -> None:
+ prefix = f"{'.'.join(__name__.split('.')[:-1])}.commands."
+ for module_name in os.listdir(os.path.dirname(__file__) + "/commands"):
+ if module_name[-3:] != ".py":
+ continue
+ importlib.import_module(f"{prefix}{module_name[:-3]}")
+
+
+def create_main_argument_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser(
+ KRES_CLIENT_NAME,
+ description="Knot Resolver command-line utility that serves as a client for communicating with the Knot Resolver management API."
+ " The utility also provides tools to work with the resolver's declarative configuration (validate, convert, ...).",
+ )
+ # parser.add_argument(
+ # "-i",
+ # "--interactive",
+ # action="store_true",
+ # help="Use the utility in interactive mode.",
+ # default=False,
+ # required=False,
+ # )
+ config_or_socket = parser.add_mutually_exclusive_group()
+ config_or_socket.add_argument(
+ "-s",
+ "--socket",
+ action="store",
+ type=str,
+ help="Optional, path to the resolver's management API, unix-domain socket, or network interface."
+ " Cannot be used together with '--config'.",
+ default=[],
+ nargs=1,
+ required=False,
+ )
+ config_or_socket.add_argument(
+ "-c",
+ "--config",
+ action="store",
+ type=str,
+ help="Optional, path to the resolver's declarative configuration to retrieve the management API configuration."
+ " Cannot be used together with '--socket'.",
+ default=[],
+ nargs=1,
+ required=False,
+ )
+ return parser
+
+
+def main() -> None:
+ auto_import_commands()
+ parser = create_main_argument_parser()
+ install_commands_parsers(parser)
+
+ namespace = parser.parse_args()
+ client = KresClient(namespace, parser)
+ client.execute()
+
+ # if namespace.interactive or len(vars(namespace)) == 2:
+ # client.interactive()
+ # else:
+ # client.execute()
diff --git a/python/knot_resolver/compat/__init__.py b/python/knot_resolver/compat/__init__.py
new file mode 100644
index 00000000..53993f6c
--- /dev/null
+++ b/python/knot_resolver/compat/__init__.py
@@ -0,0 +1,3 @@
+from . import asyncio
+
+__all__ = ["asyncio"]
diff --git a/python/knot_resolver/compat/asyncio.py b/python/knot_resolver/compat/asyncio.py
new file mode 100644
index 00000000..9e10e6c6
--- /dev/null
+++ b/python/knot_resolver/compat/asyncio.py
@@ -0,0 +1,128 @@
+# We disable pylint checks, because it can't find methods in newer Python versions.
+#
+# pylint: disable=no-member
+
+import asyncio
+import functools
+import logging
+import sys
+from asyncio import AbstractEventLoop, coroutines, events, tasks
+from typing import Any, Awaitable, Callable, Coroutine, Optional, TypeVar
+
+logger = logging.getLogger(__name__)
+
+T = TypeVar("T")
+
+
+async def to_thread(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
+ # version 3.9 and higher, call directly
+ if sys.version_info.major >= 3 and sys.version_info.minor >= 9:
+ return await asyncio.to_thread(func, *args, **kwargs) # type: ignore[attr-defined]
+
+ # earlier versions, run with default executor
+ else:
+ loop = asyncio.get_event_loop()
+ pfunc = functools.partial(func, *args, **kwargs)
+ res = await loop.run_in_executor(None, pfunc)
+ return res
+
+
+def async_in_a_thread(func: Callable[..., T]) -> Callable[..., Coroutine[None, None, T]]:
+ async def wrapper(*args: Any, **kwargs: Any) -> T:
+ return await to_thread(func, *args, **kwargs)
+
+ return wrapper
+
+
+def create_task(coro: Awaitable[T], name: Optional[str] = None) -> "asyncio.Task[T]":
+ # version 3.8 and higher, call directly
+ if sys.version_info.major >= 3 and sys.version_info.minor >= 8:
+ # pylint: disable=unexpected-keyword-arg
+ return asyncio.create_task(coro, name=name) # type: ignore[attr-defined,arg-type,call-arg]
+
+ # version 3.7 and higher, call directly without the name argument
+ if sys.version_info.major >= 3 and sys.version_info.minor >= 8:
+ return asyncio.create_task(coro) # type: ignore[attr-defined,arg-type]
+
+ # earlier versions, use older function
+ else:
+ return asyncio.ensure_future(coro)
+
+
+def is_event_loop_running() -> bool:
+ loop = events._get_running_loop() # pylint: disable=protected-access
+ return loop is not None and loop.is_running()
+
+
+def run(coro: Awaitable[T], debug: Optional[bool] = None) -> T:
+ # Adapted version of this:
+ # https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py#L8
+
+ # version 3.7 and higher, call directly
+ # disabled due to incompatibilities
+ if sys.version_info.major >= 3 and sys.version_info.minor >= 7:
+ return asyncio.run(coro, debug=debug) # type: ignore[attr-defined,arg-type]
+
+ # earlier versions, use backported version of the function
+ if events._get_running_loop() is not None: # pylint: disable=protected-access
+ raise RuntimeError("asyncio.run() cannot be called from a running event loop")
+
+ if not coroutines.iscoroutine(coro):
+ raise ValueError(f"a coroutine was expected, got {repr(coro)}")
+
+ loop = events.new_event_loop()
+ try:
+ events.set_event_loop(loop)
+ if debug is not None:
+ loop.set_debug(debug)
+ return loop.run_until_complete(coro)
+ finally:
+ try:
+ _cancel_all_tasks(loop)
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ if hasattr(loop, "shutdown_default_executor"):
+ loop.run_until_complete(loop.shutdown_default_executor()) # type: ignore[attr-defined]
+ finally:
+ events.set_event_loop(None)
+ loop.close()
+
+
+def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
+ # Backported from:
+ # https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py#L55-L74
+ #
+ to_cancel = tasks.all_tasks(loop)
+ if not to_cancel:
+ return
+
+ for task in to_cancel:
+ task.cancel()
+
+ if sys.version_info.minor >= 7:
+ # since 3.7, the loop argument is implicitely the running loop
+ # since 3.10, the loop argument is removed
+ loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
+ else:
+ loop.run_until_complete(tasks.gather(*to_cancel, loop=loop, return_exceptions=True)) # type: ignore[call-overload]
+
+ for task in to_cancel:
+ if task.cancelled():
+ continue
+ if task.exception() is not None:
+ loop.call_exception_handler(
+ {
+ "message": "unhandled exception during asyncio.run() shutdown",
+ "exception": task.exception(),
+ "task": task,
+ }
+ )
+
+
+def add_async_signal_handler(signal: int, callback: Callable[[], Coroutine[Any, Any, None]]) -> None:
+ loop = asyncio.get_event_loop()
+ loop.add_signal_handler(signal, lambda: create_task(callback()))
+
+
+def remove_signal_handler(signal: int) -> bool:
+ loop = asyncio.get_event_loop()
+ return loop.remove_signal_handler(signal)
diff --git a/python/knot_resolver/controller/__init__.py b/python/knot_resolver/controller/__init__.py
new file mode 100644
index 00000000..2398347e
--- /dev/null
+++ b/python/knot_resolver/controller/__init__.py
@@ -0,0 +1,94 @@
+"""
+This file contains autodetection logic for available subprocess controllers. Because we have to catch errors
+from imports, they are located in functions which are invoked at the end of this file.
+
+We supported multiple subprocess controllers while developing it. It now all converged onto just supervisord.
+The interface however remains so that different controllers can be added in the future.
+"""
+
+# pylint: disable=import-outside-toplevel
+
+import asyncio
+import logging
+from typing import List, Optional
+
+from knot_resolver.datamodel.config_schema import KresConfig
+from knot_resolver.controller.interface import SubprocessController
+
+logger = logging.getLogger(__name__)
+
+"""
+List of all subprocess controllers that are available in order of priority.
+It is filled dynamically based on available modules that do not fail to import.
+"""
+_registered_controllers: List[SubprocessController] = []
+
+
+def try_supervisord():
+ """
+ Attempt to load supervisord controllers.
+ """
+ try:
+ from knot_resolver.controller.supervisord import SupervisordSubprocessController
+
+ _registered_controllers.append(SupervisordSubprocessController())
+ except ImportError:
+ logger.error("Failed to import modules related to supervisord service manager", exc_info=True)
+
+
+async def get_best_controller_implementation(config: KresConfig) -> SubprocessController:
+ logger.info("Starting service manager auto-selection...")
+
+ if len(_registered_controllers) == 0:
+ logger.error("No controllers are available! Did you install all dependencies?")
+ raise LookupError("No service managers available!")
+
+ # check all controllers concurrently
+ res = await asyncio.gather(*(cont.is_controller_available(config) for cont in _registered_controllers))
+ logger.info(
+ "Available subprocess controllers are %s",
+ str(tuple((str(c) for r, c in zip(res, _registered_controllers) if r))),
+ )
+
+ # take the first one on the list which is available
+ for avail, controller in zip(res, _registered_controllers):
+ if avail:
+ logger.info("Selected controller '%s'", str(controller))
+ return controller
+
+ # or fail
+ raise LookupError("Can't find any available service manager!")
+
+
+def list_controller_names() -> List[str]:
+ """
+ Returns a list of names of registered controllers. The listed controllers are not necessarly functional.
+ """
+
+ return [str(controller) for controller in sorted(_registered_controllers, key=str)]
+
+
+async def get_controller_by_name(config: KresConfig, name: str) -> SubprocessController:
+ logger.debug("Subprocess controller selected manualy by the user, testing feasibility...")
+
+ controller: Optional[SubprocessController] = None
+ for c in sorted(_registered_controllers, key=str):
+ if str(c).startswith(name):
+ if str(c) != name:
+ logger.debug("Assuming '%s' is a shortcut for '%s'", name, str(c))
+ controller = c
+ break
+
+ if controller is None:
+ logger.error("Subprocess controller with name '%s' was not found", name)
+ raise LookupError(f"No subprocess controller named '{name}' found")
+
+ if await controller.is_controller_available(config):
+ logger.info("Selected controller '%s'", str(controller))
+ return controller
+ else:
+ raise LookupError("The selected subprocess controller is not available for use on this system.")
+
+
+# run the imports on module load
+try_supervisord()
diff --git a/python/knot_resolver/controller/interface.py b/python/knot_resolver/controller/interface.py
new file mode 100644
index 00000000..02bbaa50
--- /dev/null
+++ b/python/knot_resolver/controller/interface.py
@@ -0,0 +1,296 @@
+import asyncio
+import itertools
+import json
+import logging
+import struct
+import sys
+from abc import ABC, abstractmethod # pylint: disable=no-name-in-module
+from enum import Enum, auto
+from pathlib import Path
+from typing import Dict, Iterable, Optional, Type, TypeVar
+from weakref import WeakValueDictionary
+
+from knot_resolver.manager.constants import kresd_config_file, policy_loader_config_file
+from knot_resolver.datamodel.config_schema import KresConfig
+from knot_resolver.manager.exceptions import SubprocessControllerException
+from knot_resolver.controller.registered_workers import register_worker, unregister_worker
+from knot_resolver.utils.async_utils import writefile
+
+logger = logging.getLogger(__name__)
+
+
+class SubprocessType(Enum):
+ KRESD = auto()
+ POLICY_LOADER = auto()
+ GC = auto()
+
+
+class SubprocessStatus(Enum):
+ RUNNING = auto()
+ FATAL = auto()
+ EXITED = auto()
+ UNKNOWN = auto()
+
+
+T = TypeVar("T", bound="KresID")
+
+
+class KresID:
+ """
+ ID object used for identifying subprocesses.
+ """
+
+ _used: "Dict[SubprocessType, WeakValueDictionary[int, KresID]]" = {k: WeakValueDictionary() for k in SubprocessType}
+
+ @classmethod
+ def alloc(cls: Type[T], typ: SubprocessType) -> T:
+ # find free ID closest to zero
+ for i in itertools.count(start=0, step=1):
+ if i not in cls._used[typ]:
+ res = cls.new(typ, i)
+ return res
+
+ raise RuntimeError("Reached an end of an infinite loop. How?")
+
+ @classmethod
+ def new(cls: "Type[T]", typ: SubprocessType, n: int) -> "T":
+ if n in cls._used[typ]:
+ # Ignoring typing here, because I can't find a way how to make the _used dict
+ # typed based on subclass. I am not even sure that it's different between subclasses,
+ # it's probably still the same dict. But we don't really care about it
+ return cls._used[typ][n] # type: ignore
+ else:
+ val = cls(typ, n, _i_know_what_i_am_doing=True)
+ cls._used[typ][n] = val
+ return val
+
+ def __init__(self, typ: SubprocessType, n: int, _i_know_what_i_am_doing: bool = False):
+ if not _i_know_what_i_am_doing:
+ raise RuntimeError("Don't do this. You seem to have no idea what it does")
+
+ self._id = n
+ self._type = typ
+
+ @property
+ def subprocess_type(self) -> SubprocessType:
+ return self._type
+
+ def __repr__(self) -> str:
+ return f"KresID({self})"
+
+ def __hash__(self) -> int:
+ return self._id
+
+ def __eq__(self, o: object) -> bool:
+ if isinstance(o, KresID):
+ return self._type == o._type and self._id == o._id
+ return False
+
+ def __str__(self) -> str:
+ """
+ Returns string representation of the ID usable directly in the underlying service manager
+ """
+ raise NotImplementedError()
+
+ @staticmethod
+ def from_string(val: str) -> "KresID":
+ """
+ Inverse of __str__
+ """
+ raise NotImplementedError()
+
+ def __int__(self) -> int:
+ return self._id
+
+
+class Subprocess(ABC):
+ """
+ One SubprocessInstance corresponds to one manager's subprocess
+ """
+
+ def __init__(self, config: KresConfig, kresid: KresID) -> None:
+ self._id = kresid
+ self._config = config
+ self._registered_worker: bool = False
+
+ async def start(self, new_config: Optional[KresConfig] = None) -> None:
+ if new_config:
+ self._config = new_config
+
+ config_file: Optional[Path] = None
+ if self.type is SubprocessType.KRESD:
+ config_lua = self._config.render_lua()
+ config_file = kresd_config_file(self._config, self.id)
+ await writefile(config_file, config_lua)
+ elif self.type is SubprocessType.POLICY_LOADER:
+ config_lua = self._config.render_lua_policy()
+ config_file = policy_loader_config_file(self._config)
+ await writefile(config_file, config_lua)
+
+ try:
+ await self._start()
+ if self.type is SubprocessType.KRESD:
+ register_worker(self)
+ self._registered_worker = True
+ except SubprocessControllerException as e:
+ if config_file:
+ config_file.unlink()
+ raise e
+
+ async def apply_new_config(self, new_config: KresConfig) -> None:
+ self._config = new_config
+
+ # update config file
+ logger.debug(f"Writing config file for {self.id}")
+
+ config_file: Optional[Path] = None
+ if self.type is SubprocessType.KRESD:
+ config_lua = self._config.render_lua()
+ config_file = kresd_config_file(self._config, self.id)
+ await writefile(config_file, config_lua)
+ elif self.type is SubprocessType.POLICY_LOADER:
+ config_lua = self._config.render_lua_policy()
+ config_file = policy_loader_config_file(self._config)
+ await writefile(config_file, config_lua)
+
+ # update runtime status
+ logger.debug(f"Restarting {self.id}")
+ await self._restart()
+
+ async def stop(self) -> None:
+ if self._registered_worker:
+ unregister_worker(self)
+ await self._stop()
+ await self.cleanup()
+
+ async def cleanup(self) -> None:
+ """
+ Remove temporary files and all traces of this instance running. It is NOT SAFE to call this while
+ the kresd is running, because it will break automatic restarts (at the very least).
+ """
+
+ if self.type is SubprocessType.KRESD:
+ config_file = kresd_config_file(self._config, self.id)
+ config_file.unlink()
+ elif self.type is SubprocessType.POLICY_LOADER:
+ config_file = policy_loader_config_file(self._config)
+ config_file.unlink()
+
+ def __eq__(self, o: object) -> bool:
+ return isinstance(o, type(self)) and o.type == self.type and o.id == self.id
+
+ def __hash__(self) -> int:
+ return hash(type(self)) ^ hash(self.type) ^ hash(self.id)
+
+ @abstractmethod
+ async def _start(self) -> None:
+ pass
+
+ @abstractmethod
+ async def _stop(self) -> None:
+ pass
+
+ @abstractmethod
+ async def _restart(self) -> None:
+ pass
+
+ @abstractmethod
+ def status(self) -> SubprocessStatus:
+ pass
+
+ @property
+ def type(self) -> SubprocessType:
+ return self.id.subprocess_type
+
+ @property
+ def id(self) -> KresID:
+ return self._id
+
+ async def command(self, cmd: str) -> object:
+ if not self._registered_worker:
+ raise RuntimeError("the command cannot be sent to a process other than the kresd worker")
+
+ reader: asyncio.StreamReader
+ writer: Optional[asyncio.StreamWriter] = None
+
+ try:
+ reader, writer = await asyncio.open_unix_connection(f"./control/{int(self.id)}")
+
+ # drop prompt
+ _ = await reader.read(2)
+
+ # switch to JSON mode
+ writer.write("__json\n".encode("utf8"))
+
+ # write command
+ writer.write(cmd.encode("utf8"))
+ writer.write(b"\n")
+ await writer.drain()
+
+ # read result
+ (msg_len,) = struct.unpack(">I", await reader.read(4))
+ result_bytes = await reader.readexactly(msg_len)
+ return json.loads(result_bytes.decode("utf8"))
+
+ finally:
+ if writer is not None:
+ writer.close()
+
+ # proper closing of the socket is only implemented in later versions of python
+ if sys.version_info.minor >= 7:
+ await writer.wait_closed() # type: ignore
+
+
+class SubprocessController(ABC):
+ """
+ The common Subprocess Controller interface. This is what KresManager requires and what has to be implemented by all
+ controllers.
+ """
+
+ @abstractmethod
+ async def is_controller_available(self, config: KresConfig) -> bool:
+ """
+ Returns bool, whether the controller is available with the given config
+ """
+
+ @abstractmethod
+ async def initialize_controller(self, config: KresConfig) -> None:
+ """
+ Should be called when we want to really start using the controller with a specific configuration
+ """
+
+ @abstractmethod
+ async def get_all_running_instances(self) -> Iterable[Subprocess]:
+ """
+
+ Must NOT be called before initialize_controller()
+ """
+
+ @abstractmethod
+ async def shutdown_controller(self) -> None:
+ """
+ Called when the manager is gracefully shutting down. Allows us to stop
+ the service manager process or simply cleanup, so that we don't reuse
+ the same resources in a new run.
+
+ Must NOT be called before initialize_controller()
+ """
+
+ @abstractmethod
+ async def create_subprocess(self, subprocess_config: KresConfig, subprocess_type: SubprocessType) -> Subprocess:
+ """
+ Return a Subprocess object which can be operated on. The subprocess is not
+ started or in any way active after this call. That has to be performaed manually
+ using the returned object itself.
+
+ Must NOT be called before initialize_controller()
+ """
+
+ @abstractmethod
+ async def get_subprocess_status(self) -> Dict[KresID, SubprocessStatus]:
+ """
+ Get a status of running subprocesses as seen by the controller. This method actively polls
+ for information.
+
+ Must NOT be called before initialize_controller()
+ """
diff --git a/python/knot_resolver/controller/registered_workers.py b/python/knot_resolver/controller/registered_workers.py
new file mode 100644
index 00000000..2d3176c3
--- /dev/null
+++ b/python/knot_resolver/controller/registered_workers.py
@@ -0,0 +1,49 @@
+import asyncio
+import logging
+from typing import TYPE_CHECKING, Dict, List, Tuple
+
+from knot_resolver.manager.exceptions import SubprocessControllerException
+
+if TYPE_CHECKING:
+ from knot_resolver.controller.interface import KresID, Subprocess
+
+
+logger = logging.getLogger(__name__)
+
+
+_REGISTERED_WORKERS: "Dict[KresID, Subprocess]" = {}
+
+
+def get_registered_workers_kresids() -> "List[KresID]":
+ return list(_REGISTERED_WORKERS.keys())
+
+
+async def command_single_registered_worker(cmd: str) -> "Tuple[KresID, object]":
+ for sub in _REGISTERED_WORKERS.values():
+ return sub.id, await sub.command(cmd)
+ raise SubprocessControllerException(
+ "Unable to execute the command. There is no kresd worker running to execute the command."
+ "Try start/restart the resolver.",
+ )
+
+
+async def command_registered_workers(cmd: str) -> "Dict[KresID, object]":
+ async def single_pair(sub: "Subprocess") -> "Tuple[KresID, object]":
+ return sub.id, await sub.command(cmd)
+
+ pairs = await asyncio.gather(*(single_pair(inst) for inst in _REGISTERED_WORKERS.values()))
+ return dict(pairs)
+
+
+def unregister_worker(subprocess: "Subprocess") -> None:
+ """
+ Unregister kresd worker "Subprocess" from the list.
+ """
+ del _REGISTERED_WORKERS[subprocess.id]
+
+
+def register_worker(subprocess: "Subprocess") -> None:
+ """
+ Register kresd worker "Subprocess" on the list.
+ """
+ _REGISTERED_WORKERS[subprocess.id] = subprocess
diff --git a/python/knot_resolver/controller/supervisord/__init__.py b/python/knot_resolver/controller/supervisord/__init__.py
new file mode 100644
index 00000000..592b76be
--- /dev/null
+++ b/python/knot_resolver/controller/supervisord/__init__.py
@@ -0,0 +1,281 @@
+import logging
+from os import kill # pylint: disable=[no-name-in-module]
+from pathlib import Path
+from typing import Any, Dict, Iterable, NoReturn, Optional, Union, cast
+from xmlrpc.client import Fault, ServerProxy
+
+import supervisor.xmlrpc # type: ignore[import]
+
+from knot_resolver.compat.asyncio import async_in_a_thread
+from knot_resolver.manager.constants import supervisord_config_file, supervisord_pid_file, supervisord_sock_file
+from knot_resolver.datamodel.config_schema import KresConfig
+from knot_resolver.manager.exceptions import CancelStartupExecInsteadException, SubprocessControllerException
+from knot_resolver.controller.interface import (
+ KresID,
+ Subprocess,
+ SubprocessController,
+ SubprocessStatus,
+ SubprocessType,
+)
+from knot_resolver.controller.supervisord.config_file import SupervisordKresID, write_config_file
+from knot_resolver.utils import which
+from knot_resolver.utils.async_utils import call, readfile
+
+logger = logging.getLogger(__name__)
+
+
+async def _start_supervisord(config: KresConfig) -> None:
+ logger.debug("Writing supervisord config")
+ await write_config_file(config)
+ logger.debug("Starting supervisord")
+ res = await call(["supervisord", "--configuration", str(supervisord_config_file(config).absolute())])
+ if res != 0:
+ raise SubprocessControllerException(f"Supervisord exited with exit code {res}")
+
+
+async def _exec_supervisord(config: KresConfig) -> NoReturn:
+ logger.debug("Writing supervisord config")
+ await write_config_file(config)
+ logger.debug("Execing supervisord")
+ raise CancelStartupExecInsteadException(
+ [
+ str(which.which("supervisord")),
+ "supervisord",
+ "--configuration",
+ str(supervisord_config_file(config).absolute()),
+ ]
+ )
+
+
+async def _reload_supervisord(config: KresConfig) -> None:
+ await write_config_file(config)
+ try:
+ supervisord = _create_supervisord_proxy(config)
+ supervisord.reloadConfig()
+ except Fault as e:
+ raise SubprocessControllerException("supervisord reload failed") from e
+
+
+@async_in_a_thread
+def _stop_supervisord(config: KresConfig) -> None:
+ supervisord = _create_supervisord_proxy(config)
+ # pid = supervisord.getPID()
+ try:
+ # we might be trying to shut down supervisord at a moment, when it's waiting
+ # for us to stop. Therefore, this shutdown request for supervisord might
+ # die and it's not a problem.
+ supervisord.shutdown()
+ except Fault as e:
+ if e.faultCode == 6 and e.faultString == "SHUTDOWN_STATE":
+ # supervisord is already stopping, so it's fine
+ pass
+ else:
+ # something wrong happened, let's be loud about it
+ raise
+
+ # We could remove the configuration, but there is actually no specific need to do so.
+ # If we leave it behind, someone might find it and use it to start us from scratch again,
+ # which is perfectly fine.
+ # supervisord_config_file(config).unlink()
+
+
+async def _is_supervisord_available() -> bool:
+ # yes, it is! The code in this file wouldn't be running without it due to imports :)
+
+ # so let's just check that we can find supervisord and supervisorctl binaries
+ try:
+ which.which("supervisord")
+ which.which("supervisorctl")
+ except RuntimeError:
+ logger.error("Failed to find supervisord or supervisorctl executables in $PATH")
+ return False
+
+ return True
+
+
+async def _get_supervisord_pid(config: KresConfig) -> Optional[int]:
+ if not Path(supervisord_pid_file(config)).exists():
+ return None
+
+ return int(await readfile(supervisord_pid_file(config)))
+
+
+def _is_process_runinng(pid: int) -> bool:
+ try:
+ # kill with signal 0 is a safe way to test that a process exists
+ kill(pid, 0)
+ return True
+ except ProcessLookupError:
+ return False
+
+
+async def _is_supervisord_running(config: KresConfig) -> bool:
+ pid = await _get_supervisord_pid(config)
+ if pid is None:
+ return False
+ elif not _is_process_runinng(pid):
+ supervisord_pid_file(config).unlink()
+ return False
+ else:
+ return True
+
+
+def _create_proxy(config: KresConfig) -> ServerProxy:
+ return ServerProxy(
+ "http://127.0.0.1",
+ transport=supervisor.xmlrpc.SupervisorTransport(
+ None, None, serverurl="unix://" + str(supervisord_sock_file(config))
+ ),
+ )
+
+
+def _create_supervisord_proxy(config: KresConfig) -> Any:
+ proxy = _create_proxy(config)
+ return getattr(proxy, "supervisor")
+
+
+def _create_fast_proxy(config: KresConfig) -> Any:
+ proxy = _create_proxy(config)
+ return getattr(proxy, "fast")
+
+
+def _convert_subprocess_status(proc: Any) -> SubprocessStatus:
+ conversion_tbl = {
+ # "STOPPED": None, # filtered out elsewhere
+ "STARTING": SubprocessStatus.RUNNING,
+ "RUNNING": SubprocessStatus.RUNNING,
+ "BACKOFF": SubprocessStatus.RUNNING,
+ "STOPPING": SubprocessStatus.RUNNING,
+ "EXITED": SubprocessStatus.EXITED,
+ "FATAL": SubprocessStatus.FATAL,
+ "UNKNOWN": SubprocessStatus.UNKNOWN,
+ }
+
+ if proc["statename"] in conversion_tbl:
+ status = conversion_tbl[proc["statename"]]
+ else:
+ logger.warning(f"Unknown supervisord process state {proc['statename']}")
+ status = SubprocessStatus.UNKNOWN
+ return status
+
+
+def _list_running_subprocesses(config: KresConfig) -> Dict[SupervisordKresID, SubprocessStatus]:
+ try:
+ supervisord = _create_supervisord_proxy(config)
+ processes: Any = supervisord.getAllProcessInfo()
+ except Fault as e:
+ raise SubprocessControllerException(f"failed to get info from all running processes: {e}") from e
+
+ # there will be a manager process as well, but we don't want to report anything on ourselves
+ processes = [pr for pr in processes if pr["name"] != "manager"]
+
+ # convert all the names
+ return {
+ SupervisordKresID.from_string(f"{pr['group']}:{pr['name']}"): _convert_subprocess_status(pr)
+ for pr in processes
+ if pr["statename"] != "STOPPED"
+ }
+
+
+class SupervisordSubprocess(Subprocess):
+ def __init__(
+ self,
+ config: KresConfig,
+ controller: "SupervisordSubprocessController",
+ base_id: Union[SubprocessType, SupervisordKresID],
+ ):
+ if isinstance(base_id, SubprocessType):
+ super().__init__(config, SupervisordKresID.alloc(base_id))
+ else:
+ super().__init__(config, base_id)
+ self._controller: "SupervisordSubprocessController" = controller
+
+ @property
+ def name(self):
+ return str(self.id)
+
+ def status(self) -> SubprocessStatus:
+ try:
+ supervisord = _create_supervisord_proxy(self._config)
+ status = supervisord.getProcessInfo(self.name)
+ except Fault as e:
+ raise SubprocessControllerException(f"failed to get status from '{self.id}' process: {e}") from e
+ return _convert_subprocess_status(status)
+
+ @async_in_a_thread
+ def _start(self) -> None:
+ # +1 for canary process (same as in config_file.py)
+ assert int(self.id) <= int(self._config.max_workers) + 1, "trying to spawn more than allowed limit of workers"
+ try:
+ supervisord = _create_fast_proxy(self._config)
+ supervisord.startProcess(self.name)
+ except Fault as e:
+ raise SubprocessControllerException(f"failed to start '{self.id}'") from e
+
+ @async_in_a_thread
+ def _stop(self) -> None:
+ supervisord = _create_supervisord_proxy(self._config)
+ supervisord.stopProcess(self.name)
+
+ @async_in_a_thread
+ def _restart(self) -> None:
+ supervisord = _create_supervisord_proxy(self._config)
+ supervisord.stopProcess(self.name)
+ fast = _create_fast_proxy(self._config)
+ fast.startProcess(self.name)
+
+ def get_used_config(self) -> KresConfig:
+ return self._config
+
+
+class SupervisordSubprocessController(SubprocessController):
+ def __init__(self): # pylint: disable=super-init-not-called
+ self._controller_config: Optional[KresConfig] = None
+
+ def __str__(self):
+ return "supervisord"
+
+ async def is_controller_available(self, config: KresConfig) -> bool:
+ res = await _is_supervisord_available()
+ if not res:
+ logger.info("Failed to find usable supervisord.")
+
+ logger.debug("Detection - supervisord controller is available for use")
+ return res
+
+ async def get_all_running_instances(self) -> Iterable[Subprocess]:
+ assert self._controller_config is not None
+
+ if await _is_supervisord_running(self._controller_config):
+ states = _list_running_subprocesses(self._controller_config)
+ return [
+ SupervisordSubprocess(self._controller_config, self, id_)
+ for id_ in states
+ if states[id_] == SubprocessStatus.RUNNING
+ ]
+ else:
+ return []
+
+ async def initialize_controller(self, config: KresConfig) -> None:
+ self._controller_config = config
+
+ if not await _is_supervisord_running(config):
+ logger.info(
+ "We want supervisord to restart us when needed, we will therefore exec() it and let it start us again."
+ )
+ await _exec_supervisord(config)
+ else:
+ logger.info("Supervisord is already running, we will just update its config...")
+ await _reload_supervisord(config)
+
+ async def shutdown_controller(self) -> None:
+ assert self._controller_config is not None
+ await _stop_supervisord(self._controller_config)
+
+ async def create_subprocess(self, subprocess_config: KresConfig, subprocess_type: SubprocessType) -> Subprocess:
+ return SupervisordSubprocess(subprocess_config, self, subprocess_type)
+
+ @async_in_a_thread
+ def get_subprocess_status(self) -> Dict[KresID, SubprocessStatus]:
+ assert self._controller_config is not None
+ return cast(Dict[KresID, SubprocessStatus], _list_running_subprocesses(self._controller_config))
diff --git a/python/knot_resolver/controller/supervisord/config_file.py b/python/knot_resolver/controller/supervisord/config_file.py
new file mode 100644
index 00000000..d9a79f9e
--- /dev/null
+++ b/python/knot_resolver/controller/supervisord/config_file.py
@@ -0,0 +1,197 @@
+import logging
+import os
+
+from dataclasses import dataclass
+from pathlib import Path
+from jinja2 import Template
+from typing_extensions import Literal
+
+from knot_resolver.manager.constants import (
+ kres_gc_executable,
+ kresd_cache_dir,
+ kresd_config_file_supervisord_pattern,
+ kresd_executable,
+ policy_loader_config_file,
+ supervisord_config_file,
+ supervisord_config_file_tmp,
+ supervisord_pid_file,
+ supervisord_sock_file,
+ supervisord_subprocess_log_dir,
+ user_constants,
+)
+from knot_resolver.datamodel.config_schema import KresConfig
+from knot_resolver.datamodel.logging_schema import LogTargetEnum
+from knot_resolver.controller.interface import KresID, SubprocessType
+from knot_resolver.utils.async_utils import read_resource, writefile
+
+logger = logging.getLogger(__name__)
+
+
+class SupervisordKresID(KresID):
+ # WARNING: be really careful with renaming. If the naming schema is changing,
+ # we should be able to parse the old one as well, otherwise updating manager will
+ # cause weird behavior
+
+ @staticmethod
+ def from_string(val: str) -> "SupervisordKresID":
+ # the double name is checked because thats how we read it from supervisord
+ if val in ("cache-gc", "cache-gc:cache-gc"):
+ return SupervisordKresID.new(SubprocessType.GC, 0)
+ elif val in ("policy-loader", "policy-loader:policy-loader"):
+ return SupervisordKresID.new(SubprocessType.POLICY_LOADER, 0)
+ else:
+ val = val.replace("kresd:kresd", "")
+ return SupervisordKresID.new(SubprocessType.KRESD, int(val))
+
+ def __str__(self) -> str:
+ if self.subprocess_type is SubprocessType.GC:
+ return "cache-gc"
+ elif self.subprocess_type is SubprocessType.POLICY_LOADER:
+ return "policy-loader"
+ elif self.subprocess_type is SubprocessType.KRESD:
+ return f"kresd:kresd{self._id}"
+ else:
+ raise RuntimeError(f"Unexpected subprocess type {self.subprocess_type}")
+
+
+def kres_cache_gc_args(config: KresConfig) -> str:
+ args = ""
+
+ if config.logging.level == "debug" or (config.logging.groups and "cache-gc" in config.logging.groups):
+ args += " -v"
+
+ gc_config = config.cache.garbage_collector
+ if gc_config:
+ args += (
+ f" -d {gc_config.interval.millis()}"
+ f" -u {gc_config.threshold}"
+ f" -f {gc_config.release}"
+ f" -l {gc_config.rw_deletes}"
+ f" -L {gc_config.rw_reads}"
+ f" -t {gc_config.temp_keys_space.mbytes()}"
+ f" -m {gc_config.rw_duration.micros()}"
+ f" -w {gc_config.rw_delay.micros()}"
+ )
+ if gc_config.dry_run:
+ args += " -n"
+ return args
+ raise ValueError("missing configuration for the cache garbage collector")
+
+
+@dataclass
+class ProcessTypeConfig:
+ """
+ Data structure holding data for supervisord config template
+ """
+
+ logfile: Path
+ workdir: str
+ command: str
+ environment: str
+ max_procs: int = 1
+
+ @staticmethod
+ def create_gc_config(config: KresConfig) -> "ProcessTypeConfig":
+ cwd = str(os.getcwd())
+ return ProcessTypeConfig( # type: ignore[call-arg]
+ logfile=supervisord_subprocess_log_dir(config) / "gc.log",
+ workdir=cwd,
+ command=f"{kres_gc_executable()} -c {kresd_cache_dir(config)}{kres_cache_gc_args(config)}",
+ environment="",
+ )
+
+ @staticmethod
+ def create_policy_loader_config(config: KresConfig) -> "ProcessTypeConfig":
+ cwd = str(os.getcwd())
+ return ProcessTypeConfig( # type: ignore[call-arg]
+ logfile=supervisord_subprocess_log_dir(config) / "policy-loader.log",
+ workdir=cwd,
+ command=f"{kresd_executable()} -c {(policy_loader_config_file(config))} -c - -n",
+ environment="X-SUPERVISORD-TYPE=notify",
+ )
+
+ @staticmethod
+ def create_kresd_config(config: KresConfig) -> "ProcessTypeConfig":
+ cwd = str(os.getcwd())
+ return ProcessTypeConfig( # type: ignore[call-arg]
+ logfile=supervisord_subprocess_log_dir(config) / "kresd%(process_num)d.log",
+ workdir=cwd,
+ command=f"{kresd_executable()} -c {kresd_config_file_supervisord_pattern(config)} -n",
+ environment='SYSTEMD_INSTANCE="%(process_num)d",X-SUPERVISORD-TYPE=notify',
+ max_procs=int(config.max_workers) + 1, # +1 for the canary process
+ )
+
+ @staticmethod
+ def create_manager_config(_config: KresConfig) -> "ProcessTypeConfig":
+ # read original command from /proc
+ with open("/proc/self/cmdline", "rb") as f:
+ args = [s.decode("utf-8") for s in f.read()[:-1].split(b"\0")]
+
+ # insert debugger when asked
+ if os.environ.get("KRES_DEBUG_MANAGER"):
+ logger.warning("Injecting debugger into the supervisord config")
+ # the args array looks like this:
+ # [PYTHON_PATH, "-m", "knot_resolver", ...]
+ args = args[:1] + ["-m", "debugpy", "--listen", "0.0.0.0:5678", "--wait-for-client"] + args[2:]
+
+ cmd = '"' + '" "'.join(args) + '"'
+
+ return ProcessTypeConfig( # type: ignore[call-arg]
+ workdir=user_constants().working_directory_on_startup,
+ command=cmd,
+ environment="X-SUPERVISORD-TYPE=notify",
+ logfile=Path(""), # this will be ignored
+ )
+
+
+@dataclass
+class SupervisordConfig:
+ unix_http_server: Path
+ pid_file: Path
+ workdir: str
+ logfile: Path
+ loglevel: Literal["critical", "error", "warn", "info", "debug", "trace", "blather"]
+ target: LogTargetEnum
+
+ @staticmethod
+ def create(config: KresConfig) -> "SupervisordConfig":
+ # determine the correct logging level
+ if config.logging.groups and "supervisord" in config.logging.groups:
+ loglevel = "info"
+ else:
+ loglevel = {
+ "crit": "critical",
+ "err": "error",
+ "warning": "warn",
+ "notice": "warn",
+ "info": "info",
+ "debug": "debug",
+ }[config.logging.level]
+ cwd = str(os.getcwd())
+ return SupervisordConfig( # type: ignore[call-arg]
+ unix_http_server=supervisord_sock_file(config),
+ pid_file=supervisord_pid_file(config),
+ workdir=cwd,
+ logfile=Path("syslog" if config.logging.target == "syslog" else "/dev/null"),
+ loglevel=loglevel, # type: ignore[arg-type]
+ target=config.logging.target,
+ )
+
+
+async def write_config_file(config: KresConfig) -> None:
+ if not supervisord_subprocess_log_dir(config).exists():
+ supervisord_subprocess_log_dir(config).mkdir(exist_ok=True)
+
+ template = await read_resource(__package__, "supervisord.conf.j2")
+ assert template is not None
+ template = template.decode("utf8")
+ config_string = Template(template).render(
+ gc=ProcessTypeConfig.create_gc_config(config),
+ loader=ProcessTypeConfig.create_policy_loader_config(config),
+ kresd=ProcessTypeConfig.create_kresd_config(config),
+ manager=ProcessTypeConfig.create_manager_config(config),
+ config=SupervisordConfig.create(config),
+ )
+ await writefile(supervisord_config_file_tmp(config), config_string)
+ # atomically replace (we don't technically need this right now, but better safe then sorry)
+ os.rename(supervisord_config_file_tmp(config), supervisord_config_file(config))
diff --git a/python/knot_resolver/controller/supervisord/plugin/fast_rpcinterface.py b/python/knot_resolver/controller/supervisord/plugin/fast_rpcinterface.py
new file mode 100644
index 00000000..c3834784
--- /dev/null
+++ b/python/knot_resolver/controller/supervisord/plugin/fast_rpcinterface.py
@@ -0,0 +1,173 @@
+# type: ignore
+# pylint: skip-file
+
+"""
+This file is modified version of supervisord's source code:
+https://github.com/Supervisor/supervisor/blob/5d9c39619e2e7e7fca33c890cb2a9f2d3d0ab762/supervisor/rpcinterface.py
+
+The changes made are:
+
+ - removed everything that we do not need, reformatted to fit our code stylepo (2022-06-24)
+ - made startProcess faster by setting delay to 0 (2022-06-24)
+
+
+The original supervisord licence follows:
+--------------------------------------------------------------------
+
+Supervisor is licensed under the following license:
+
+ A copyright notice accompanies this license document that identifies
+ the copyright holders.
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions are
+ met:
+
+ 1. Redistributions in source code must retain the accompanying
+ copyright notice, this list of conditions, and the following
+ disclaimer.
+
+ 2. Redistributions in binary form must reproduce the accompanying
+ copyright notice, this list of conditions, and the following
+ disclaimer in the documentation and/or other materials provided
+ with the distribution.
+
+ 3. Names of the copyright holders must not be used to endorse or
+ promote products derived from this software without prior
+ written permission from the copyright holders.
+
+ 4. If any files are modified, you must cause the modified files to
+ carry prominent notices stating that you changed the files and
+ the date of any change.
+
+ Disclaimer
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND
+ ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
+ TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
+ PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ HOLDERS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
+ TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
+ TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
+ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
+ SUCH DAMAGE.
+"""
+
+from supervisor.http import NOT_DONE_YET
+from supervisor.options import BadCommand, NoPermission, NotExecutable, NotFound, split_namespec
+from supervisor.states import RUNNING_STATES, ProcessStates, SupervisorStates
+from supervisor.xmlrpc import Faults, RPCError
+
+
+class SupervisorNamespaceRPCInterface:
+ def __init__(self, supervisord):
+ self.supervisord = supervisord
+
+ def _update(self, text):
+ self.update_text = text # for unit tests, mainly
+ if isinstance(self.supervisord.options.mood, int) and self.supervisord.options.mood < SupervisorStates.RUNNING:
+ raise RPCError(Faults.SHUTDOWN_STATE)
+
+ # RPC API methods
+
+ def _getGroupAndProcess(self, name):
+ # get process to start from name
+ group_name, process_name = split_namespec(name)
+
+ group = self.supervisord.process_groups.get(group_name)
+ if group is None:
+ raise RPCError(Faults.BAD_NAME, name)
+
+ if process_name is None:
+ return group, None
+
+ process = group.processes.get(process_name)
+ if process is None:
+ raise RPCError(Faults.BAD_NAME, name)
+
+ return group, process
+
+ def startProcess(self, name, wait=True):
+ """Start a process
+
+ @param string name Process name (or ``group:name``, or ``group:*``)
+ @param boolean wait Wait for process to be fully started
+ @return boolean result Always true unless error
+
+ """
+ self._update("startProcess")
+ group, process = self._getGroupAndProcess(name)
+ if process is None:
+ group_name, process_name = split_namespec(name)
+ return self.startProcessGroup(group_name, wait)
+
+ # test filespec, don't bother trying to spawn if we know it will
+ # eventually fail
+ try:
+ filename, argv = process.get_execv_args()
+ except NotFound as why:
+ raise RPCError(Faults.NO_FILE, why.args[0])
+ except (BadCommand, NotExecutable, NoPermission) as why:
+ raise RPCError(Faults.NOT_EXECUTABLE, why.args[0])
+
+ if process.get_state() in RUNNING_STATES:
+ raise RPCError(Faults.ALREADY_STARTED, name)
+
+ if process.get_state() == ProcessStates.UNKNOWN:
+ raise RPCError(Faults.FAILED, "%s is in an unknown process state" % name)
+
+ process.spawn()
+
+ # We call reap() in order to more quickly obtain the side effects of
+ # process.finish(), which reap() eventually ends up calling. This
+ # might be the case if the spawn() was successful but then the process
+ # died before its startsecs elapsed or it exited with an unexpected
+ # exit code. In particular, finish() may set spawnerr, which we can
+ # check and immediately raise an RPCError, avoiding the need to
+ # defer by returning a callback.
+
+ self.supervisord.reap()
+
+ if process.spawnerr:
+ raise RPCError(Faults.SPAWN_ERROR, name)
+
+ # We call process.transition() in order to more quickly obtain its
+ # side effects. In particular, it might set the process' state from
+ # STARTING->RUNNING if the process has a startsecs==0.
+ process.transition()
+
+ if wait and process.get_state() != ProcessStates.RUNNING:
+ # by default, this branch will almost always be hit for processes
+ # with default startsecs configurations, because the default number
+ # of startsecs for a process is "1", and the process will not have
+ # entered the RUNNING state yet even though we've called
+ # transition() on it. This is because a process is not considered
+ # RUNNING until it has stayed up > startsecs.
+
+ def onwait():
+ if process.spawnerr:
+ raise RPCError(Faults.SPAWN_ERROR, name)
+
+ state = process.get_state()
+
+ if state not in (ProcessStates.STARTING, ProcessStates.RUNNING):
+ raise RPCError(Faults.ABNORMAL_TERMINATION, name)
+
+ if state == ProcessStates.RUNNING:
+ return True
+
+ return NOT_DONE_YET
+
+ onwait.delay = 0
+ onwait.rpcinterface = self
+ return onwait # deferred
+
+ return True
+
+
+# this is not used in code but referenced via an entry point in the conf file
+def make_main_rpcinterface(supervisord):
+ return SupervisorNamespaceRPCInterface(supervisord)
diff --git a/python/knot_resolver/controller/supervisord/plugin/manager_integration.py b/python/knot_resolver/controller/supervisord/plugin/manager_integration.py
new file mode 100644
index 00000000..2fc8cf94
--- /dev/null
+++ b/python/knot_resolver/controller/supervisord/plugin/manager_integration.py
@@ -0,0 +1,85 @@
+# type: ignore
+# pylint: disable=protected-access
+import atexit
+import os
+import signal
+from typing import Any, Optional
+
+from supervisor.compat import as_string
+from supervisor.events import ProcessStateFatalEvent, ProcessStateRunningEvent, ProcessStateStartingEvent, subscribe
+from supervisor.options import ServerOptions
+from supervisor.process import Subprocess
+from supervisor.states import SupervisorStates
+from supervisor.supervisord import Supervisor
+
+from knot_resolver.utils.systemd_notify import systemd_notify
+
+superd: Optional[Supervisor] = None
+
+
+def check_for_fatal_manager(event: ProcessStateFatalEvent) -> None:
+ assert superd is not None
+
+ proc: Subprocess = event.process
+ processname = as_string(proc.config.name)
+ if processname == "manager":
+ # stop the whole supervisord gracefully
+ superd.options.logger.critical("manager process entered FATAL state! Shutting down")
+ superd.options.mood = SupervisorStates.SHUTDOWN
+
+ # force the interpreter to exit with exit code 1
+ atexit.register(lambda: os._exit(1))
+
+
+def check_for_starting_manager(event: ProcessStateStartingEvent) -> None:
+ assert superd is not None
+
+ proc: Subprocess = event.process
+ processname = as_string(proc.config.name)
+ if processname == "manager":
+ # manager has sucessfully started, report it upstream
+ systemd_notify(STATUS="Starting services...")
+
+
+def check_for_runnning_manager(event: ProcessStateRunningEvent) -> None:
+ assert superd is not None
+
+ proc: Subprocess = event.process
+ processname = as_string(proc.config.name)
+ if processname == "manager":
+ # manager has sucessfully started, report it upstream
+ systemd_notify(READY="1", STATUS="Ready")
+
+
+def ServerOptions_get_signal(self):
+ sig = self.signal_receiver.get_signal()
+ if sig == signal.SIGHUP and superd is not None:
+ superd.options.logger.info("received SIGHUP, forwarding to the process 'manager'")
+ manager_pid = superd.process_groups["manager"].processes["manager"].pid
+ os.kill(manager_pid, signal.SIGHUP)
+ return None
+
+ return sig
+
+
+def inject(supervisord: Supervisor, **_config: Any) -> Any: # pylint: disable=useless-return
+ global superd
+ superd = supervisord
+
+ # This status notification here unsets the env variable $NOTIFY_SOCKET provided by systemd
+ # and stores it locally. Therefore, it shouldn't clash with $NOTIFY_SOCKET we are providing
+ # downstream
+ systemd_notify(STATUS="Initializing supervisord...")
+
+ # register events
+ subscribe(ProcessStateFatalEvent, check_for_fatal_manager)
+ subscribe(ProcessStateStartingEvent, check_for_starting_manager)
+ subscribe(ProcessStateRunningEvent, check_for_runnning_manager)
+
+ # forward SIGHUP to manager
+ ServerOptions.get_signal = ServerOptions_get_signal
+
+ # this method is called by supervisord when loading the plugin,
+ # it should return XML-RPC object, which we don't care about
+ # That's why why are returning just None
+ return None
diff --git a/python/knot_resolver/controller/supervisord/plugin/notifymodule.c b/python/knot_resolver/controller/supervisord/plugin/notifymodule.c
new file mode 100644
index 00000000..d56ee7d2
--- /dev/null
+++ b/python/knot_resolver/controller/supervisord/plugin/notifymodule.c
@@ -0,0 +1,176 @@
+#define PY_SSIZE_T_CLEAN
+#include <Python.h>
+
+#include <stdbool.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <unistd.h>
+#include <sys/types.h>
+#include <stdlib.h>
+#include <string.h>
+#include <errno.h>
+#include <sys/socket.h>
+#include <fcntl.h>
+#include <stddef.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+
+#define CONTROL_SOCKET_NAME "knot-resolver-control-socket"
+#define NOTIFY_SOCKET_NAME "NOTIFY_SOCKET"
+#define MODULE_NAME "notify"
+#define RECEIVE_BUFFER_SIZE 2048
+
+static PyObject *NotifySocketError;
+
+static PyObject *init_control_socket(PyObject *self, PyObject *args)
+{
+ /* create socket */
+ int controlfd = socket(AF_UNIX, SOCK_DGRAM | SOCK_NONBLOCK, 0);
+ if (controlfd == -1) {
+ PyErr_SetFromErrno(NotifySocketError);
+ return NULL;
+ }
+
+ /* create address */
+ struct sockaddr_un server_addr;
+ bzero(&server_addr, sizeof(server_addr));
+ server_addr.sun_family = AF_UNIX;
+ server_addr.sun_path[0] = '\0'; // mark it as abstract namespace socket
+ strcpy(server_addr.sun_path + 1, CONTROL_SOCKET_NAME);
+ size_t addr_len = offsetof(struct sockaddr_un, sun_path) +
+ strlen(CONTROL_SOCKET_NAME) + 1;
+
+ /* bind to the address */
+ int res = bind(controlfd, (struct sockaddr *)&server_addr, addr_len);
+ if (res < 0) {
+ PyErr_SetFromErrno(NotifySocketError);
+ return NULL;
+ }
+
+ /* make sure that we are send credentials */
+ int data = (int)true;
+ res = setsockopt(controlfd, SOL_SOCKET, SO_PASSCRED, &data,
+ sizeof(data));
+ if (res < 0) {
+ PyErr_SetFromErrno(NotifySocketError);
+ return NULL;
+ }
+
+ /* store the name of the socket in env to fake systemd */
+ char *old_value = getenv(NOTIFY_SOCKET_NAME);
+ if (old_value != NULL) {
+ printf("[notify_socket] warning, running under systemd and overwriting $%s\n",
+ NOTIFY_SOCKET_NAME);
+ // fixme
+ }
+
+ res = setenv(NOTIFY_SOCKET_NAME, "@" CONTROL_SOCKET_NAME, 1);
+ if (res < 0) {
+ PyErr_SetFromErrno(NotifySocketError);
+ return NULL;
+ }
+
+ return PyLong_FromLong((long)controlfd);
+}
+
+static PyObject *handle_control_socket_connection_event(PyObject *self,
+ PyObject *args)
+{
+ long controlfd;
+ if (!PyArg_ParseTuple(args, "i", &controlfd))
+ return NULL;
+
+ /* read command assuming it fits and it was sent all at once */
+ // prepare space to read filedescriptors
+ struct msghdr msg;
+ msg.msg_name = NULL;
+ msg.msg_namelen = 0;
+
+ // prepare a place to read the actual message
+ char place_for_data[RECEIVE_BUFFER_SIZE];
+ bzero(&place_for_data, sizeof(place_for_data));
+ struct iovec iov = { .iov_base = &place_for_data,
+ .iov_len = sizeof(place_for_data) };
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ char cmsg[CMSG_SPACE(sizeof(struct ucred))];
+ msg.msg_control = cmsg;
+ msg.msg_controllen = sizeof(cmsg);
+
+ /* Receive real plus ancillary data */
+ int len = recvmsg(controlfd, &msg, 0);
+ if (len == -1) {
+ if (errno == EWOULDBLOCK || errno == EAGAIN) {
+ Py_RETURN_NONE;
+ } else {
+ PyErr_SetFromErrno(NotifySocketError);
+ return NULL;
+ }
+ }
+
+ /* read the sender pid */
+ struct cmsghdr *cmsgp = CMSG_FIRSTHDR(&msg);
+ pid_t pid = -1;
+ while (cmsgp != NULL) {
+ if (cmsgp->cmsg_type == SCM_CREDENTIALS) {
+ if (
+ cmsgp->cmsg_len != CMSG_LEN(sizeof(struct ucred)) ||
+ cmsgp->cmsg_level != SOL_SOCKET
+ ) {
+ printf("[notify_socket] invalid cmsg data, ignoring\n");
+ Py_RETURN_NONE;
+ }
+
+ struct ucred cred;
+ memcpy(&cred, CMSG_DATA(cmsgp), sizeof(cred));
+ pid = cred.pid;
+ }
+ cmsgp = CMSG_NXTHDR(&msg, cmsgp);
+ }
+ if (pid == -1) {
+ printf("[notify_socket] ignoring received data without credentials: %s\n",
+ place_for_data);
+ Py_RETURN_NONE;
+ }
+
+ /* return received data as a tuple (pid, data bytes) */
+ return Py_BuildValue("iy", pid, place_for_data);
+}
+
+static PyMethodDef NotifyMethods[] = {
+ { "init_socket", init_control_socket, METH_VARARGS,
+ "Init notify socket. Returns it's file descriptor." },
+ { "read_message", handle_control_socket_connection_event, METH_VARARGS,
+ "Reads datagram from notify socket. Returns tuple of PID and received bytes." },
+ { NULL, NULL, 0, NULL } /* Sentinel */
+};
+
+static struct PyModuleDef notifymodule = {
+ PyModuleDef_HEAD_INIT, MODULE_NAME, /* name of module */
+ NULL, /* module documentation, may be NULL */
+ -1, /* size of per-interpreter state of the module,
+ or -1 if the module keeps state in global variables. */
+ NotifyMethods
+};
+
+PyMODINIT_FUNC PyInit_notify(void)
+{
+ PyObject *m;
+
+ m = PyModule_Create(&notifymodule);
+ if (m == NULL)
+ return NULL;
+
+ NotifySocketError =
+ PyErr_NewException(MODULE_NAME ".error", NULL, NULL);
+ Py_XINCREF(NotifySocketError);
+ if (PyModule_AddObject(m, "error", NotifySocketError) < 0) {
+ Py_XDECREF(NotifySocketError);
+ Py_CLEAR(NotifySocketError);
+ Py_DECREF(m);
+ return NULL;
+ }
+
+ return m;
+}
diff --git a/python/knot_resolver/controller/supervisord/plugin/patch_logger.py b/python/knot_resolver/controller/supervisord/plugin/patch_logger.py
new file mode 100644
index 00000000..411f232e
--- /dev/null
+++ b/python/knot_resolver/controller/supervisord/plugin/patch_logger.py
@@ -0,0 +1,97 @@
+# type: ignore
+# pylint: disable=protected-access
+
+import os
+import sys
+import traceback
+from typing import Any
+
+from supervisor.dispatchers import POutputDispatcher
+from supervisor.loggers import LevelsByName, StreamHandler, SyslogHandler
+from supervisor.supervisord import Supervisor
+from typing_extensions import Literal
+
+FORWARD_LOG_LEVEL = LevelsByName.CRIT # to make sure it's always printed
+
+
+def empty_function(*args, **kwargs):
+ pass
+
+
+FORWARD_MSG_FORMAT: str = "%(name)s[%(pid)d]%(stream)s: %(data)s"
+
+
+def POutputDispatcher_log(self: POutputDispatcher, data: bytearray):
+ if data:
+ # parse the input
+ if not isinstance(data, bytes):
+ text = data
+ else:
+ try:
+ text = data.decode("utf-8")
+ except UnicodeDecodeError:
+ text = "Undecodable: %r" % data
+
+ # print line by line prepending correct prefix to match the style
+ config = self.process.config
+ config.options.logger.handlers = forward_handlers
+ for line in text.splitlines():
+ stream = ""
+ if self.channel == "stderr":
+ stream = " (stderr)"
+ config.options.logger.log(
+ FORWARD_LOG_LEVEL, FORWARD_MSG_FORMAT, name=config.name, stream=stream, data=line, pid=self.process.pid
+ )
+ config.options.logger.handlers = supervisord_handlers
+
+
+def _create_handler(fmt, level, target: Literal["stdout", "stderr", "syslog"]) -> StreamHandler:
+ if target == "syslog":
+ handler = SyslogHandler()
+ else:
+ handler = StreamHandler(sys.stdout if target == "stdout" else sys.stderr)
+ handler.setFormat(fmt)
+ handler.setLevel(level)
+ return handler
+
+
+supervisord_handlers = []
+forward_handlers = []
+
+
+def inject(supervisord: Supervisor, **config: Any) -> Any: # pylint: disable=useless-return
+ try:
+ # reconfigure log handlers
+ supervisord.options.logger.info("reconfiguring log handlers")
+ supervisord_handlers.append(
+ _create_handler(
+ f"%(asctime)s supervisor[{os.getpid()}]: [%(levelname)s] %(message)s\n",
+ supervisord.options.loglevel,
+ config["target"],
+ )
+ )
+ forward_handlers.append(
+ _create_handler("%(asctime)s %(message)s\n", supervisord.options.loglevel, config["target"])
+ )
+ supervisord.options.logger.handlers = supervisord_handlers
+
+ # replace output handler for subprocesses
+ POutputDispatcher._log = POutputDispatcher_log
+
+ # we forward stdio in all cases, even when logging to syslog. This should prevent the unforturtunate
+ # case of swallowing an error message leaving the users confused. To make the forwarded lines obvious
+ # we just prepend a explanatory string at the beginning of all messages
+ if config["target"] == "syslog":
+ global FORWARD_MSG_FORMAT
+ FORWARD_MSG_FORMAT = "captured stdio output from " + FORWARD_MSG_FORMAT
+
+ # this method is called by supervisord when loading the plugin,
+ # it should return XML-RPC object, which we don't care about
+ # That's why why are returning just None
+ return None
+
+ # if we fail to load the module, print some explanation
+ # should not happen when run by endusers
+ except BaseException:
+ traceback.print_exc()
+ raise
diff --git a/python/knot_resolver/controller/supervisord/plugin/sd_notify.py b/python/knot_resolver/controller/supervisord/plugin/sd_notify.py
new file mode 100644
index 00000000..ff32828b
--- /dev/null
+++ b/python/knot_resolver/controller/supervisord/plugin/sd_notify.py
@@ -0,0 +1,227 @@
+# type: ignore
+# pylint: disable=protected-access
+# pylint: disable=c-extension-no-member
+import os
+import signal
+import time
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
+
+from supervisor.events import ProcessStateEvent, ProcessStateStartingEvent, subscribe
+from supervisor.medusa.asyncore_25 import compact_traceback
+from supervisor.process import Subprocess
+from supervisor.states import ProcessStates
+from supervisor.supervisord import Supervisor
+
+from knot_resolver.controller.supervisord.plugin import notify
+
+starting_processes: List[Subprocess] = []
+
+
+def is_type_notify(proc: Subprocess) -> bool:
+ return proc.config.environment is not None and proc.config.environment.get("X-SUPERVISORD-TYPE", None) == "notify"
+
+
+class NotifySocketDispatcher:
+ """
+ See supervisor.dispatcher
+ """
+
+ def __init__(self, supervisor: Supervisor, fd: int):
+ self._supervisor = supervisor
+ self.fd = fd
+ self.closed = False # True if close() has been called
+
+ def __repr__(self):
+ return f"<{self.__class__.__name__} with fd={self.fd}>"
+
+ def readable(self):
+ return True
+
+ def writable(self):
+ return False
+
+ def handle_read_event(self):
+ logger: Any = self._supervisor.options.logger
+
+ res: Optional[Tuple[int, bytes]] = notify.read_message(self.fd)
+ if res is None:
+ return None # there was some junk
+ pid, data = res
+
+ # pylint: disable=undefined-loop-variable
+ for proc in starting_processes:
+ if proc.pid == pid:
+ break
+ else:
+ logger.warn(f"ignoring ready notification from unregistered PID={pid}")
+ return None
+
+ if data.startswith(b"READY=1"):
+ # handle case, when some process is really ready
+
+ if is_type_notify(proc):
+ proc._assertInState(ProcessStates.STARTING)
+ proc.change_state(ProcessStates.RUNNING)
+ logger.info(
+ f"success: {proc.config.name} entered RUNNING state, process sent notification via $NOTIFY_SOCKET"
+ )
+ else:
+ logger.warn(f"ignoring READY notification from {proc.config.name}, which is not configured to send it")
+
+ elif data.startswith(b"STOPPING=1"):
+ # just accept the message, filter unwanted notifications and do nothing else
+
+ if is_type_notify(proc):
+ logger.info(
+ f"success: {proc.config.name} entered STOPPING state, process sent notification via $NOTIFY_SOCKET"
+ )
+ else:
+ logger.warn(
+ f"ignoring STOPPING notification from {proc.config.name}, which is not configured to send it"
+ )
+
+ else:
+ # handle case, when we got something unexpected
+ logger.warn(f"ignoring unrecognized data on $NOTIFY_SOCKET sent from PID={pid}, data='{data!r}'")
+ return None
+
+ def handle_write_event(self):
+ raise ValueError("this dispatcher is not writable")
+
+ def handle_error(self):
+ _nil, t, v, tbinfo = compact_traceback()
+
+ self._supervisor.options.logger.error(
+ f"uncaptured python exception, closing notify socket {repr(self)} ({t}:{v} {tbinfo})"
+ )
+ self.close()
+
+ def close(self):
+ if not self.closed:
+ os.close(self.fd)
+ self.closed = True
+
+ def flush(self):
+ pass
+
+
+def keep_track_of_starting_processes(event: ProcessStateEvent) -> None:
+ global starting_processes
+
+ proc: Subprocess = event.process
+
+ if isinstance(event, ProcessStateStartingEvent):
+ # process is starting
+ # if proc not in starting_processes:
+ starting_processes.append(proc)
+
+ else:
+ # not starting
+ starting_processes = [p for p in starting_processes if p.pid is not proc.pid]
+
+
+notify_dispatcher: Optional[NotifySocketDispatcher] = None
+
+
+def process_transition(slf: Subprocess) -> None:
+ if not is_type_notify(slf):
+ return slf
+
+ # modified version of upstream process transition code
+ if slf.state == ProcessStates.STARTING:
+ if time.time() - slf.laststart > slf.config.startsecs:
+ # STARTING -> STOPPING if the process has not sent ready notification
+ # within proc.config.startsecs
+ slf.config.options.logger.warn(
+ f"process '{slf.config.name}' did not send ready notification within {slf.config.startsecs} secs, killing"
+ )
+ slf.kill(signal.SIGKILL)
+ slf.x_notifykilled = True # used in finish() function to set to FATAL state
+ slf.laststart = time.time() + 1 # prevent immediate state transition to RUNNING from happening
+
+ # return self for chaining
+ return slf
+
+
+def subprocess_finish_tail(slf, pid, sts) -> Tuple[Any, Any, Any]:
+ if getattr(slf, "x_notifykilled", False):
+ # we want FATAL, not STOPPED state after timeout waiting for startup notification
+ # why? because it's likely not gonna help to try starting the process up again if
+ # it failed so early
+ slf.change_state(ProcessStates.FATAL)
+
+ # clear the marker value
+ del slf.x_notifykilled
+
+ # return for chaining
+ return slf, pid, sts
+
+
+def supervisord_get_process_map(supervisord: Any, mp: Dict[Any, Any]) -> Dict[Any, Any]:
+ global notify_dispatcher
+ if notify_dispatcher is None:
+ notify_dispatcher = NotifySocketDispatcher(supervisord, notify.init_socket())
+ supervisord.options.logger.info("notify: injected $NOTIFY_SOCKET into event loop")
+
+ # add our dispatcher to the result
+ assert notify_dispatcher.fd not in mp
+ mp[notify_dispatcher.fd] = notify_dispatcher
+
+ return mp
+
+
+def process_spawn_as_child_add_env(slf: Subprocess, *args: Any) -> Tuple[Any, ...]:
+ if is_type_notify(slf):
+ slf.config.environment["NOTIFY_SOCKET"] = "@knot-resolver-control-socket"
+ return (slf, *args)
+
+
+T = TypeVar("T")
+U = TypeVar("U")
+
+
+def chain(first: Callable[..., U], second: Callable[[U], T]) -> Callable[..., T]:
+ def wrapper(*args: Any, **kwargs: Any) -> T:
+ res = first(*args, **kwargs)
+ if isinstance(res, tuple):
+ return second(*res)
+ else:
+ return second(res)
+
+ return wrapper
+
+
+def append(first: Callable[..., T], second: Callable[..., None]) -> Callable[..., T]:
+ def wrapper(*args: Any, **kwargs: Any) -> T:
+ res = first(*args, **kwargs)
+ second(*args, **kwargs)
+ return res
+
+ return wrapper
+
+
+def monkeypatch(supervisord: Supervisor) -> None:
+ """Inject ourselves into supervisord code"""
+
+ # append notify socket handler to event loopo
+ supervisord.get_process_map = chain(supervisord.get_process_map, partial(supervisord_get_process_map, supervisord))
+
+ # prepend timeout handler to transition method
+ Subprocess.transition = chain(process_transition, Subprocess.transition)
+ Subprocess.finish = append(Subprocess.finish, subprocess_finish_tail)
+
+ # add environment variable $NOTIFY_SOCKET to starting processes
+ Subprocess._spawn_as_child = chain(process_spawn_as_child_add_env, Subprocess._spawn_as_child)
+
+ # keep references to starting subprocesses
+ subscribe(ProcessStateEvent, keep_track_of_starting_processes)
+
+
+def inject(supervisord: Supervisor, **_config: Any) -> Any: # pylint: disable=useless-return
+ monkeypatch(supervisord)
+
+ # this method is called by supervisord when loading the plugin,
+ # it should return XML-RPC object, which we don't care about
+ # That's why why are returning just None
+ return None
diff --git a/python/knot_resolver/controller/supervisord/supervisord.conf.j2 b/python/knot_resolver/controller/supervisord/supervisord.conf.j2
new file mode 100644
index 00000000..4179d522
--- /dev/null
+++ b/python/knot_resolver/controller/supervisord/supervisord.conf.j2
@@ -0,0 +1,93 @@
+[supervisord]
+pidfile = {{ config.pid_file }}
+directory = {{ config.workdir }}
+nodaemon = true
+
+{# disable initial logging until patch_logger.py takes over #}
+logfile = /dev/null
+logfile_maxbytes = 0
+silent = true
+
+{# config for patch_logger.py #}
+loglevel = {{ config.loglevel }}
+{# there are more options in the plugin section #}
+
+[unix_http_server]
+file = {{ config.unix_http_server }}
+
+[supervisorctl]
+serverurl = unix://{{ config.unix_http_server }}
+
+{# Extensions to changing the supervisord behavior #}
+[rpcinterface:patch_logger]
+supervisor.rpcinterface_factory = knot_resolver.controller.supervisord.plugin.patch_logger:inject
+target = {{ config.target }}
+
+[rpcinterface:manager_integration]
+supervisor.rpcinterface_factory = knot_resolver.controller.supervisord.plugin.manager_integration:inject
+
+[rpcinterface:sd_notify]
+supervisor.rpcinterface_factory = knot_resolver.controller.supervisord.plugin.sd_notify:inject
+
+{# Extensions for actual API control #}
+[rpcinterface:supervisor]
+supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface
+
+[rpcinterface:fast]
+supervisor.rpcinterface_factory = knot_resolver.controller.supervisord.plugin.fast_rpcinterface:make_main_rpcinterface
+
+[program:manager]
+redirect_stderr=false
+directory={{ manager.workdir }}
+command={{ manager.command }}
+stopsignal=SIGINT
+killasgroup=true
+autorestart=true
+autostart=true
+{# Note that during startup,
+ manager will signal being ready only after sequential startup of all kresd workers,
+ i.e. it might take lots of time currently, if the user configured very large rulesets (e.g. huge RPZ).
+ Let's permit it lots of time, assuming that useful work is being done.
+#}
+startsecs=600
+environment={{ manager.environment }},KRES_SUPRESS_LOG_PREFIX=true
+stdout_logfile=NONE
+stderr_logfile=NONE
+
+[program:kresd]
+process_name=%(program_name)s%(process_num)d
+numprocs={{ kresd.max_procs }}
+directory={{ kresd.workdir }}
+command={{ kresd.command }}
+autostart=false
+autorestart=true
+stopsignal=TERM
+killasgroup=true
+startsecs=60
+environment={{ kresd.environment }}
+stdout_logfile=NONE
+stderr_logfile=NONE
+
+[program:policy-loader]
+directory={{ loader.workdir }}
+command={{ loader.command }}
+autostart=false
+stopsignal=TERM
+killasgroup=true
+startsecs=300
+environment={{ loader.environment }}
+stdout_logfile=NONE
+stderr_logfile=NONE
+
+[program:cache-gc]
+redirect_stderr=false
+directory={{ gc.workdir }}
+command={{ gc.command }}
+autostart=false
+autorestart=true
+stopsignal=TERM
+killasgroup=true
+startsecs=0
+environment={{ gc.environment }}
+stdout_logfile=NONE
+stderr_logfile=NONE
diff --git a/python/knot_resolver/datamodel/__init__.py b/python/knot_resolver/datamodel/__init__.py
new file mode 100644
index 00000000..a0174acc
--- /dev/null
+++ b/python/knot_resolver/datamodel/__init__.py
@@ -0,0 +1,3 @@
+from .config_schema import KresConfig
+
+__all__ = ["KresConfig"]
diff --git a/python/knot_resolver/datamodel/cache_schema.py b/python/knot_resolver/datamodel/cache_schema.py
new file mode 100644
index 00000000..eca36bf2
--- /dev/null
+++ b/python/knot_resolver/datamodel/cache_schema.py
@@ -0,0 +1,139 @@
+from typing import List, Optional, Union
+
+from typing_extensions import Literal
+
+from knot_resolver.datamodel.templates import template_from_str
+from knot_resolver.datamodel.types import (
+ DNSRecordTypeEnum,
+ DomainName,
+ EscapedStr,
+ IntNonNegative,
+ IntPositive,
+ Percent,
+ ReadableFile,
+ SizeUnit,
+ TimeUnit,
+ WritableDir,
+)
+from knot_resolver.utils.modeling import ConfigSchema
+from knot_resolver.utils.modeling.base_schema import lazy_default
+
+_CACHE_CLEAR_TEMPLATE = template_from_str(
+ "{% from 'macros/cache_macros.lua.j2' import cache_clear %} {{ cache_clear(params) }}"
+)
+
+
+class CacheClearRPCSchema(ConfigSchema):
+ name: Optional[DomainName] = None
+ exact_name: bool = False
+ rr_type: Optional[DNSRecordTypeEnum] = None
+ chunk_size: IntPositive = IntPositive(100)
+
+ def _validate(self) -> None:
+ if self.rr_type and not self.exact_name:
+ raise ValueError("'rr-type' is only supported with 'exact-name: true'")
+
+ def render_lua(self) -> str:
+ return _CACHE_CLEAR_TEMPLATE.render(params=self) # pyright: reportUnknownMemberType=false
+
+
+class PrefillSchema(ConfigSchema):
+ """
+ Prefill the cache periodically by importing zone data obtained over HTTP.
+
+ ---
+ origin: Origin for the imported data. Cache prefilling is only supported for the root zone ('.').
+ url: URL of the zone data to be imported.
+ refresh_interval: Time interval between consecutive refreshes of the imported zone data.
+ ca_file: Path to the file containing a CA certificate bundle that is used to authenticate the HTTPS connection.
+ """
+
+ origin: DomainName
+ url: EscapedStr
+ refresh_interval: TimeUnit = TimeUnit("1d")
+ ca_file: Optional[ReadableFile] = None
+
+ def _validate(self) -> None:
+ if str(self.origin) != ".":
+ raise ValueError("cache prefilling is not yet supported for non-root zones")
+
+
+class GarbageCollectorSchema(ConfigSchema):
+ """
+ Configuration options of the cache garbage collector (kres-cache-gc).
+
+ ---
+ interval: Time interval how often the garbage collector will be run.
+ threshold: Cache usage in percent that triggers the garbage collector.
+ release: Percent of used cache to be freed by the garbage collector.
+ temp_keys_space: Maximum amount of temporary memory for copied keys (0 = unlimited).
+ rw_deletes: Maximum number of deleted records per read-write transaction (0 = unlimited).
+ rw_reads: Maximum number of readed records per read-write transaction (0 = unlimited).
+ rw_duration: Maximum duration of read-write transaction (0 = unlimited).
+ rw_delay: Wait time between two read-write transactions.
+ dry_run: Run the garbage collector in dry-run mode.
+ """
+
+ interval: TimeUnit = TimeUnit("1s")
+ threshold: Percent = Percent(80)
+ release: Percent = Percent(10)
+ temp_keys_space: SizeUnit = SizeUnit("0M")
+ rw_deletes: IntNonNegative = IntNonNegative(100)
+ rw_reads: IntNonNegative = IntNonNegative(200)
+ rw_duration: TimeUnit = TimeUnit("0us")
+ rw_delay: TimeUnit = TimeUnit("0us")
+ dry_run: bool = False
+
+
+class PredictionSchema(ConfigSchema):
+ """
+ Helps keep the cache hot by prefetching expiring records and learning usage patterns and repetitive queries.
+
+ ---
+ window: Sampling window length.
+ period: Number of windows that can be kept in memory.
+ """
+
+ window: TimeUnit = TimeUnit("15m")
+ period: IntPositive = IntPositive(24)
+
+
+class PrefetchSchema(ConfigSchema):
+ """
+ These options help keep the cache hot by prefetching expiring records or learning usage patterns and repetitive queries.
+ ---
+ expiring: Prefetch expiring records.
+ prediction: Prefetch record by predicting based on usage patterns and repetitive queries.
+ """
+
+ expiring: bool = False
+ prediction: Optional[PredictionSchema] = None
+
+
+class CacheSchema(ConfigSchema):
+ """
+ DNS resolver cache configuration.
+
+ ---
+ storage: Cache storage of the DNS resolver.
+ size_max: Maximum size of the cache.
+ garbage_collector: Use the garbage collector (kres-cache-gc) to periodically clear cache.
+ ttl_min: Minimum time-to-live for the cache entries.
+ ttl_max: Maximum time-to-live for the cache entries.
+ ns_timeout: Time interval for which a nameserver address will be ignored after determining that it does not return (useful) answers.
+ prefill: Prefill the cache periodically by importing zone data obtained over HTTP.
+ prefetch: These options help keep the cache hot by prefetching expiring records or learning usage patterns and repetitive queries.
+ """
+
+ storage: WritableDir = lazy_default(WritableDir, "/var/cache/knot-resolver")
+ size_max: SizeUnit = SizeUnit("100M")
+ garbage_collector: Union[GarbageCollectorSchema, Literal[False]] = GarbageCollectorSchema()
+ ttl_min: TimeUnit = TimeUnit("5s")
+ ttl_max: TimeUnit = TimeUnit("1d")
+ ns_timeout: TimeUnit = TimeUnit("1000ms")
+ prefill: Optional[List[PrefillSchema]] = None
+ prefetch: PrefetchSchema = PrefetchSchema()
+
+ def _validate(self):
+ if self.ttl_min.seconds() >= self.ttl_max.seconds():
+ raise ValueError("'ttl-max' must be larger then 'ttl-min'")
diff --git a/python/knot_resolver/datamodel/config_schema.py b/python/knot_resolver/datamodel/config_schema.py
new file mode 100644
index 00000000..1ee300d8
--- /dev/null
+++ b/python/knot_resolver/datamodel/config_schema.py
@@ -0,0 +1,242 @@
+import logging
+import os
+import socket
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from typing_extensions import Literal
+
+from knot_resolver.manager.constants import MAX_WORKERS
+from knot_resolver.datamodel.cache_schema import CacheSchema
+from knot_resolver.datamodel.dns64_schema import Dns64Schema
+from knot_resolver.datamodel.dnssec_schema import DnssecSchema
+from knot_resolver.datamodel.forward_schema import ForwardSchema
+from knot_resolver.datamodel.local_data_schema import LocalDataSchema, RPZSchema, RuleSchema
+from knot_resolver.datamodel.logging_schema import LoggingSchema
+from knot_resolver.datamodel.lua_schema import LuaSchema
+from knot_resolver.datamodel.management_schema import ManagementSchema
+from knot_resolver.datamodel.monitoring_schema import MonitoringSchema
+from knot_resolver.datamodel.network_schema import NetworkSchema
+from knot_resolver.datamodel.options_schema import OptionsSchema
+from knot_resolver.datamodel.templates import POLICY_CONFIG_TEMPLATE, WORKER_CONFIG_TEMPLATE
+from knot_resolver.datamodel.types import EscapedStr, IntPositive, WritableDir
+from knot_resolver.datamodel.view_schema import ViewSchema
+from knot_resolver.datamodel.webmgmt_schema import WebmgmtSchema
+from knot_resolver.utils.modeling import ConfigSchema
+from knot_resolver.utils.modeling.base_schema import lazy_default
+from knot_resolver.utils.modeling.exceptions import AggregateDataValidationError, DataValidationError
+
+_DEFAULT_RUNDIR = "/var/run/knot-resolver"
+
+DEFAULT_MANAGER_API_SOCK = _DEFAULT_RUNDIR + "/manager.sock"
+
+logger = logging.getLogger(__name__)
+
+
+def _cpu_count() -> Optional[int]:
+ try:
+ return len(os.sched_getaffinity(0))
+ except (NotImplementedError, AttributeError):
+ logger.warning("The number of usable CPUs could not be determined using 'os.sched_getaffinity()'.")
+ cpus = os.cpu_count()
+ if cpus is None:
+ logger.warning("The number of usable CPUs could not be determined using 'os.cpu_count()'.")
+ return cpus
+
+
+def _default_max_worker_count() -> int:
+ c = _cpu_count()
+ if c:
+ return c * 10
+ return MAX_WORKERS
+
+
+def _get_views_tags(views: List[ViewSchema]) -> List[str]:
+ tags = []
+ for view in views:
+ if view.tags:
+ tags += [str(tag) for tag in view.tags if tag not in tags]
+ return tags
+
+
+def _check_local_data_tags(
+ views_tags: List[str], rules_or_rpz: Union[List[RuleSchema], List[RPZSchema]]
+) -> Tuple[List[str], List[DataValidationError]]:
+ tags = []
+ errs = []
+
+ i = 0
+ for rule in rules_or_rpz:
+ tags_not_in = []
+ if rule.tags:
+ for tag in rule.tags:
+ tag_str = str(tag)
+ if tag_str not in tags:
+ tags.append(tag_str)
+ if tag_str not in views_tags:
+ tags_not_in.append(tag_str)
+ if len(tags_not_in) > 0:
+ errs.append(
+ DataValidationError(
+ f"some tags {tags_not_in} not found in '/views' tags", f"/local-data/rules[{i}]/tags"
+ )
+ )
+ i += 1
+ return tags, errs
+
+
+class KresConfig(ConfigSchema):
+ class Raw(ConfigSchema):
+ """
+ Knot Resolver declarative configuration.
+
+ ---
+ version: Version of the configuration schema. By default it is the latest supported by the resolver, but couple of versions back are be supported as well.
+ nsid: Name Server Identifier (RFC 5001) which allows DNS clients to request resolver to send back its NSID along with the reply to a DNS request.
+ hostname: Internal DNS resolver hostname. Default is machine hostname.
+ rundir: Directory where the resolver can create files and which will be it's cwd.
+ workers: The number of running kresd (Knot Resolver daemon) workers. If set to 'auto', it is equal to number of CPUs available.
+ max_workers: The maximum number of workers allowed. Cannot be changed in runtime.
+ management: Configuration of management HTTP API.
+ webmgmt: Configuration of legacy web management endpoint.
+ options: Fine-tuning global parameters of DNS resolver operation.
+ network: Network connections and protocols configuration.
+ views: List of views and its configuration.
+ local_data: Local data for forward records (A/AAAA) and reverse records (PTR).
+ forward: List of Forward Zones and its configuration.
+ cache: DNS resolver cache configuration.
+ dnssec: Disable DNSSEC, enable with defaults or set new configuration.
+ dns64: Disable DNS64 (RFC 6147), enable with defaults or set new configuration.
+ logging: Logging and debugging configuration.
+ monitoring: Metrics exposisition configuration (Prometheus, Graphite)
+ lua: Custom Lua configuration.
+ """
+
+ version: int = 1
+ nsid: Optional[EscapedStr] = None
+ hostname: Optional[EscapedStr] = None
+ rundir: WritableDir = lazy_default(WritableDir, _DEFAULT_RUNDIR)
+ workers: Union[Literal["auto"], IntPositive] = IntPositive(1)
+ max_workers: IntPositive = IntPositive(_default_max_worker_count())
+ management: ManagementSchema = lazy_default(ManagementSchema, {"unix-socket": DEFAULT_MANAGER_API_SOCK})
+ webmgmt: Optional[WebmgmtSchema] = None
+ options: OptionsSchema = OptionsSchema()
+ network: NetworkSchema = NetworkSchema()
+ views: Optional[List[ViewSchema]] = None
+ local_data: LocalDataSchema = LocalDataSchema()
+ forward: Optional[List[ForwardSchema]] = None
+ cache: CacheSchema = lazy_default(CacheSchema, {})
+ dnssec: Union[bool, DnssecSchema] = True
+ dns64: Union[bool, Dns64Schema] = False
+ logging: LoggingSchema = LoggingSchema()
+ monitoring: MonitoringSchema = MonitoringSchema()
+ lua: LuaSchema = LuaSchema()
+
+ _LAYER = Raw
+
+ nsid: Optional[EscapedStr]
+ hostname: EscapedStr
+ rundir: WritableDir
+ workers: IntPositive
+ max_workers: IntPositive
+ management: ManagementSchema
+ webmgmt: Optional[WebmgmtSchema]
+ options: OptionsSchema
+ network: NetworkSchema
+ views: Optional[List[ViewSchema]]
+ local_data: LocalDataSchema
+ forward: Optional[List[ForwardSchema]]
+ cache: CacheSchema
+ dnssec: Union[Literal[False], DnssecSchema]
+ dns64: Union[Literal[False], Dns64Schema]
+ logging: LoggingSchema
+ monitoring: MonitoringSchema
+ lua: LuaSchema
+
+ def _hostname(self, obj: Raw) -> Any:
+ if obj.hostname is None:
+ return socket.gethostname()
+ return obj.hostname
+
+ def _workers(self, obj: Raw) -> Any:
+ if obj.workers == "auto":
+ count = _cpu_count()
+ if count:
+ return IntPositive(count)
+ raise ValueError(
+ "The number of available CPUs to automatically set the number of running 'kresd' workers could not be determined."
+ "The number of workers can be configured manually in 'workers' option."
+ )
+ return obj.workers
+
+ def _dnssec(self, obj: Raw) -> Any:
+ if obj.dnssec is True:
+ return DnssecSchema()
+ return obj.dnssec
+
+ def _dns64(self, obj: Raw) -> Any:
+ if obj.dns64 is True:
+ return Dns64Schema()
+ return obj.dns64
+
+ def _validate(self) -> None:
+ # enforce max-workers config
+ if int(self.workers) > int(self.max_workers):
+ raise ValueError(f"can't run with more workers then the configured maximum {self.max_workers}")
+
+ # sanity check
+ cpu_count = _cpu_count()
+ if cpu_count and int(self.workers) > 10 * cpu_count:
+ raise ValueError(
+ "refusing to run with more then 10 workers per cpu core, the system wouldn't behave nicely"
+ )
+
+ # get all tags from views
+ views_tags = []
+ if self.views:
+ views_tags = _get_views_tags(self.views)
+
+ # get local-data tags and check its existence in views
+ errs = []
+ local_data_tags = []
+ if self.local_data.rules:
+ rules_tags, rules_errs = _check_local_data_tags(views_tags, self.local_data.rules)
+ errs += rules_errs
+ local_data_tags += rules_tags
+ if self.local_data.rpz:
+ rpz_tags, rpz_errs = _check_local_data_tags(views_tags, self.local_data.rpz)
+ errs += rpz_errs
+ local_data_tags += rpz_tags
+
+ # look for unused tags in /views
+ unused_tags = views_tags.copy()
+ for tag in local_data_tags:
+ if tag in unused_tags:
+ unused_tags.remove(tag)
+ if len(unused_tags) > 1:
+ errs.append(DataValidationError(f"unused tags {unused_tags} found", "/views"))
+
+ # raise all validation errors
+ if len(errs) == 1:
+ raise errs[0]
+ elif len(errs) > 1:
+ raise AggregateDataValidationError("/", errs)
+
+ def render_lua(self) -> str:
+ # FIXME the `cwd` argument is used only for configuring control socket path
+ # it should be removed and relative path used instead as soon as issue
+ # https://gitlab.nic.cz/knot/knot-resolver/-/issues/720 is fixed
+ return WORKER_CONFIG_TEMPLATE.render(cfg=self, cwd=os.getcwd())
+
+ def render_lua_policy(self) -> str:
+ return POLICY_CONFIG_TEMPLATE.render(cfg=self, cwd=os.getcwd())
+
+
+def get_rundir_without_validation(data: Dict[str, Any]) -> WritableDir:
+ """
+ Without fully parsing, try to get a rundir from a raw config data, otherwise use default.
+ Attempts a dir validation to produce a good error message.
+
+ Used for initial manager startup.
+ """
+
+ return WritableDir(data["rundir"] if "rundir" in data else _DEFAULT_RUNDIR, object_path="/rundir")
diff --git a/python/knot_resolver/datamodel/design-notes.yml b/python/knot_resolver/datamodel/design-notes.yml
new file mode 100644
index 00000000..e4424bc8
--- /dev/null
+++ b/python/knot_resolver/datamodel/design-notes.yml
@@ -0,0 +1,237 @@
+###### Working notes about configuration schema
+
+
+## TODO nit: nest one level deeper inside `dnssec`, probably
+dnssec:
+ keep-removed: 0
+ refresh-time: 10s
+ hold-down-time: 30d
+
+## TODO nit: I don't like this name, at least not for the experimental thing we have there
+network:
+ tls:
+ auto_discovery: boolean
+
+#### General questions
+Plurals: do we name attributes in plural if they're a list;
+ some of them even allow a non-list if using a single element.
+
+
+#### New-policy brainstorming
+
+dnssec:
+ # Convert to key: style instead of list?
+ # - easier to handle in API/CLI (which might be a common action on names with broken DNSSEC)
+ # - allows to supply a value - stamp for expiration of that NTA
+ # (absolute time, but I can imagine API/CLI converting from duration when executed)
+ # - syntax isn't really more difficult, mainly it forces one entry per line (seems OK)
+ negative-trust-anchors:
+ example.org:
+ my.example.net:
+
+
+view:
+ # When a client request arrives, based on the `view` class of rules we may either
+ # decide for a direct answer or for marking the request with a set of tags.
+ # The concepts of matching and actions are a very good fit for this,
+ # and that matches our old policy approach. Matching here should avoid QNAME+QTYPE;
+ # instead it's e.g. suitable for access control.
+ # RPZ files also support rules that fall into this `view` class.
+ #
+ # Selecting a single rule: the most specific client-IP prefix
+ # that also matches additional conditions.
+ - subnet: [ 0.0.0.0/0, ::/0 ]
+ answer: refused
+ # some might prefer `allow: refused` ?
+ # Also, RCODEs are customary in CAPITALS though maybe not in configs.
+
+ - subnet: [ 10.0.0.0/8, 192.168.0.0/16 ]
+ # Adding `tags` implies allowing the query.
+ tags: [ t1, t2, t3 ] # theoretically we could use space-separated string
+ options: # only some of the global options can be overridden in view
+ minimize: true
+ dns64: true
+ rate-limit: # future option, probably (optionally?) structured
+ # LATER: rulesets are a relatively unclear feature for now.
+ # Their main point is to allow prioritization and avoid
+ # intermixing rules that come from different sources.
+ # Also some properties might be specifyable per ruleset.
+ ruleset: tt
+
+ - subnet: [ 10.0.10.0/24 ] # maybe allow a single value instead of a list?
+ # LATER: special addresses?
+ # - for kresd-internal requests
+ # - shorthands for all private IPv4 and/or IPv6;
+ # though yaml's repeated nodes could mostly cover that
+ # or just copy&paste from docs
+ answer: allow
+
+# Or perhaps a more complex approach? Probably not.
+# We might have multiple conditions at once and multiple actions at once,
+# but I don't expect these to be common, so the complication is probably not worth it.
+# An advantage would be that the separation of the two parts would be more visible.
+view:
+ - match:
+ subnet: [ 10.0.0.0/8, 192.168.0.0/16 ]
+ do:
+ tags: [ t1, t2, t3 ]
+ options: # ...
+
+
+local-data: # TODO: name
+ #FIXME: tags - allow assigning them to (groups of) addresses/records.
+
+ addresses: # automatically adds PTR records and NODATA (LATER: overridable NODATA?)
+ foo.bar: [ 127.0.0.1, ::1 ]
+ my.pc.corp: 192.168.12.95
+ addresses-files: # files in /etc/hosts format (and semantics like `addresses`)
+ - /etc/hosts
+
+ # Zonefile format seems quite handy here. Details:
+ # - probably use `local-data.ttl` from model as the default
+ # - and . root to avoid confusion if someone misses a final dot.
+ records: |
+ example.net. TXT "foo bar"
+ A 192.168.2.3
+ A 192.168.2.4
+ local.example.org AAAA ::1
+
+ subtrees:
+ nodata: true # impl ATM: defaults to false, set (only) for each rule/name separately
+ # impl: options like `ttl` and `nodata` might make sense to be settable (only?) per ruleset
+
+ subtrees: # TODO: perhaps just allow in the -tagged style, if we can't avoid lists anyway?
+ - type: empty
+ roots: [ sub2.example.org ] # TODO: name it the same as for forwarding
+ tags: [ t2 ]
+ - type: nxdomain
+ # Will we need to support multiple file formats in future and choose here?
+ roots-file: /path/to/file.txt
+ - type: empty
+ roots-url: https://example.org/blocklist.txt
+ refresh: 1d
+ # Is it a separate ruleset? Optionally? Persistence?
+ # (probably the same questions for local files as well)
+
+ - type: redirect
+ roots: [ sub4.example.org ]
+ addresses: [ 127.0.0.1, ::1 ]
+
+local-data-tagged: # TODO: name (view?); and even structure seems unclear.
+ # TODO: allow only one "type" per list entry? (addresses / addresses-files / subtrees / ...)
+ - tags: [ t1, t2 ]
+ addresses: #... otherwise the same as local-data
+ - tags: [ t2 ]
+ records: # ...
+ - tags: [ t3 ]
+ subtrees: empty
+ roots: [ sub2.example.org ]
+
+local-data-tagged: # this avoids lists, so it's relatively easy to amend through API
+ "t1 t2": # perhaps it's not nice that tags don't form a proper list?
+ addresses:
+ foo.bar: [ 127.0.0.1, ::1 ]
+ t4:
+ addresses:
+ foo.bar: [ 127.0.0.1, ::1 ]
+local-data: # avoids lists and merges into the untagged `local-data` config subtree
+ tagged: # (getting quite deep, though)
+ t1 t2:
+ addresses:
+ foo.bar: [ 127.0.0.1, ::1 ]
+# or even this ugly thing:
+local-data-tagged t1 t2:
+ addresses:
+ foo.bar: [ 127.0.0.1, ::1 ]
+
+forward: # TODO: "name" is from Unbound, but @vcunat would prefer "subtree" or something.
+ - name: '.' # Root is the default so could be omitted?
+ servers: [2001:148f:fffe::1, 2001:148f:ffff::1, 185.43.135.1, 193.14.47.1]
+ # TLS forward, server authenticated using hostname and system-wide CA certificates
+ # https://www.knot-resolver.cz/documentation/latest/modules-policy.html?highlight=forward#tls-examples
+ - name: '.'
+ servers:
+ - address: [ 192.0.2.1, 192.0.2.2@5353 ]
+ transport: tls
+ pin-sha256: Wg==
+ - address: 2001:DB8::d0c
+ transport: tls
+ hostname: res.example.com
+ ca-file: /etc/knot-resolver/tlsca.crt
+ options:
+ # LATER: allow a subset of options here, per sub-tree?
+ # Though that's not necessarily related to forwarding (e.g. TTL limits),
+ # especially implementation-wise it probably won't matter.
+
+
+# Too confusing approach, I suppose? Different from usual way of thinking but closer to internal model.
+# Down-sides:
+# - multiple rules for the same name won't be possible (future, with different tags)
+# - loading names from a file won't be possible (or URL, etc.)
+rules:
+ example.org: &fwd_odvr
+ type: forward
+ servers: [2001:148f:fffe::1, 2001:148f:ffff::1, 185.43.135.1, 193.14.47.1]
+ sub2.example.org:
+ type: empty
+ tags: [ t3, t5 ]
+ sub3.example.org:
+ type: forward-auth
+ dnssec: no
+
+
+# @amrazek: current valid config
+
+views:
+ - subnets: [ 0.0.0.0/0, "::/0" ]
+ answer: refused
+ - subnets: [ 0.0.0.0/0, "::/0" ]
+ tags: [t01, t02, t03]
+ options:
+ minimize: true # default
+ dns64: true # default
+ - subnets: 10.0.10.0/24 # can be single value
+ answer: allow
+
+local-data:
+ ttl: 1d
+ nodata: true
+ addresses:
+ foo.bar: [ 127.0.0.1, "::1" ]
+ my.pc.corp: 192.168.12.95
+ addresses-files:
+ - /etc/hosts
+ records: |
+ example.net. TXT "foo bar"
+ A 192.168.2.3
+ A 192.168.2.4
+ local.example.org AAAA ::1
+ subtrees:
+ - type: empty
+ roots: [ sub2.example.org ]
+ tags: [ t2 ]
+ - type: nxdomain
+ roots-file: /path/to/file.txt
+ - type: empty
+ roots-url: https://example.org/blocklist.txt
+ refresh: 1d
+ - type: redirect
+ roots: [ sub4.example.org ]
+ addresses: [ 127.0.0.1, "::1" ]
+
+forward:
+ - subtree: '.'
+ servers:
+ - address: [ 192.0.2.1, 192.0.2.2@5353 ]
+ transport: tls
+ pin-sha256: Wg==
+ - address: 2001:DB8::d0c
+ transport: tls
+ hostname: res.example.com
+ ca-file: /etc/knot-resolver/tlsca.crt
+ options:
+ dnssec: true # default
+ - subtree: 1.168.192.in-addr.arpa
+ servers: [ 192.0.2.1@5353 ]
+ options:
+ dnssec: false # policy.STUB?
diff --git a/python/knot_resolver/datamodel/dns64_schema.py b/python/knot_resolver/datamodel/dns64_schema.py
new file mode 100644
index 00000000..cc0fa06a
--- /dev/null
+++ b/python/knot_resolver/datamodel/dns64_schema.py
@@ -0,0 +1,19 @@
+from typing import List, Optional
+
+from knot_resolver.datamodel.types import IPv6Network, IPv6Network96, TimeUnit
+from knot_resolver.utils.modeling import ConfigSchema
+
+
+class Dns64Schema(ConfigSchema):
+ """
+ DNS64 (RFC 6147) configuration.
+
+ ---
+ prefix: IPv6 prefix to be used for synthesizing AAAA records.
+ rev_ttl: TTL in CNAME generated in the reverse 'ip6.arpa.' subtree.
+ exclude_subnets: IPv6 subnets that are disallowed in answer.
+ """
+
+ prefix: IPv6Network96 = IPv6Network96("64:ff9b::/96")
+ rev_ttl: Optional[TimeUnit] = None
+ exclude_subnets: Optional[List[IPv6Network]] = None
diff --git a/python/knot_resolver/datamodel/dnssec_schema.py b/python/knot_resolver/datamodel/dnssec_schema.py
new file mode 100644
index 00000000..6f51d5eb
--- /dev/null
+++ b/python/knot_resolver/datamodel/dnssec_schema.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from knot_resolver.datamodel.types import DomainName, EscapedStr, IntNonNegative, ReadableFile, TimeUnit
+from knot_resolver.utils.modeling import ConfigSchema
+
+
+class TrustAnchorFileSchema(ConfigSchema):
+ """
+ Trust-anchor zonefile configuration.
+
+ ---
+ file: Path to the zonefile that stores trust-anchors.
+ read_only: Blocks zonefile updates according to RFC 5011.
+
+ """
+
+ file: ReadableFile
+ read_only: bool = False
+
+
+class DnssecSchema(ConfigSchema):
+ """
+ DNSSEC configuration.
+
+ ---
+ trust_anchor_sentinel: Allows users of DNSSEC validating resolver to detect which root keys are configured in resolver's chain of trust. (RFC 8509)
+ trust_anchor_signal_query: Signaling Trust Anchor Knowledge in DNSSEC Using Key Tag Query, according to (RFC 8145#section-5).
+ time_skew_detection: Detection of difference between local system time and expiration time bounds in DNSSEC signatures for '. NS' records.
+ keep_removed: How many removed keys should be held in history (and key file) before being purged.
+ refresh_time: Force trust-anchors to be updated every defined time periodically instead of relying on (RFC 5011) logic and TTLs. Intended only for testing purposes.
+ hold_down_time: Modify hold-down timer (RFC 5011). Intended only for testing purposes.
+ trust_anchors: List of trust-anchors in DS/DNSKEY records format.
+ negative_trust_anchors: List of domain names representing negative trust-anchors. (RFC 7646)
+ trust_anchors_files: List of zonefiles where trust-anchors are stored.
+ """
+
+ trust_anchor_sentinel: bool = True
+ trust_anchor_signal_query: bool = True
+ time_skew_detection: bool = True
+ keep_removed: IntNonNegative = IntNonNegative(0)
+ refresh_time: Optional[TimeUnit] = None
+ hold_down_time: TimeUnit = TimeUnit("30d")
+ trust_anchors: Optional[List[EscapedStr]] = None
+ negative_trust_anchors: Optional[List[DomainName]] = None
+ trust_anchors_files: Optional[List[TrustAnchorFileSchema]] = None
diff --git a/python/knot_resolver/datamodel/forward_schema.py b/python/knot_resolver/datamodel/forward_schema.py
new file mode 100644
index 00000000..96c0f048
--- /dev/null
+++ b/python/knot_resolver/datamodel/forward_schema.py
@@ -0,0 +1,84 @@
+from typing import Any, List, Optional, Union
+
+from typing_extensions import Literal
+
+from knot_resolver.datamodel.types import (
+ DomainName,
+ IPAddressOptionalPort,
+ ListOrItem,
+ PinSha256,
+ ReadableFile,
+)
+from knot_resolver.utils.modeling import ConfigSchema
+
+
+class ForwardServerSchema(ConfigSchema):
+ """
+ Forward server configuration.
+
+ ---
+ address: IP address(es) of a forward server.
+ transport: Transport protocol for a forward server.
+ pin_sha256: Hash of accepted CA certificate.
+ hostname: Hostname of the Forward server.
+ ca_file: Path to CA certificate file.
+ """
+
+ address: ListOrItem[IPAddressOptionalPort]
+ transport: Optional[Literal["tls"]] = None
+ pin_sha256: Optional[ListOrItem[PinSha256]] = None
+ hostname: Optional[DomainName] = None
+ ca_file: Optional[ReadableFile] = None
+
+ def _validate(self) -> None:
+ if self.pin_sha256 and (self.hostname or self.ca_file):
+ raise ValueError("'pin-sha256' cannot be configurad together with 'hostname' or 'ca-file'")
+
+
+class ForwardOptionsSchema(ConfigSchema):
+ """
+ Subtree(s) forward options.
+
+ ---
+ authoritative: The forwarding target is an authoritative server.
+ dnssec: Enable/disable DNSSEC.
+ """
+
+ authoritative: bool = False
+ dnssec: bool = True
+
+
+class ForwardSchema(ConfigSchema):
+ """
+ Configuration of forward subtree.
+
+ ---
+ subtree: Subtree(s) to forward.
+ servers: Forward servers configuration.
+ options: Subtree(s) forward options.
+ """
+
+ subtree: ListOrItem[DomainName]
+ servers: Union[List[IPAddressOptionalPort], List[ForwardServerSchema]]
+ options: ForwardOptionsSchema = ForwardOptionsSchema()
+
+ def _validate(self) -> None:
+ def is_port_custom(servers: List[Any]) -> bool:
+ for server in servers:
+ if isinstance(server, IPAddressOptionalPort) and server.port:
+ return int(server.port) != 53
+ elif isinstance(server, ForwardServerSchema):
+ return is_port_custom(server.address.to_std())
+ return False
+
+ def is_transport_tls(servers: List[Any]) -> bool:
+ for server in servers:
+ if isinstance(server, ForwardServerSchema):
+ return server.transport == "tls"
+ return False
+
+ if self.options.authoritative and is_port_custom(self.servers):
+ raise ValueError("Forwarding to authoritative servers on a custom port is currently not supported.")
+
+ if self.options.authoritative and is_transport_tls(self.servers):
+ raise ValueError("Forwarding to authoritative servers using TLS protocol is not supported.")
diff --git a/python/knot_resolver/datamodel/globals.py b/python/knot_resolver/datamodel/globals.py
new file mode 100644
index 00000000..610323fa
--- /dev/null
+++ b/python/knot_resolver/datamodel/globals.py
@@ -0,0 +1,57 @@
+"""
+The parsing and validation of the datamodel is dependent on a global state:
+- a file system path used for resolving relative paths
+
+
+Commentary from @vsraier:
+=========================
+
+While this is not ideal, it is the best we can do at the moment. When I created this module,
+the datamodel was dependent on the global state implicitely. The validation procedures just read
+the current working directory. This module is the first step in removing the global dependency.
+
+At some point in the future, it might be interesting to add something like a "validation context"
+to the modelling tools. It is not technically complicated, but it requires
+massive model changes I am not willing to make at the moment. Ideally, when implementing this,
+the BaseSchema would turn into an empty class without any logic. Not even a constructor. All logic
+would be in the ObjectMapper class. Similar to how Gson works in Java or AutoMapper in C#.
+"""
+
+from pathlib import Path
+from typing import Optional
+
+
+class Context:
+ resolve_root: Optional[Path]
+ strict_validation: bool
+
+ def __init__(self, resolve_root: Optional[Path], strict_validation: bool = True) -> None:
+ self.resolve_root = resolve_root
+ self.strict_validation = strict_validation
+
+
+_global_context: Context = Context(None)
+
+
+def set_global_validation_context(context: Context) -> None:
+ global _global_context
+ _global_context = context
+
+
+def reset_global_validation_context() -> None:
+ global _global_context
+ _global_context = Context(None)
+
+
+def get_resolve_root() -> Path:
+ if _global_context.resolve_root is None:
+ raise RuntimeError(
+ "Global validation context 'resolve_root' is not set!"
+ " Before validation, you have to set it using `set_global_validation_context()` function!"
+ )
+
+ return _global_context.resolve_root
+
+
+def get_strict_validation() -> bool:
+ return _global_context.strict_validation
diff --git a/python/knot_resolver/datamodel/local_data_schema.py b/python/knot_resolver/datamodel/local_data_schema.py
new file mode 100644
index 00000000..2fe5a03a
--- /dev/null
+++ b/python/knot_resolver/datamodel/local_data_schema.py
@@ -0,0 +1,95 @@
+from typing import Dict, List, Optional
+
+from typing_extensions import Literal
+
+from knot_resolver.datamodel.types import (
+ DomainName,
+ EscapedStr,
+ IDPattern,
+ IPAddress,
+ ListOrItem,
+ ReadableFile,
+ TimeUnit,
+)
+from knot_resolver.utils.modeling import ConfigSchema
+
+
+class RuleSchema(ConfigSchema):
+ """
+ Local data advanced rule configuration.
+
+ ---
+ name: Hostname(s).
+ subtree: Type of subtree.
+ address: Address(es) to pair with hostname(s).
+ file: Path to file(s) with hostname and IP address(es) pairs in '/etc/hosts' like format.
+ records: Direct addition of records in DNS zone file format.
+ tags: Tags to link with other policy rules.
+ ttl: Optional, TTL value used for these answers.
+ nodata: Optional, use NODATA synthesis. NODATA will be synthesised for matching name, but mismatching type(e.g. AAAA query when only A exists).
+ """
+
+ name: Optional[ListOrItem[DomainName]] = None
+ subtree: Optional[Literal["empty", "nxdomain", "redirect"]] = None
+ address: Optional[ListOrItem[IPAddress]] = None
+ file: Optional[ListOrItem[ReadableFile]] = None
+ records: Optional[EscapedStr] = None
+ tags: Optional[List[IDPattern]] = None
+ ttl: Optional[TimeUnit] = None
+ nodata: Optional[bool] = None
+
+ def _validate(self) -> None:
+ options_sum = sum([bool(self.address), bool(self.subtree), bool(self.file), bool(self.records)])
+ if options_sum == 2 and bool(self.address) and self.subtree in {"empty", "redirect"}:
+ pass # these combinations still make sense
+ elif options_sum > 1:
+ raise ValueError("only one of 'address', 'subtree' or 'file' can be configured")
+ elif options_sum < 1:
+ raise ValueError("one of 'address', 'subtree', 'file' or 'records' must be configured")
+
+ options_sum2 = sum([bool(self.name), bool(self.file), bool(self.records)])
+ if options_sum2 != 1:
+ raise ValueError("one of 'name', 'file or 'records' must be configured")
+
+ if bool(self.nodata) and bool(self.subtree) and not bool(self.address):
+ raise ValueError("'nodata' defined but unused with 'subtree'")
+
+
+class RPZSchema(ConfigSchema):
+ """
+ Configuration or Response Policy Zone (RPZ).
+
+ ---
+ file: Path to the RPZ zone file.
+ tags: Tags to link with other policy rules.
+ """
+
+ file: ReadableFile
+ tags: Optional[List[IDPattern]] = None
+
+
+class LocalDataSchema(ConfigSchema):
+ """
+ Local data for forward records (A/AAAA) and reverse records (PTR).
+
+ ---
+ ttl: Default TTL value used for added local data/records.
+ nodata: Use NODATA synthesis. NODATA will be synthesised for matching name, but mismatching type(e.g. AAAA query when only A exists).
+ root_fallback_addresses: Direct replace of root hints.
+ root_fallback_addresses_files: Direct replace of root hints from a zonefile.
+ addresses: Direct addition of hostname and IP addresses pairs.
+ addresses_files: Direct addition of hostname and IP addresses pairs from files in '/etc/hosts' like format.
+ records: Direct addition of records in DNS zone file format.
+ rules: Local data rules.
+ rpz: List of Response Policy Zones and its configuration.
+ """
+
+ ttl: Optional[TimeUnit] = None
+ nodata: bool = True
+ root_fallback_addresses: Optional[Dict[DomainName, ListOrItem[IPAddress]]] = None
+ root_fallback_addresses_files: Optional[List[ReadableFile]] = None
+ addresses: Optional[Dict[DomainName, ListOrItem[IPAddress]]] = None
+ addresses_files: Optional[List[ReadableFile]] = None
+ records: Optional[EscapedStr] = None
+ rules: Optional[List[RuleSchema]] = None
+ rpz: Optional[List[RPZSchema]] = None
diff --git a/python/knot_resolver/datamodel/logging_schema.py b/python/knot_resolver/datamodel/logging_schema.py
new file mode 100644
index 00000000..e2985dd1
--- /dev/null
+++ b/python/knot_resolver/datamodel/logging_schema.py
@@ -0,0 +1,153 @@
+import os
+from typing import Any, List, Optional, Set, Type, Union, cast
+
+from typing_extensions import Literal
+
+from knot_resolver.datamodel.types import TimeUnit, WritableFilePath
+from knot_resolver.utils.modeling import ConfigSchema
+from knot_resolver.utils.modeling.base_schema import is_obj_type_valid
+
+try:
+ # On Debian 10, the typing_extensions library does not contain TypeAlias.
+ # We don't strictly need the import for anything except for type checking,
+ # so this try-except makes sure it works either way.
+ from typing_extensions import TypeAlias # pylint: disable=ungrouped-imports
+except ImportError:
+ TypeAlias = None # type: ignore
+
+
+LogLevelEnum = Literal["crit", "err", "warning", "notice", "info", "debug"]
+LogTargetEnum = Literal["syslog", "stderr", "stdout"]
+LogGroupsEnum: TypeAlias = Literal[
+ "manager",
+ "supervisord",
+ "cache-gc",
+ ## Now the LOG_GRP_*_TAG defines, exactly from ../../../lib/log.h
+ "system",
+ "cache",
+ "io",
+ "net",
+ "ta",
+ "tasent",
+ "tasign",
+ "taupd",
+ "tls",
+ "gnutls",
+ "tls_cl",
+ "xdp",
+ "doh",
+ "dnssec",
+ "hint",
+ "plan",
+ "iterat",
+ "valdtr",
+ "resolv",
+ "select",
+ "zoncut",
+ "cookie",
+ "statis",
+ "rebind",
+ "worker",
+ "policy",
+ "daf",
+ "timejm",
+ "timesk",
+ "graphi",
+ "prefil",
+ "primin",
+ "srvstl",
+ "wtchdg",
+ "nsid",
+ "dnstap",
+ "tests",
+ "dotaut",
+ "http",
+ "contrl",
+ "module",
+ "devel",
+ "renum",
+ "exterr",
+ "rules",
+ "prlayr",
+ # "reqdbg",... (non-displayed section of the enum)
+]
+
+
+class DnstapSchema(ConfigSchema):
+ """
+ Logging DNS queries and responses to a unix socket.
+
+ ---
+ unix_socket: Path to unix domain socket where dnstap messages will be sent.
+ log_queries: Log queries from downstream in wire format.
+ log_responses: Log responses to downstream in wire format.
+ log_tcp_rtt: Log TCP RTT (Round-trip time).
+ """
+
+ unix_socket: WritableFilePath
+ log_queries: bool = True
+ log_responses: bool = True
+ log_tcp_rtt: bool = True
+
+
+class DebuggingSchema(ConfigSchema):
+ """
+ Advanced debugging parameters for kresd (Knot Resolver daemon).
+
+ ---
+ assertion_abort: Allow the process to be aborted in case it encounters a failed assertion.
+ assertion_fork: Fork and abord child kresd process to obtain a coredump, while the parent process recovers and keeps running.
+ """
+
+ assertion_abort: bool = False
+ assertion_fork: TimeUnit = TimeUnit("5m")
+
+
+class LoggingSchema(ConfigSchema):
+ class Raw(ConfigSchema):
+ """
+ Logging and debugging configuration.
+
+ ---
+ level: Global logging level.
+ target: Global logging stream target. "from-env" uses $KRES_LOGGING_TARGET and defaults to "stdout".
+ groups: List of groups for which 'debug' logging level is set.
+ dnssec_bogus: Logging a message for each DNSSEC validation failure.
+ dnstap: Logging DNS requests and responses to a unix socket.
+ debugging: Advanced debugging parameters for kresd (Knot Resolver daemon).
+ """
+
+ level: LogLevelEnum = "notice"
+ target: Union[LogTargetEnum, Literal["from-env"]] = "from-env"
+ groups: Optional[List[LogGroupsEnum]] = None
+ dnssec_bogus: bool = False
+ dnstap: Union[Literal[False], DnstapSchema] = False
+ debugging: DebuggingSchema = DebuggingSchema()
+
+ _LAYER = Raw
+
+ level: LogLevelEnum
+ target: LogTargetEnum
+ groups: Optional[List[LogGroupsEnum]]
+ dnssec_bogus: bool
+ dnstap: Union[Literal[False], DnstapSchema]
+ debugging: DebuggingSchema
+
+ def _target(self, raw: Raw) -> LogTargetEnum:
+ if raw.target == "from-env":
+ target = os.environ.get("KRES_LOGGING_TARGET") or "stdout"
+ if not is_obj_type_valid(target, cast(Type[Any], LogTargetEnum)):
+ raise ValueError(f"logging target '{target}' read from $KRES_LOGGING_TARGET is invalid")
+ return cast(LogTargetEnum, target)
+ else:
+ return raw.target
+
+ def _validate(self):
+ if self.groups is None:
+ return
+
+ checked: Set[str] = set()
+ for i, g in enumerate(self.groups):
+ if g in checked:
+ raise ValueError(f"duplicate logging group '{g}' on index {i}")
+ checked.add(g)
diff --git a/python/knot_resolver/datamodel/lua_schema.py b/python/knot_resolver/datamodel/lua_schema.py
new file mode 100644
index 00000000..56e8ee09
--- /dev/null
+++ b/python/knot_resolver/datamodel/lua_schema.py
@@ -0,0 +1,23 @@
+from typing import Optional
+
+from knot_resolver.datamodel.types import ReadableFile
+from knot_resolver.utils.modeling import ConfigSchema
+
+
+class LuaSchema(ConfigSchema):
+ """
+ Custom Lua configuration.
+
+ ---
+ script_only: Ignore declarative configuration and use only Lua script or file defined in this section.
+ script: Custom Lua configuration script.
+ script_file: Path to file that contains Lua configuration script.
+ """
+
+ script_only: bool = False
+ script: Optional[str] = None
+ script_file: Optional[ReadableFile] = None
+
+ def _validate(self) -> None:
+ if self.script and self.script_file:
+ raise ValueError("'lua.script' and 'lua.script-file' are both defined, only one can be used")
diff --git a/python/knot_resolver/datamodel/management_schema.py b/python/knot_resolver/datamodel/management_schema.py
new file mode 100644
index 00000000..b338c32a
--- /dev/null
+++ b/python/knot_resolver/datamodel/management_schema.py
@@ -0,0 +1,21 @@
+from typing import Optional
+
+from knot_resolver.datamodel.types import WritableFilePath, IPAddressPort
+from knot_resolver.utils.modeling import ConfigSchema
+
+
+class ManagementSchema(ConfigSchema):
+ """
+ Configuration of management HTTP API.
+
+ ---
+ unix_socket: Path to unix domain socket to listen to.
+ interface: IP address and port number to listen to.
+ """
+
+ unix_socket: Optional[WritableFilePath] = None
+ interface: Optional[IPAddressPort] = None
+
+ def _validate(self) -> None:
+ if bool(self.unix_socket) == bool(self.interface):
+ raise ValueError("One of 'interface' or 'unix-socket' must be configured.")
diff --git a/python/knot_resolver/datamodel/monitoring_schema.py b/python/knot_resolver/datamodel/monitoring_schema.py
new file mode 100644
index 00000000..3b3ad6d9
--- /dev/null
+++ b/python/knot_resolver/datamodel/monitoring_schema.py
@@ -0,0 +1,25 @@
+from typing import Union
+
+from typing_extensions import Literal
+
+from knot_resolver.datamodel.types import DomainName, EscapedStr, IPAddress, PortNumber, TimeUnit
+from knot_resolver.utils.modeling import ConfigSchema
+
+
+class GraphiteSchema(ConfigSchema):
+ host: Union[IPAddress, DomainName]
+ port: PortNumber = PortNumber(2003)
+ prefix: EscapedStr = EscapedStr("")
+ interval: TimeUnit = TimeUnit("5s")
+ tcp: bool = False
+
+
+class MonitoringSchema(ConfigSchema):
+ """
+ ---
+ enabled: configures, whether statistics module will be loaded into resolver
+ graphite: optionally configures where should graphite metrics be sent to
+ """
+
+ enabled: Literal["manager-only", "lazy", "always"] = "lazy"
+ graphite: Union[Literal[False], GraphiteSchema] = False
diff --git a/python/knot_resolver/datamodel/network_schema.py b/python/knot_resolver/datamodel/network_schema.py
new file mode 100644
index 00000000..e766d499
--- /dev/null
+++ b/python/knot_resolver/datamodel/network_schema.py
@@ -0,0 +1,181 @@
+from typing import List, Optional, Union
+
+from typing_extensions import Literal
+
+from knot_resolver.datamodel.types import (
+ EscapedStr32B,
+ WritableFilePath,
+ Int0_512,
+ Int0_65535,
+ InterfaceOptionalPort,
+ IPAddress,
+ IPAddressEM,
+ IPNetwork,
+ IPv4Address,
+ IPv6Address,
+ ListOrItem,
+ PortNumber,
+ ReadableFile,
+ SizeUnit,
+)
+from knot_resolver.utils.modeling import ConfigSchema
+
+KindEnum = Literal["dns", "xdp", "dot", "doh-legacy", "doh2"]
+
+
+class EdnsBufferSizeSchema(ConfigSchema):
+ """
+ EDNS payload size advertised in DNS packets.
+
+ ---
+ upstream: Maximum EDNS upstream (towards other DNS servers) payload size.
+ downstream: Maximum EDNS downstream (towards clients) payload size for communication.
+ """
+
+ upstream: SizeUnit = SizeUnit("1232B")
+ downstream: SizeUnit = SizeUnit("1232B")
+
+
+class AddressRenumberingSchema(ConfigSchema):
+ """
+ Renumbers addresses in answers to different address space.
+
+ ---
+ source: Source subnet.
+ destination: Destination address prefix.
+ """
+
+ source: IPNetwork
+ destination: Union[IPAddressEM, IPAddress]
+
+
+class TLSSchema(ConfigSchema):
+ """
+ TLS configuration, also affects DNS over TLS and DNS over HTTPS.
+
+ ---
+ cert_file: Path to certificate file.
+ key_file: Path to certificate key file.
+ sticket_secret: Secret for TLS session resumption via tickets. (RFC 5077).
+ sticket_secret_file: Path to file with secret for TLS session resumption via tickets. (RFC 5077).
+ auto_discovery: Experimental automatic discovery of authoritative servers supporting DNS-over-TLS.
+ padding: EDNS(0) padding of queries and answers sent over an encrypted channel.
+ """
+
+ cert_file: Optional[ReadableFile] = None
+ key_file: Optional[ReadableFile] = None
+ sticket_secret: Optional[EscapedStr32B] = None
+ sticket_secret_file: Optional[ReadableFile] = None
+ auto_discovery: bool = False
+ padding: Union[bool, Int0_512] = True
+
+ def _validate(self):
+ if self.sticket_secret and self.sticket_secret_file:
+ raise ValueError("'sticket_secret' and 'sticket_secret_file' are both defined, only one can be used")
+
+
+class ListenSchema(ConfigSchema):
+ class Raw(ConfigSchema):
+ """
+ Configuration of listening interface.
+
+ ---
+ unix_socket: Path to unix domain socket to listen to.
+ interface: IP address or interface name with optional port number to listen to.
+ port: Port number to listen to.
+ kind: Specifies DNS query transport protocol.
+ freebind: Used for binding to non-local address.
+ """
+
+ interface: Optional[ListOrItem[InterfaceOptionalPort]] = None
+ unix_socket: Optional[ListOrItem[WritableFilePath]] = None
+ port: Optional[PortNumber] = None
+ kind: KindEnum = "dns"
+ freebind: bool = False
+
+ _LAYER = Raw
+
+ interface: Optional[ListOrItem[InterfaceOptionalPort]]
+ unix_socket: Optional[ListOrItem[WritableFilePath]]
+ port: Optional[PortNumber]
+ kind: KindEnum
+ freebind: bool
+
+ def _interface(self, origin: Raw) -> Optional[ListOrItem[InterfaceOptionalPort]]:
+ if origin.interface:
+ port_set: Optional[bool] = None
+ for intrfc in origin.interface: # type: ignore[attr-defined]
+ if origin.port and intrfc.port:
+ raise ValueError("The port number is defined in two places ('port' option and '@<port>' syntax).")
+ if port_set is not None and (bool(intrfc.port) != port_set):
+ raise ValueError(
+ "The '@<port>' syntax must be used either for all or none of the interface in the list."
+ )
+ port_set = bool(intrfc.port)
+ return origin.interface
+
+ def _port(self, origin: Raw) -> Optional[PortNumber]:
+ if origin.port:
+ return origin.port
+ # default port number based on kind
+ elif origin.interface:
+ if origin.kind == "dot":
+ return PortNumber(853)
+ elif origin.kind in ["doh-legacy", "doh2"]:
+ return PortNumber(443)
+ return PortNumber(53)
+ return None
+
+ def _validate(self) -> None:
+ if bool(self.unix_socket) == bool(self.interface):
+ raise ValueError("One of 'interface' or 'unix-socket' must be configured.")
+ if self.port and self.unix_socket:
+ raise ValueError(
+ "'unix-socket' and 'port' are not compatible options."
+ " Port configuration can only be used with 'interface' option."
+ )
+
+
+class ProxyProtocolSchema(ConfigSchema):
+ """
+ PROXYv2 protocol configuration.
+
+ ---
+ allow: Allow usage of the PROXYv2 protocol headers by clients on the specified addresses.
+ """
+
+ allow: List[Union[IPAddress, IPNetwork]]
+
+
+class NetworkSchema(ConfigSchema):
+ """
+ Network connections and protocols configuration.
+
+ ---
+ do_ipv4: Enable/disable using IPv4 for contacting upstream nameservers.
+ do_ipv6: Enable/disable using IPv6 for contacting upstream nameservers.
+ out_interface_v4: IPv4 address used to perform queries. Not set by default, which lets the OS choose any address.
+ out_interface_v6: IPv6 address used to perform queries. Not set by default, which lets the OS choose any address.
+ tcp_pipeline: TCP pipeline limit. The number of outstanding queries that a single client connection can make in parallel.
+ edns_tcp_keepalive: Allows clients to discover the connection timeout. (RFC 7828)
+ edns_buffer_size: Maximum EDNS payload size advertised in DNS packets. Different values can be configured for communication downstream (towards clients) and upstream (towards other DNS servers).
+ address_renumbering: Renumbers addresses in answers to different address space.
+ tls: TLS configuration, also affects DNS over TLS and DNS over HTTPS.
+ proxy_protocol: PROXYv2 protocol configuration.
+ listen: List of interfaces to listen to and its configuration.
+ """
+
+ do_ipv4: bool = True
+ do_ipv6: bool = True
+ out_interface_v4: Optional[IPv4Address] = None
+ out_interface_v6: Optional[IPv6Address] = None
+ tcp_pipeline: Int0_65535 = Int0_65535(100)
+ edns_tcp_keepalive: bool = True
+ edns_buffer_size: EdnsBufferSizeSchema = EdnsBufferSizeSchema()
+ address_renumbering: Optional[List[AddressRenumberingSchema]] = None
+ tls: TLSSchema = TLSSchema()
+ proxy_protocol: Union[Literal[False], ProxyProtocolSchema] = False
+ listen: List[ListenSchema] = [
+ ListenSchema({"interface": "127.0.0.1"}),
+ ListenSchema({"interface": "::1", "freebind": True}),
+ ]
diff --git a/python/knot_resolver/datamodel/options_schema.py b/python/knot_resolver/datamodel/options_schema.py
new file mode 100644
index 00000000..9230c7f0
--- /dev/null
+++ b/python/knot_resolver/datamodel/options_schema.py
@@ -0,0 +1,36 @@
+from typing_extensions import Literal
+
+from knot_resolver.utils.modeling import ConfigSchema
+
+GlueCheckingEnum = Literal["normal", "strict", "permissive"]
+
+
+class OptionsSchema(ConfigSchema):
+ """
+ Fine-tuning global parameters of DNS resolver operation.
+
+ ---
+ glue_checking: Glue records scrictness checking level.
+ minimize: Send minimum amount of information in recursive queries to enhance privacy.
+ query_loopback: Permits queries to loopback addresses.
+ reorder_rrset: Controls whether resource records within a RRSet are reordered each time it is served from the cache.
+ query_case_randomization: Randomize Query Character Case.
+ priming: Initializing DNS resolver cache with Priming Queries (RFC 8109)
+ rebinding_protection: Protection against DNS Rebinding attack.
+ refuse_no_rd: Queries without RD (recursion desired) bit set in query are answered with REFUSED.
+ time_jump_detection: Detection of difference between local system time and expiration time bounds in DNSSEC signatures for '. NS' records.
+ violators_workarounds: Workarounds for known DNS protocol violators.
+ serve_stale: Allows using timed-out records in case DNS resolver is unable to contact upstream servers.
+ """
+
+ glue_checking: GlueCheckingEnum = "normal"
+ minimize: bool = True
+ query_loopback: bool = False
+ reorder_rrset: bool = True
+ query_case_randomization: bool = True
+ priming: bool = True
+ rebinding_protection: bool = False
+ refuse_no_rd: bool = True
+ time_jump_detection: bool = True
+ violators_workarounds: bool = False
+ serve_stale: bool = False
diff --git a/python/knot_resolver/datamodel/policy_schema.py b/python/knot_resolver/datamodel/policy_schema.py
new file mode 100644
index 00000000..8f9d8b26
--- /dev/null
+++ b/python/knot_resolver/datamodel/policy_schema.py
@@ -0,0 +1,126 @@
+from typing import List, Optional, Union
+
+from knot_resolver.datamodel.forward_schema import ForwardServerSchema
+from knot_resolver.datamodel.network_schema import AddressRenumberingSchema
+from knot_resolver.datamodel.types import (
+ DNSRecordTypeEnum,
+ IPAddressOptionalPort,
+ PolicyActionEnum,
+ PolicyFlagEnum,
+ TimeUnit,
+)
+from knot_resolver.utils.modeling import ConfigSchema
+
+
+class FilterSchema(ConfigSchema):
+ """
+ Query filtering configuration.
+
+ ---
+ suffix: Filter based on the suffix of the query name.
+ pattern: Filter based on the pattern that match query name.
+ qtype: Filter based on the DNS query type.
+ """
+
+ suffix: Optional[str] = None
+ pattern: Optional[str] = None
+ qtype: Optional[DNSRecordTypeEnum] = None
+
+
+class AnswerSchema(ConfigSchema):
+ """
+ Configuration of custom resource record for DNS answer.
+
+ ---
+ rtype: Type of DNS resource record.
+ rdata: Data of DNS resource record.
+ ttl: Time-to-live value for defined answer.
+ nodata: Answer with NODATA If requested type is not configured in the answer. Otherwise policy rule is ignored.
+ """
+
+ rtype: DNSRecordTypeEnum
+ rdata: str
+ ttl: TimeUnit = TimeUnit("1s")
+ nodata: bool = False
+
+
+def _validate_policy_action(policy_action: Union["ActionSchema", "PolicySchema"]) -> None:
+ servers = ["mirror", "forward", "stub"]
+
+ def _field(ac: str) -> str:
+ if ac in servers:
+ return "servers"
+ return "message" if ac == "deny" else ac
+
+ configurable_actions = ["deny", "reroute", "answer"] + servers
+
+ # checking for missing mandatory fields for actions
+ field = _field(policy_action.action)
+ if policy_action.action in configurable_actions and not getattr(policy_action, field):
+ raise ValueError(f"missing mandatory field '{field}' for '{policy_action.action}' action")
+
+ # checking for unnecessary fields
+ for ac in configurable_actions + ["deny"]:
+ field = _field(ac)
+ if getattr(policy_action, field) and _field(policy_action.action) != field:
+ raise ValueError(f"'{field}' field can only be defined for '{ac}' action")
+
+ # ForwardServerSchema is valid only for 'forward' action
+ if policy_action.servers:
+ for server in policy_action.servers: # pylint: disable=not-an-iterable
+ if policy_action.action != "forward" and isinstance(server, ForwardServerSchema):
+ raise ValueError(
+ f"'ForwardServerSchema' in 'servers' is valid only for 'forward' action, got '{policy_action.action}'"
+ )
+
+
+class ActionSchema(ConfigSchema):
+ """
+ Configuration of policy action.
+
+ ---
+ action: Policy action.
+ message: Deny message for 'deny' action.
+ reroute: Configuration for 'reroute' action.
+ answer: Answer definition for 'answer' action.
+ servers: Servers configuration for 'mirror', 'forward' and 'stub' action.
+ """
+
+ action: PolicyActionEnum
+ message: Optional[str] = None
+ reroute: Optional[List[AddressRenumberingSchema]] = None
+ answer: Optional[AnswerSchema] = None
+ servers: Optional[Union[List[IPAddressOptionalPort], List[ForwardServerSchema]]] = None
+
+ def _validate(self) -> None:
+ _validate_policy_action(self)
+
+
+class PolicySchema(ConfigSchema):
+ """
+ Configuration of policy rule.
+
+ ---
+ action: Policy rule action.
+ priority: Policy rule priority.
+ filter: Query filtering configuration.
+ views: Use policy rule only for clients defined by views.
+ options: Configuration flags for policy rule.
+ message: Deny message for 'deny' action.
+ reroute: Configuration for 'reroute' action.
+ answer: Answer definition for 'answer' action.
+ servers: Servers configuration for 'mirror', 'forward' and 'stub' action.
+ """
+
+ action: PolicyActionEnum
+ priority: Optional[int] = None
+ filter: Optional[FilterSchema] = None
+ views: Optional[List[str]] = None
+ options: Optional[List[PolicyFlagEnum]] = None
+ message: Optional[str] = None
+ reroute: Optional[List[AddressRenumberingSchema]] = None
+ answer: Optional[AnswerSchema] = None
+ servers: Optional[Union[List[IPAddressOptionalPort], List[ForwardServerSchema]]] = None
+
+ def _validate(self) -> None:
+ _validate_policy_action(self)
diff --git a/python/knot_resolver/datamodel/rpz_schema.py b/python/knot_resolver/datamodel/rpz_schema.py
new file mode 100644
index 00000000..96d79293
--- /dev/null
+++ b/python/knot_resolver/datamodel/rpz_schema.py
@@ -0,0 +1,29 @@
+from typing import List, Optional
+
+from knot_resolver.datamodel.types import PolicyActionEnum, PolicyFlagEnum, ReadableFile
+from knot_resolver.utils.modeling import ConfigSchema
+
+
+class RPZSchema(ConfigSchema):
+ """
+ Configuration or Response Policy Zone (RPZ).
+
+ ---
+ action: RPZ rule action, typically 'deny'.
+ file: Path to the RPZ zone file.
+ watch: Reload the file when it changes.
+ views: Use RPZ rule only for clients defined by views.
+ options: Configuration flags for RPZ rule.
+ message: Deny message for 'deny' action.
+ """
+
+ action: PolicyActionEnum
+ file: ReadableFile
+ watch: bool = True
+ views: Optional[List[str]] = None
+ options: Optional[List[PolicyFlagEnum]] = None
+ message: Optional[str] = None
+
+ def _validate(self) -> None:
+ if self.message and not self.action == "deny":
+ raise ValueError("'message' field can only be defined for 'deny' action")
diff --git a/python/knot_resolver/datamodel/slice_schema.py b/python/knot_resolver/datamodel/slice_schema.py
new file mode 100644
index 00000000..1586cab7
--- /dev/null
+++ b/python/knot_resolver/datamodel/slice_schema.py
@@ -0,0 +1,21 @@
+from typing import List, Optional
+
+from typing_extensions import Literal
+
+from knot_resolver.datamodel.policy_schema import ActionSchema
+from knot_resolver.utils.modeling import ConfigSchema
+
+
+class SliceSchema(ConfigSchema):
+ """
+ Split the entire DNS namespace into distinct slices.
+
+ ---
+ function: Slicing function that returns index based on query
+ views: Use this Slice only for clients defined by views.
+ actions: Actions for slice.
+ """
+
+ function: Literal["randomize-psl"] = "randomize-psl"
+ views: Optional[List[str]] = None
+ actions: List[ActionSchema]
diff --git a/python/knot_resolver/datamodel/static_hints_schema.py b/python/knot_resolver/datamodel/static_hints_schema.py
new file mode 100644
index 00000000..ac64c311
--- /dev/null
+++ b/python/knot_resolver/datamodel/static_hints_schema.py
@@ -0,0 +1,27 @@
+from typing import Dict, List, Optional
+
+from knot_resolver.datamodel.types import DomainName, IPAddress, ReadableFile, TimeUnit
+from knot_resolver.utils.modeling import ConfigSchema
+
+
+class StaticHintsSchema(ConfigSchema):
+ """
+ Static hints for forward records (A/AAAA) and reverse records (PTR)
+
+ ---
+ ttl: TTL value used for records added from static hints.
+ nodata: Use NODATA synthesis. NODATA will be synthesised for matching hint name, but mismatching type.
+ etc_hosts: Add hints from '/etc/hosts' file.
+ root_hints: Direct addition of root hints pairs (hostname, list of addresses).
+ root_hints_file: Path to root hints in zonefile. Replaces all current root hints.
+ hints: Direct addition of hints pairs (hostname, list of addresses).
+ hints_files: Path to hints in hosts-like file.
+ """
+
+ ttl: Optional[TimeUnit] = None
+ nodata: bool = True
+ etc_hosts: bool = False
+ root_hints: Optional[Dict[DomainName, List[IPAddress]]] = None
+ root_hints_file: Optional[ReadableFile] = None
+ hints: Optional[Dict[DomainName, List[IPAddress]]] = None
+ hints_files: Optional[List[ReadableFile]] = None
diff --git a/python/knot_resolver/datamodel/stub_zone_schema.py b/python/knot_resolver/datamodel/stub_zone_schema.py
new file mode 100644
index 00000000..afd1cc79
--- /dev/null
+++ b/python/knot_resolver/datamodel/stub_zone_schema.py
@@ -0,0 +1,32 @@
+from typing import List, Optional, Union
+
+from knot_resolver.datamodel.types import DomainName, IPAddressOptionalPort, PolicyFlagEnum
+from knot_resolver.utils.modeling import ConfigSchema
+
+
+class StubServerSchema(ConfigSchema):
+ """
+ Configuration of Stub server.
+
+ ---
+ address: IP address of Stub server.
+ """
+
+ address: IPAddressOptionalPort
+
+
+class StubZoneSchema(ConfigSchema):
+ """
+ Configuration of Stub Zone.
+
+ ---
+ subtree: Domain name of the zone.
+ servers: IP address of Stub server.
+ views: Use this Stub Zone only for clients defined by views.
+ options: Configuration flags for Stub Zone.
+ """
+
+ subtree: DomainName
+ servers: Union[List[IPAddressOptionalPort], List[StubServerSchema]]
+ views: Optional[List[str]] = None
+ options: Optional[List[PolicyFlagEnum]] = None
diff --git a/python/knot_resolver/datamodel/templates/__init__.py b/python/knot_resolver/datamodel/templates/__init__.py
new file mode 100644
index 00000000..fdb91dd2
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/__init__.py
@@ -0,0 +1,43 @@
+import os
+import sys
+
+from jinja2 import Environment, FileSystemLoader, Template
+
+
+def _get_templates_dir() -> str:
+ module = sys.modules["knot_resolver.datamodel"].__file__
+ if module:
+ templates_dir = os.path.join(os.path.dirname(module), "templates")
+ if os.path.isdir(templates_dir):
+ return templates_dir
+ raise NotADirectoryError(f"the templates dir '{templates_dir}' is not a directory or does not exist")
+ raise OSError("package 'knot_resolver.datamodel' cannot be located or loaded")
+
+
+_TEMPLATES_DIR = _get_templates_dir()
+
+
+def _import_kresd_worker_config_template() -> Template:
+ path = os.path.join(_TEMPLATES_DIR, "worker-config.lua.j2")
+ with open(path, "r", encoding="UTF-8") as file:
+ template = file.read()
+ return template_from_str(template)
+
+
+def _import_kresd_policy_config_template() -> Template:
+ path = os.path.join(_TEMPLATES_DIR, "policy-config.lua.j2")
+ with open(path, "r", encoding="UTF-8") as file:
+ template = file.read()
+ return template_from_str(template)
+
+
+def template_from_str(template: str) -> Template:
+ ldr = FileSystemLoader(_TEMPLATES_DIR)
+ env = Environment(trim_blocks=True, lstrip_blocks=True, loader=ldr)
+ return env.from_string(template)
+
+
+WORKER_CONFIG_TEMPLATE = _import_kresd_worker_config_template()
+
+
+POLICY_CONFIG_TEMPLATE = _import_kresd_policy_config_template()
diff --git a/python/knot_resolver/datamodel/templates/cache.lua.j2 b/python/knot_resolver/datamodel/templates/cache.lua.j2
new file mode 100644
index 00000000..f0176a59
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/cache.lua.j2
@@ -0,0 +1,32 @@
+cache.open({{ cfg.cache.size_max.bytes() }}, 'lmdb://{{ cfg.cache.storage }}')
+cache.min_ttl({{ cfg.cache.ttl_min.seconds() }})
+cache.max_ttl({{ cfg.cache.ttl_max.seconds() }})
+cache.ns_tout({{ cfg.cache.ns_timeout.millis() }})
+
+{% if cfg.cache.prefill %}
+-- cache.prefill
+modules.load('prefill')
+prefill.config({
+{% for item in cfg.cache.prefill %}
+ ['{{ item.origin.punycode() }}'] = {
+ url = '{{ item.url }}',
+ interval = {{ item.refresh_interval.seconds() }}
+ {{ "ca_file = '"+item.ca_file+"'," if item.ca_file }}
+ }
+{% endfor %}
+})
+{% endif %}
+
+{% if cfg.cache.prefetch.expiring %}
+-- cache.prefetch.expiring
+modules.load('prefetch')
+{% endif %}
+
+{% if cfg.cache.prefetch.prediction %}
+-- cache.prefetch.prediction
+modules.load('predict')
+predict.config({
+ window = {{ cfg.cache.prefetch.prediction.window.minutes() }},
+ period = {{ cfg.cache.prefetch.prediction.period }},
+})
+{% endif %}
diff --git a/python/knot_resolver/datamodel/templates/dns64.lua.j2 b/python/knot_resolver/datamodel/templates/dns64.lua.j2
new file mode 100644
index 00000000..c5239f00
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/dns64.lua.j2
@@ -0,0 +1,17 @@
+{% from 'macros/common_macros.lua.j2' import string_table %}
+
+{% if cfg.dns64 %}
+-- load dns64 module
+modules.load('dns64')
+
+-- dns64.prefix
+dns64.config({
+ prefix = '{{ cfg.dns64.prefix.to_std().network_address|string }}',
+{% if cfg.dns64.rev_ttl %}
+ rev_ttl = {{ cfg.dns64.rev_ttl.seconds() }},
+{% endif %}
+{% if cfg.dns64.exclude_subnets %}
+ exclude_subnets = {{ string_table(cfg.dns64.exclude_subnets) }},
+{% endif %}
+})
+{% endif %} \ No newline at end of file
diff --git a/python/knot_resolver/datamodel/templates/dnssec.lua.j2 b/python/knot_resolver/datamodel/templates/dnssec.lua.j2
new file mode 100644
index 00000000..05d1fa68
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/dnssec.lua.j2
@@ -0,0 +1,60 @@
+{% from 'macros/common_macros.lua.j2' import boolean %}
+
+{% if not cfg.dnssec %}
+-- disable dnssec
+trust_anchors.remove('.')
+{% endif %}
+
+-- options.trust-anchor-sentinel
+{% if cfg.dnssec.trust_anchor_sentinel %}
+modules.load('ta_sentinel')
+{% else %}
+modules.unload('ta_sentinel')
+{% endif %}
+
+-- options.trust-anchor-signal-query
+{% if cfg.dnssec.trust_anchor_signal_query %}
+modules.load('ta_signal_query')
+{% else %}
+modules.unload('ta_signal_query')
+{% endif %}
+
+-- options.time-skew-detection
+{% if cfg.dnssec.time_skew_detection %}
+modules.load('detect_time_skew')
+{% else %}
+modules.unload('detect_time_skew')
+{% endif %}
+
+{% if cfg.dnssec.keep_removed %}
+-- dnssec.keep-removed
+trust_anchors.keep_removed = {{ cfg.dnssec.keep_removed }}
+{% endif %}
+
+{% if cfg.dnssec.refresh_time %}
+-- dnssec.refresh-time
+trust_anchors.refresh_time = {{ cfg.dnssec.refresh_time.seconds()|string }}
+{% endif %}
+
+{% if cfg.dnssec.trust_anchors %}
+-- dnssec.trust-anchors
+{% for ta in cfg.dnssec.trust_anchors %}
+trust_anchors.add('{{ ta }}')
+{% endfor %}
+{% endif %}
+
+{% if cfg.dnssec.negative_trust_anchors %}
+-- dnssec.negative-trust-anchors
+trust_anchors.set_insecure({
+{% for nta in cfg.dnssec.negative_trust_anchors %}
+ '{{ nta }}',
+{% endfor %}
+})
+{% endif %}
+
+{% if cfg.dnssec.trust_anchors_files %}
+-- dnssec.trust-anchors-files
+{% for taf in cfg.dnssec.trust_anchors_files %}
+trust_anchors.add_file('{{ taf.file }}', readonly = {{ boolean(taf.read_only) }})
+{% endfor %}
+{% endif %} \ No newline at end of file
diff --git a/python/knot_resolver/datamodel/templates/forward.lua.j2 b/python/knot_resolver/datamodel/templates/forward.lua.j2
new file mode 100644
index 00000000..24311da1
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/forward.lua.j2
@@ -0,0 +1,9 @@
+{% from 'macros/forward_macros.lua.j2' import policy_rule_forward_add %}
+
+{% if cfg.forward %}
+{% for fwd in cfg.forward %}
+{% for subtree in fwd.subtree %}
+{{ policy_rule_forward_add(subtree,fwd.options,fwd.servers) }}
+{% endfor %}
+{% endfor %}
+{% endif %}
diff --git a/python/knot_resolver/datamodel/templates/local_data.lua.j2 b/python/knot_resolver/datamodel/templates/local_data.lua.j2
new file mode 100644
index 00000000..8882471f
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/local_data.lua.j2
@@ -0,0 +1,41 @@
+{% from 'macros/local_data_macros.lua.j2' import local_data_rules, local_data_records, local_data_root_fallback_addresses, local_data_root_fallback_addresses_files, local_data_addresses, local_data_addresses_files %}
+{% from 'macros/common_macros.lua.j2' import boolean %}
+
+modules = { 'hints > iterate' }
+
+{# root-fallback-addresses #}
+{% if cfg.local_data.root_fallback_addresses -%}
+{{ local_data_root_fallback_addresses(cfg.local_data.root_fallback_addresses) }}
+{%- endif %}
+
+{# root-fallback-addresses-files #}
+{% if cfg.local_data.root_fallback_addresses_files -%}
+{{ local_data_root_fallback_addresses_files(cfg.local_data.root_fallback_addresses_files) }}
+{%- endif %}
+
+{# addresses #}
+{% if cfg.local_data.addresses -%}
+{{ local_data_addresses(cfg.local_data.addresses, cfg.local_data.nodata, cfg.local_data.ttl) }}
+{%- endif %}
+
+{# addresses-files #}
+{% if cfg.local_data.addresses_files -%}
+{{ local_data_addresses_files(cfg.local_data.addresses_files, cfg.local_data.nodata, cfg.local_data.ttl) }}
+{%- endif %}
+
+{# records #}
+{% if cfg.local_data.records -%}
+{{ local_data_records(cfg.local_data.records, false, cfg.local_data.nodata, cfg.local_data.ttl) }}
+{%- endif %}
+
+{# rules #}
+{% if cfg.local_data.rules -%}
+{{ local_data_rules(cfg.local_data.rules, cfg.local_data.nodata, cfg.local_data.ttl) }}
+{%- endif %}
+
+{# rpz #}
+{% if cfg.local_data.rpz -%}
+{% for rpz in cfg.local_data.rpz %}
+{{ local_data_records(rpz.file, true, cfg.local_data.nodata, cfg.local_data.ttl, rpz.tags) }}
+{% endfor %}
+{%- endif %}
diff --git a/python/knot_resolver/datamodel/templates/logging.lua.j2 b/python/knot_resolver/datamodel/templates/logging.lua.j2
new file mode 100644
index 00000000..2d5937a8
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/logging.lua.j2
@@ -0,0 +1,43 @@
+{% from 'macros/common_macros.lua.j2' import boolean %}
+
+-- logging.level
+log_level('{{ cfg.logging.level }}')
+
+{% if cfg.logging.target -%}
+-- logging.target
+log_target('{{ cfg.logging.target }}')
+{%- endif %}
+
+{% if cfg.logging.groups %}
+-- logging.groups
+log_groups({
+{% for g in cfg.logging.groups %}
+{% if g != "manager" and g != "supervisord" and g != "cache-gc" %}
+ '{{ g }}',
+{% endif %}
+{% endfor %}
+})
+{% endif %}
+
+{% if cfg.logging.dnssec_bogus %}
+modules.load('bogus_log')
+{% endif %}
+
+{% if cfg.logging.dnstap -%}
+-- logging.dnstap
+modules.load('dnstap')
+dnstap.config({
+ socket_path = '{{ cfg.logging.dnstap.unix_socket }}',
+ client = {
+ log_queries = {{ boolean(cfg.logging.dnstap.log_queries) }},
+ log_responses = {{ boolean(cfg.logging.dnstap.log_responses) }},
+ log_tcp_rtt = {{ boolean(cfg.logging.dnstap.log_tcp_rtt) }}
+ }
+})
+{%- endif %}
+
+-- logging.debugging.assertion-abort
+debugging.assertion_abort = {{ boolean(cfg.logging.debugging.assertion_abort) }}
+
+-- logging.debugging.assertion-fork
+debugging.assertion_fork = {{ cfg.logging.debugging.assertion_fork.millis() }}
diff --git a/python/knot_resolver/datamodel/templates/macros/cache_macros.lua.j2 b/python/knot_resolver/datamodel/templates/macros/cache_macros.lua.j2
new file mode 100644
index 00000000..51df48da
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/macros/cache_macros.lua.j2
@@ -0,0 +1,11 @@
+{% from 'macros/common_macros.lua.j2' import boolean, quotes, qtype_table %}
+
+
+{% macro cache_clear(params) -%}
+cache.clear(
+{{- quotes(params.name) if params.name else 'nil' -}},
+{{- boolean(params.exact_name) -}},
+{{- qtype_table(params.rr_type) if params.rr_type else 'nil' -}},
+{{- params.chunk_size if not params.exact_name else 'nil' -}}
+)
+{%- endmacro %}
diff --git a/python/knot_resolver/datamodel/templates/macros/common_macros.lua.j2 b/python/knot_resolver/datamodel/templates/macros/common_macros.lua.j2
new file mode 100644
index 00000000..4c2ba11a
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/macros/common_macros.lua.j2
@@ -0,0 +1,101 @@
+{% macro quotes(string) -%}
+'{{ string }}'
+{%- endmacro %}
+
+{% macro boolean(val, negation=false) -%}
+{%- if negation -%}
+{{ 'false' if val else 'true' }}
+{%- else-%}
+{{ 'true' if val else 'false' }}
+{%- endif -%}
+{%- endmacro %}
+
+{# Return string or table of strings #}
+{% macro string_table(table) -%}
+{%- if table is string -%}
+'{{ table|string }}'
+{%- else-%}
+{
+{%- for item in table -%}
+'{{ item|string }}',
+{%- endfor -%}
+}
+{%- endif -%}
+{%- endmacro %}
+
+{# Return str2ip or table of str2ip #}
+{% macro str2ip_table(table) -%}
+{%- if table is string -%}
+kres.str2ip('{{ table|string }}')
+{%- else-%}
+{
+{%- for item in table -%}
+kres.str2ip('{{ item|string }}'),
+{%- endfor -%}
+}
+{%- endif -%}
+{%- endmacro %}
+
+{# Return qtype or table of qtype #}
+{% macro qtype_table(table) -%}
+{%- if table is string -%}
+kres.type.{{ table|string }}
+{%- else-%}
+{
+{%- for item in table -%}
+kres.type.{{ item|string }},
+{%- endfor -%}
+}
+{%- endif -%}
+{%- endmacro %}
+
+{# Return server address or table of server addresses #}
+{% macro servers_table(servers) -%}
+{%- if servers is string -%}
+'{{ servers|string }}'
+{%- else-%}
+{
+{%- for item in servers -%}
+{%- if item.address -%}
+'{{ item.address|string }}',
+{%- else -%}
+'{{ item|string }}',
+{%- endif -%}
+{%- endfor -%}
+}
+{%- endif -%}
+{%- endmacro %}
+
+{# Return server address or table of server addresses #}
+{% macro tls_servers_table(servers) -%}
+{
+{%- for item in servers -%}
+{%- if item.address -%}
+{'{{ item.address|string }}',{{ tls_server_auth(item) }}},
+{%- else -%}
+'{{ item|string }}',
+{%- endif -%}
+{%- endfor -%}
+}
+{%- endmacro %}
+
+{% macro tls_server_auth(server) -%}
+{%- if server.hostname -%}
+hostname='{{ server.hostname|string }}',
+{%- endif -%}
+{%- if server.ca_file -%}
+ca_file='{{ server.ca_file|string }}',
+{%- endif -%}
+{%- if server.pin_sha256 -%}
+pin_sha256=
+{%- if server.pin_sha256 is string -%}
+'{{ server.pin_sha256|string }}',
+{%- else -%}
+{
+{%- for pin in server.pin_sha256 -%}
+'{{ pin|string }}',
+{%- endfor -%}
+}
+{%- endif -%}
+{%- endif -%}
+{%- endmacro %}
diff --git a/python/knot_resolver/datamodel/templates/macros/forward_macros.lua.j2 b/python/knot_resolver/datamodel/templates/macros/forward_macros.lua.j2
new file mode 100644
index 00000000..b7723fb0
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/macros/forward_macros.lua.j2
@@ -0,0 +1,42 @@
+{% from 'macros/common_macros.lua.j2' import boolean, string_table %}
+
+{% macro forward_options(options) -%}
+{dnssec={{ boolean(options.dnssec) }},auth={{ boolean(options.authoritative) }}}
+{%- endmacro %}
+
+{% macro forward_server(server) -%}
+{%- if server.address -%}
+{%- for addr in server.address -%}
+{'{{ addr }}',
+{%- if server.transport == 'tls' -%}
+tls=true,
+{%- else -%}
+tls=false,
+{%- endif -%}
+{%- if server.hostname -%}
+hostname='{{ server.hostname }}',
+{%- endif -%}
+{%- if server.pin_sha256 -%}
+pin_sha256={{ string_table(server.pin_sha256) }},
+{%- endif -%}
+{%- if server.ca_file -%}
+ca_file='{{ server.ca_file }}',
+{%- endif -%}
+},
+{%- endfor -%}
+{% else %}
+{'{{ server }}'},
+{%- endif -%}
+{%- endmacro %}
+
+{% macro forward_servers(servers) -%}
+{
+{%- for server in servers -%}
+{{ forward_server(server) }}
+{%- endfor -%}
+}
+{%- endmacro %}
+
+{% macro policy_rule_forward_add(subtree,options,servers) -%}
+policy.rule_forward_add('{{ subtree }}',{{ forward_options(options) }},{{ forward_servers(servers) }})
+{%- endmacro %}
diff --git a/python/knot_resolver/datamodel/templates/macros/local_data_macros.lua.j2 b/python/knot_resolver/datamodel/templates/macros/local_data_macros.lua.j2
new file mode 100644
index 00000000..0898571c
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/macros/local_data_macros.lua.j2
@@ -0,0 +1,101 @@
+{% from 'macros/common_macros.lua.j2' import string_table, boolean %}
+{% from 'macros/policy_macros.lua.j2' import policy_get_tagset, policy_todname %}
+
+
+{% macro local_data_root_fallback_addresses(pairs) -%}
+hints.root({
+{% for name, addresses in pairs.items() %}
+ ['{{ name }}']={{ string_table(addresses) }},
+{% endfor %}
+})
+{%- endmacro %}
+
+
+{% macro local_data_root_fallback_addresses_files(files) -%}
+{% for file in files %}
+hints.root_file('{{ file }}')
+{% endfor %}
+{%- endmacro %}
+
+{%- macro local_data_ttl(ttl) -%}
+{%- if ttl -%}
+{{ ttl.seconds() }}
+{%- else -%}
+{{ 'C.KR_RULE_TTL_DEFAULT' }}
+{%- endif -%}
+{%- endmacro -%}
+
+
+{% macro kr_rule_local_address(name, address, nodata, ttl, tags=none) -%}
+assert(C.kr_rule_local_address('{{ name }}', '{{ address }}',
+ {{ boolean(nodata) }}, {{ local_data_ttl(ttl)}}, {{ policy_get_tagset(tags) }}) == 0)
+{%- endmacro -%}
+
+
+{% macro local_data_addresses(pairs, nodata, ttl) -%}
+{% for name, addresses in pairs.items() %}
+{% for address in addresses %}
+{{ kr_rule_local_address(name, address, nodata, ttl) }}
+{% endfor %}
+{% endfor%}
+{%- endmacro %}
+
+
+{% macro kr_rule_local_hosts(file, nodata, ttl, tags=none) -%}
+assert(C.kr_rule_local_hosts('{{ file }}', {{ boolean(nodata) }},
+ {{ local_data_ttl(ttl)}}, {{ policy_get_tagset(tags) }}) == 0)
+{%- endmacro %}
+
+
+{% macro local_data_addresses_files(files, nodata, ttl, tags) -%}
+{% for file in files %}
+{{ kr_rule_local_hosts(file, nodata, ttl, tags) }}
+{% endfor %}
+{%- endmacro %}
+
+
+{% macro local_data_records(input_str, is_rpz, nodata, ttl, tags=none, id='rrs') -%}
+{{ id }} = ffi.new('struct kr_rule_zonefile_config')
+{{ id }}.ttl = {{ local_data_ttl(ttl) }}
+{{ id }}.tags = {{ policy_get_tagset(tags) }}
+{{ id }}.nodata = {{ boolean(nodata) }}
+{{ id }}.is_rpz = {{ boolean(is_rpz) }}
+{% if is_rpz -%}
+{{ id }}.filename = '{{ input_str }}'
+{% else %}
+{{ id }}.input_str = [[
+{{ input_str.multiline() }}
+]]
+{% endif %}
+assert(C.kr_rule_zonefile({{ id }})==0)
+{%- endmacro %}
+
+
+{% macro kr_rule_local_subtree(name, type, ttl, tags=none) -%}
+assert(C.kr_rule_local_subtree(todname('{{ name }}'),
+ C.KR_RULE_SUB_{{ type.upper() }}, {{ local_data_ttl(ttl) }}, {{ policy_get_tagset(tags) }}) == 0)
+{%- endmacro %}
+
+
+{% macro local_data_rules(items, nodata, ttl) -%}
+{% for item in items %}
+{% if item.name %}
+{% for name in item.name %}
+{% if item.address %}
+{% for address in item.address %}
+{{ kr_rule_local_address(name, address, nodata if item.nodata is none else item.nodata, item.ttl or ttl, item.tags) }}
+{% endfor %}
+{% endif %}
+{% if item.subtree %}
+{{ kr_rule_local_subtree(name, item.subtree, item.ttl or ttl, item.tags) }}
+{% endif %}
+{% endfor %}
+{% elif item.file %}
+{% for file in item.file %}
+{{ kr_rule_local_hosts(file, nodata if item.nodata is none else item.nodata, item.ttl or ttl, item.tags) }}
+{% endfor %}
+{% elif item.records %}
+{{ local_data_records(item.records, false, nodata if item.nodata is none else item.nodata, item.ttl or ttl, tags) }}
+{% endif %}
+{% endfor %}
+{%- endmacro %}
diff --git a/python/knot_resolver/datamodel/templates/macros/network_macros.lua.j2 b/python/knot_resolver/datamodel/templates/macros/network_macros.lua.j2
new file mode 100644
index 00000000..79800f7d
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/macros/network_macros.lua.j2
@@ -0,0 +1,55 @@
+{% macro http_config(http_cfg, kind, tls=true) -%}
+http.config({tls={{ 'true' if tls else 'false'}},
+{%- if http_cfg.cert_file -%}
+ cert='{{ http_cfg.cert_file }}',
+{%- endif -%}
+{%- if http_cfg.key_file -%}
+ key='{{ http_cfg.key_file }}',
+{%- endif -%}
+},'{{ kind }}')
+{%- endmacro %}
+
+
+{% macro listen_kind(kind) -%}
+{%- if kind == "dot" -%}
+'tls'
+{%- elif kind == "doh-legacy" -%}
+'doh_legacy'
+{%- else -%}
+'{{ kind }}'
+{%- endif -%}
+{%- endmacro %}
+
+
+{% macro net_listen_unix_socket(path, kind, freebind) -%}
+net.listen('{{ path }}',nil,{kind={{ listen_kind(kind) }},freebind={{ 'true' if freebind else 'false'}}})
+{%- endmacro %}
+
+
+{% macro net_listen_interface(interface, kind, freebind, port) -%}
+net.listen(
+{%- if interface.addr -%}
+'{{ interface.addr }}',
+{%- elif interface.if_name -%}
+net['{{ interface.if_name }}'],
+{%- endif -%}
+{%- if interface.port -%}
+{{ interface.port }},
+{%- else -%}
+{{ port }},
+{%- endif -%}
+{kind={{ listen_kind(kind) }},freebind={{ 'true' if freebind else 'false'}}})
+{%- endmacro %}
+
+
+{% macro network_listen(listen) -%}
+{%- if listen.unix_socket -%}
+{% for path in listen.unix_socket %}
+{{ net_listen_unix_socket(path, listen.kind, listen.freebind) }}
+{% endfor %}
+{%- elif listen.interface -%}
+{% for interface in listen.interface %}
+{{ net_listen_interface(interface, listen.kind, listen.freebind, listen.port) }}
+{% endfor %}
+{%- endif -%}
+{%- endmacro %} \ No newline at end of file
diff --git a/python/knot_resolver/datamodel/templates/macros/policy_macros.lua.j2 b/python/knot_resolver/datamodel/templates/macros/policy_macros.lua.j2
new file mode 100644
index 00000000..347532e6
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/macros/policy_macros.lua.j2
@@ -0,0 +1,279 @@
+{% from 'macros/common_macros.lua.j2' import string_table, str2ip_table, qtype_table, servers_table, tls_servers_table %}
+
+
+{# Add policy #}
+
+{% macro policy_add(rule, postrule=false) -%}
+{%- if postrule -%}
+policy.add({{ rule }},true)
+{%- else -%}
+policy.add({{ rule }})
+{%- endif -%}
+{%- endmacro %}
+
+
+{# Slice #}
+
+{% macro policy_slice_randomize_psl(seed='') -%}
+{%- if seed == '' -%}
+policy.slice_randomize_psl()
+{%- else -%}
+policy.slice_randomize_psl(seed={{ seed }})
+{%- endif -%}
+{%- endmacro %}
+
+{% macro policy_slice(func, actions) -%}
+policy.slice(
+{%- if func == 'randomize-psl' -%}
+policy.slice_randomize_psl()
+{%- else -%}
+policy.slice_randomize_psl()
+{%- endif -%}
+,{{ actions }})
+{%- endmacro %}
+
+
+{# Flags #}
+
+{% macro policy_flags(flags) -%}
+policy.FLAGS({
+{{- flags -}}
+})
+{%- endmacro %}
+
+
+{# Tags assign #}
+
+{% macro policy_tags_assign(tags) -%}
+policy.TAGS_ASSIGN({{ string_table(tags) }})
+{%- endmacro %}
+
+{% macro policy_get_tagset(tags) -%}
+{%- if tags -%}
+policy.get_tagset({{ string_table(tags) }})
+{%- else -%}
+0
+{%- endif -%}
+{%- endmacro %}
+
+
+{# Filters #}
+
+{% macro policy_all(action) -%}
+policy.all({{ action }})
+{%- endmacro %}
+
+{% macro policy_suffix(action, suffix_table) -%}
+policy.suffix({{ action }},{{ suffix_table }})
+{%- endmacro %}
+
+{% macro policy_suffix_common(action, suffix_table, common_suffix=none) -%}
+policy.suffix_common({{ action }},{{ suffix_table }}
+{%- if common_suffix -%}
+,{{ common_suffix }}
+{%- endif -%}
+)
+{%- endmacro %}
+
+{% macro policy_pattern(action, pattern) -%}
+policy.pattern({{ action }},'{{ pattern }}')
+{%- endmacro %}
+
+{% macro policy_rpz(action, path, watch=true) -%}
+policy.rpz({{ action|string }},'{{ path|string }}',{{ 'true' if watch else 'false' }})
+{%- endmacro %}
+
+
+{# Custom filters #}
+
+{% macro declare_policy_qtype_custom_filter() -%}
+function policy_qtype(action, qtype)
+
+ local function has_value (tab, val)
+ for index, value in ipairs(tab) do
+ if value == val then
+ return true
+ end
+ end
+
+ return false
+ end
+
+ return function (state, query)
+ if query.stype == qtype then
+ return action
+ elseif has_value(qtype, query.stype) then
+ return action
+ else
+ return nil
+ end
+ end
+end
+{%- endmacro %}
+
+{% macro policy_qtype_custom_filter(action, qtype) -%}
+policy_qtype({{ action }}, {{ qtype }})
+{%- endmacro %}
+
+
+{# Auto Filter #}
+
+{% macro policy_auto_filter(action, filter=none) -%}
+{%- if filter.suffix -%}
+{{ policy_suffix(action, policy_todname(filter.suffix)) }}
+{%- elif filter.pattern -%}
+{{ policy_pattern(action, filter.pattern) }}
+{%- elif filter.qtype -%}
+{{ policy_qtype_custom_filter(action, qtype_table(filter.qtype)) }}
+{%- else -%}
+{{ policy_all(action) }}
+{%- endif %}
+{%- endmacro %}
+
+
+{# Non-chain actions #}
+
+{% macro policy_pass() -%}
+policy.PASS
+{%- endmacro %}
+
+{% macro policy_deny() -%}
+policy.DENY
+{%- endmacro %}
+
+{% macro policy_deny_msg(message) -%}
+policy.DENY_MSG('{{ message|string }}')
+{%- endmacro %}
+
+{% macro policy_drop() -%}
+policy.DROP
+{%- endmacro %}
+
+{% macro policy_refuse() -%}
+policy.REFUSE
+{%- endmacro %}
+
+{% macro policy_tc() -%}
+policy.TC
+{%- endmacro %}
+
+{% macro policy_reroute(reroute) -%}
+policy.REROUTE(
+{%- for item in reroute -%}
+{['{{ item.source }}']='{{ item.destination }}'},
+{%- endfor -%}
+)
+{%- endmacro %}
+
+{% macro policy_answer(answer) -%}
+policy.ANSWER({[kres.type.{{ answer.rtype }}]={rdata=
+{%- if answer.rtype in ['A','AAAA'] -%}
+{{ str2ip_table(answer.rdata) }},
+{%- elif answer.rtype == '' -%}
+{# TODO: Do the same for other record types that require a special rdata type in Lua.
+By default, the raw string from config is used. #}
+{%- else -%}
+{{ string_table(answer.rdata) }},
+{%- endif -%}
+ttl={{ answer.ttl.seconds()|int }}}},{{ 'true' if answer.nodata else 'false' }})
+{%- endmacro %}
+
+{# policy.ANSWER( { [kres.type.A] = { rdata=kres.str2ip('192.0.2.7'), ttl=300 }}) #}
+
+{# Chain actions #}
+
+{% macro policy_mirror(mirror) -%}
+policy.MIRROR(
+{% if mirror is string %}
+'{{ mirror }}'
+{% else %}
+{
+{%- for addr in mirror -%}
+'{{ addr }}',
+{%- endfor -%}
+}
+{%- endif -%}
+)
+{%- endmacro %}
+
+{% macro policy_debug_always() -%}
+policy.DEBUG_ALWAYS
+{%- endmacro %}
+
+{% macro policy_debug_cache_miss() -%}
+policy.DEBUG_CACHE_MISS
+{%- endmacro %}
+
+{% macro policy_qtrace() -%}
+policy.QTRACE
+{%- endmacro %}
+
+{% macro policy_reqtrace() -%}
+policy.REQTRACE
+{%- endmacro %}
+
+{% macro policy_stub(servers) -%}
+policy.STUB({{ servers_table(servers) }})
+{%- endmacro %}
+
+{% macro policy_forward(servers) -%}
+policy.FORWARD({{ servers_table(servers) }})
+{%- endmacro %}
+
+{% macro policy_tls_forward(servers) -%}
+policy.TLS_FORWARD({{ tls_servers_table(servers) }})
+{%- endmacro %}
+
+
+{# Auto action #}
+
+{% macro policy_auto_action(rule) -%}
+{%- if rule.action == 'pass' -%}
+{{ policy_pass() }}
+{%- elif rule.action == 'deny' -%}
+{%- if rule.message -%}
+{{ policy_deny_msg(rule.message) }}
+{%- else -%}
+{{ policy_deny() }}
+{%- endif -%}
+{%- elif rule.action == 'drop' -%}
+{{ policy_drop() }}
+{%- elif rule.action == 'refuse' -%}
+{{ policy_refuse() }}
+{%- elif rule.action == 'tc' -%}
+{{ policy_tc() }}
+{%- elif rule.action == 'reroute' -%}
+{{ policy_reroute(rule.reroute) }}
+{%- elif rule.action == 'answer' -%}
+{{ policy_answer(rule.answer) }}
+{%- elif rule.action == 'mirror' -%}
+{{ policy_mirror(rule.mirror) }}
+{%- elif rule.action == 'debug-always' -%}
+{{ policy_debug_always() }}
+{%- elif rule.action == 'debug-cache-miss' -%}
+{{ policy_sebug_cache_miss() }}
+{%- elif rule.action == 'qtrace' -%}
+{{ policy_qtrace() }}
+{%- elif rule.action == 'reqtrace' -%}
+{{ policy_reqtrace() }}
+{%- endif -%}
+{%- endmacro %}
+
+
+{# Other #}
+
+{% macro policy_todname(name) -%}
+todname('{{ name.punycode()|string }}')
+{%- endmacro %}
+
+{% macro policy_todnames(names) -%}
+policy.todnames({
+{%- if names is string -%}
+'{{ names.punycode()|string }}'
+{%- else -%}
+{%- for name in names -%}
+'{{ name.punycode()|string }}',
+{%- endfor -%}
+{%- endif -%}
+})
+{%- endmacro %} \ No newline at end of file
diff --git a/python/knot_resolver/datamodel/templates/macros/view_macros.lua.j2 b/python/knot_resolver/datamodel/templates/macros/view_macros.lua.j2
new file mode 100644
index 00000000..2f1a7964
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/macros/view_macros.lua.j2
@@ -0,0 +1,25 @@
+{%- macro get_proto_set(protocols) -%}
+0
+{%- for p in protocols or [] -%}
+ + 2^C.KR_PROTO_{{ p.upper() }}
+{%- endfor -%}
+{%- endmacro -%}
+
+{% macro view_flags(options) -%}
+{% if not options.minimize -%}
+"NO_MINIMIZE",
+{%- endif %}
+{% if not options.dns64 -%}
+"DNS64_DISABLE",
+{%- endif %}
+{%- endmacro %}
+
+{% macro view_answer(answer) -%}
+{%- if answer == 'allow' -%}
+policy.TAGS_ASSIGN({})
+{%- elif answer == 'refused' -%}
+'policy.REFUSE'
+{%- elif answer == 'noanswer' -%}
+'policy.NO_ANSWER'
+{%- endif -%}
+{%- endmacro %}
diff --git a/python/knot_resolver/datamodel/templates/monitoring.lua.j2 b/python/knot_resolver/datamodel/templates/monitoring.lua.j2
new file mode 100644
index 00000000..624b59ab
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/monitoring.lua.j2
@@ -0,0 +1,33 @@
+--- control socket location
+local ffi = require('ffi')
+local id = os.getenv('SYSTEMD_INSTANCE')
+if not id then
+ log_error(ffi.C.LOG_GRP_SYSTEM, 'environment variable $SYSTEMD_INSTANCE not set, which should not have been possible due to running under manager')
+else
+ -- Bind to control socket in CWD (= rundir in config)
+ -- FIXME replace with relative path after fixing https://gitlab.nic.cz/knot/knot-resolver/-/issues/720
+ local path = '{{ cwd }}/control/'..id
+ log_warn(ffi.C.LOG_GRP_SYSTEM, 'path = ' .. path)
+ local ok, err = pcall(net.listen, path, nil, { kind = 'control' })
+ if not ok then
+ log_warn(ffi.C.LOG_GRP_NETWORK, 'bind to '..path..' failed '..err)
+ end
+end
+
+{% if cfg.monitoring.enabled == "always" %}
+modules.load('stats')
+{% endif %}
+
+--- function used for statistics collection
+function collect_lazy_statistics()
+ if stats == nil then
+ modules.load('stats')
+ end
+
+ return stats.list()
+end
+
+--- function used for statistics collection
+function collect_statistics()
+ return stats.list()
+end
diff --git a/python/knot_resolver/datamodel/templates/network.lua.j2 b/python/knot_resolver/datamodel/templates/network.lua.j2
new file mode 100644
index 00000000..665ee454
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/network.lua.j2
@@ -0,0 +1,102 @@
+{% from 'macros/common_macros.lua.j2' import boolean %}
+{% from 'macros/network_macros.lua.j2' import network_listen, http_config %}
+
+-- network.do-ipv4/6
+net.ipv4 = {{ boolean(cfg.network.do_ipv4) }}
+net.ipv6 = {{ boolean(cfg.network.do_ipv6) }}
+
+{% if cfg.network.out_interface_v4 %}
+-- network.out-interface-v4
+net.outgoing_v4('{{ cfg.network.out_interface_v4 }}')
+{% endif %}
+
+{% if cfg.network.out_interface_v6 %}
+-- network.out-interface-v6
+net.outgoing_v6('{{ cfg.network.out_interface_v6 }}')
+{% endif %}
+
+-- network.tcp-pipeline
+net.tcp_pipeline({{ cfg.network.tcp_pipeline }})
+
+-- network.edns-keep-alive
+{% if cfg.network.edns_tcp_keepalive %}
+modules.load('edns_keepalive')
+{% else %}
+modules.unload('edns_keepalive')
+{% endif %}
+
+-- network.edns-buffer-size
+net.bufsize(
+ {{ cfg.network.edns_buffer_size.upstream.bytes() }},
+ {{ cfg.network.edns_buffer_size.downstream.bytes() }}
+)
+
+{% if cfg.network.tls.cert_file and cfg.network.tls.key_file %}
+-- network.tls
+net.tls('{{ cfg.network.tls.cert_file }}', '{{ cfg.network.tls.key_file }}')
+{% endif %}
+
+{% if cfg.network.tls.sticket_secret %}
+-- network.tls.sticket-secret
+net.tls_sticket_secret('{{ cfg.network.tls.sticket_secret }}')
+{% endif %}
+
+{% if cfg.network.tls.sticket_secret_file %}
+-- network.tls.sticket-secret-file
+net.tls_sticket_secret_file('{{ cfg.network.tls.sticket_secret_file }}')
+{% endif %}
+
+{% if cfg.network.tls.auto_discovery %}
+-- network.tls.auto-discovery
+modules.load('experimental_dot_auth')
+{% else %}
+-- modules.unload('experimental_dot_auth')
+{% endif %}
+
+-- network.tls.padding
+net.tls_padding(
+{%- if cfg.network.tls.padding == true -%}
+true
+{%- elif cfg.network.tls.padding == false -%}
+false
+{%- else -%}
+{{ cfg.network.tls.padding }}
+{%- endif -%}
+)
+
+{% if cfg.network.address_renumbering %}
+-- network.address-renumbering
+modules.load('renumber')
+renumber.config = {
+{% for item in cfg.network.address_renumbering %}
+ {'{{ item.source }}', '{{ item.destination }}'},
+{% endfor %}
+}
+{% endif %}
+
+{%- set vars = {'doh_legacy': False} -%}
+{% for listen in cfg.network.listen if listen.kind == "doh-legacy" -%}
+{%- if vars.update({'doh_legacy': True}) -%}{%- endif -%}
+{%- endfor %}
+
+{% if vars.doh_legacy %}
+-- doh_legacy http config
+modules.load('http')
+{{ http_config(cfg.network.tls,"doh_legacy") }}
+{% endif %}
+
+{% if cfg.network.proxy_protocol %}
+-- network.proxy-protocol
+net.proxy_allowed({
+{% for item in cfg.network.proxy_protocol.allow %}
+'{{ item }}',
+{% endfor %}
+})
+{% else %}
+net.proxy_allowed({})
+{% endif %}
+
+-- network.listen
+{% for listen in cfg.network.listen %}
+{{ network_listen(listen) }}
+{% endfor %} \ No newline at end of file
diff --git a/python/knot_resolver/datamodel/templates/options.lua.j2 b/python/knot_resolver/datamodel/templates/options.lua.j2
new file mode 100644
index 00000000..8210fb6d
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/options.lua.j2
@@ -0,0 +1,52 @@
+{% from 'macros/common_macros.lua.j2' import boolean %}
+
+-- options.glue-checking
+mode('{{ cfg.options.glue_checking }}')
+
+{% if cfg.options.rebinding_protection %}
+-- options.rebinding-protection
+modules.load('rebinding < iterate')
+{% endif %}
+
+{% if cfg.options.violators_workarounds %}
+-- options.violators-workarounds
+modules.load('workarounds < iterate')
+{% endif %}
+
+{% if cfg.options.serve_stale %}
+-- options.serve-stale
+modules.load('serve_stale < cache')
+{% endif %}
+
+-- options.query-priming
+{% if cfg.options.priming %}
+modules.load('priming')
+{% else %}
+modules.unload('priming')
+{% endif %}
+
+-- options.time-jump-detection
+{% if cfg.options.time_jump_detection %}
+modules.load('detect_time_jump')
+{% else %}
+modules.unload('detect_time_jump')
+{% endif %}
+
+-- options.refuse-no-rd
+{% if cfg.options.refuse_no_rd %}
+modules.load('refuse_nord')
+{% else %}
+modules.unload('refuse_nord')
+{% endif %}
+
+-- options.qname-minimisation
+option('NO_MINIMIZE', {{ boolean(cfg.options.minimize,true) }})
+
+-- options.query-loopback
+option('ALLOW_LOCAL', {{ boolean(cfg.options.query_loopback) }})
+
+-- options.reorder-rrset
+option('REORDER_RR', {{ boolean(cfg.options.reorder_rrset) }})
+
+-- options.query-case-randomization
+option('NO_0X20', {{ boolean(cfg.options.query_case_randomization,true) }}) \ No newline at end of file
diff --git a/python/knot_resolver/datamodel/templates/policy-config.lua.j2 b/python/knot_resolver/datamodel/templates/policy-config.lua.j2
new file mode 100644
index 00000000..4c5c9048
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/policy-config.lua.j2
@@ -0,0 +1,40 @@
+{% if not cfg.lua.script_only %}
+
+-- FFI library
+ffi = require('ffi')
+local C = ffi.C
+
+-- logging.level
+log_level('{{ cfg.logging.level }}')
+
+{% if cfg.logging.target -%}
+-- logging.target
+log_target('{{ cfg.logging.target }}')
+{%- endif %}
+
+{% if cfg.logging.groups %}
+-- logging.groups
+log_groups({
+{% for g in cfg.logging.groups %}
+{% if g != "manager" and g != "supervisord" and g != "cache-gc" %}
+ '{{ g }}',
+{% endif %}
+{% endfor %}
+})
+{% endif %}
+
+-- Config required for the cache opening
+cache.open({{ cfg.cache.size_max.bytes() }}, 'lmdb://{{ cfg.cache.storage }}')
+
+-- VIEWS section ------------------------------------
+{% include "views.lua.j2" %}
+
+-- LOCAL-DATA section -------------------------------
+{% include "local_data.lua.j2" %}
+
+-- FORWARD section ----------------------------------
+{% include "forward.lua.j2" %}
+
+{% endif %}
+
+quit()
diff --git a/python/knot_resolver/datamodel/templates/static_hints.lua.j2 b/python/knot_resolver/datamodel/templates/static_hints.lua.j2
new file mode 100644
index 00000000..130facf9
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/static_hints.lua.j2
@@ -0,0 +1,51 @@
+{% if cfg.static_hints.etc_hosts or cfg.static_hints.root_hints_file or cfg.static_hints.hints_files or cfg.static_hints.root_hints or cfg.static_hints.hints %}
+modules.load('hints > iterate')
+
+{% if cfg.static_hints.ttl %}
+-- static-hints.ttl
+hints.ttl({{ cfg.static_hints.ttl.seconds()|string }})
+{% endif %}
+
+-- static-hints.no-data
+hints.use_nodata({{ 'true' if cfg.static_hints.nodata else 'false' }})
+
+{% if cfg.static_hints.etc_hosts %}
+-- static-hints.etc-hosts
+hints.add_hosts('/etc/hosts')
+{% endif %}
+
+{% if cfg.static_hints.root_hints_file %}
+-- static-hints.root-hints-file
+hints.root_file('{{ cfg.static_hints.root_hints_file }}')
+{% endif %}
+
+{% if cfg.static_hints.hints_files %}
+-- static-hints.hints-files
+{% for item in cfg.static_hints.hints_files %}
+hints.add_hosts('{{ item }}')
+{% endfor %}
+{% endif %}
+
+{% if cfg.static_hints.root_hints %}
+-- static-hints.root-hints
+hints.root({
+{% for name, addrs in cfg.static_hints.root_hints.items() %}
+['{{ name.punycode() }}'] = {
+{% for addr in addrs %}
+ '{{ addr }}',
+{% endfor %}
+ },
+{% endfor %}
+})
+{% endif %}
+
+{% if cfg.static_hints.hints %}
+-- static-hints.hints
+{% for name, addrs in cfg.static_hints.hints.items() %}
+{% for addr in addrs %}
+hints.set('{{ name.punycode() }} {{ addr }}')
+{% endfor %}
+{% endfor %}
+{% endif %}
+
+{% endif %} \ No newline at end of file
diff --git a/python/knot_resolver/datamodel/templates/views.lua.j2 b/python/knot_resolver/datamodel/templates/views.lua.j2
new file mode 100644
index 00000000..81de8c7b
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/views.lua.j2
@@ -0,0 +1,25 @@
+{% from 'macros/common_macros.lua.j2' import quotes %}
+{% from 'macros/view_macros.lua.j2' import get_proto_set, view_flags, view_answer %}
+{% from 'macros/policy_macros.lua.j2' import policy_flags, policy_tags_assign %}
+
+{% if cfg.views %}
+{% for view in cfg.views %}
+{% for subnet in view.subnets %}
+
+assert(C.kr_view_insert_action('{{ subnet }}', '{{ view.dst_subnet or '' }}',
+ {{ get_proto_set(view.protocols) }}, policy.COMBINE({
+{%- set flags = view_flags(view.options) -%}
+{% if flags %}
+ {{ quotes(policy_flags(flags)) }},
+{%- endif %}
+
+{% if view.tags %}
+ {{ policy_tags_assign(view.tags) }},
+{% elif view.answer %}
+ {{ view_answer(view.answer) }},
+{%- endif %}
+ })) == 0)
+
+{% endfor %}
+{% endfor %}
+{% endif %}
diff --git a/python/knot_resolver/datamodel/templates/webmgmt.lua.j2 b/python/knot_resolver/datamodel/templates/webmgmt.lua.j2
new file mode 100644
index 00000000..938ea8da
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/webmgmt.lua.j2
@@ -0,0 +1,25 @@
+{% from 'macros/common_macros.lua.j2' import boolean %}
+
+{% if cfg.webmgmt -%}
+-- webmgmt
+modules.load('http')
+http.config({tls = {{ boolean(cfg.webmgmt.tls) }},
+{%- if cfg.webmgmt.cert_file -%}
+ cert = '{{ cfg.webmgmt.cert_file }}',
+{%- endif -%}
+{%- if cfg.webmgmt.cert_file -%}
+ key = '{{ cfg.webmgmt.key_file }}',
+{%- endif -%}
+}, 'webmgmt')
+net.listen(
+{%- if cfg.webmgmt.unix_socket -%}
+ '{{ cfg.webmgmt.unix_socket }}',nil,
+{%- elif cfg.webmgmt.interface -%}
+ {%- if cfg.webmgmt.interface.addr -%}
+ '{{ cfg.webmgmt.interface.addr }}',{{ cfg.webmgmt.interface.port }},
+ {%- elif cfg.webmgmt.interface.if_name -%}
+ net.{{ cfg.webmgmt.interface.if_name }},{{ cfg.webmgmt.interface.port }},
+ {%- endif -%}
+{%- endif -%}
+{ kind = 'webmgmt' })
+{%- endif %} \ No newline at end of file
diff --git a/python/knot_resolver/datamodel/templates/worker-config.lua.j2 b/python/knot_resolver/datamodel/templates/worker-config.lua.j2
new file mode 100644
index 00000000..17c49fb0
--- /dev/null
+++ b/python/knot_resolver/datamodel/templates/worker-config.lua.j2
@@ -0,0 +1,58 @@
+{% if not cfg.lua.script_only %}
+
+-- FFI library
+ffi = require('ffi')
+local C = ffi.C
+
+-- Do not clear the DB with rules; we had it prepared by a different process.
+assert(C.kr_rules_init(nil, 0, false) == 0)
+
+-- hostname
+hostname('{{ cfg.hostname }}')
+
+{% if cfg.nsid %}
+-- nsid
+modules.load('nsid')
+nsid.name('{{ cfg.nsid }}' .. worker.id)
+{% endif %}
+
+-- LOGGING section ----------------------------------
+{% include "logging.lua.j2" %}
+
+-- MONITORING section -------------------------------
+{% include "monitoring.lua.j2" %}
+
+-- WEBMGMT section ----------------------------------
+{% include "webmgmt.lua.j2" %}
+
+-- OPTIONS section ----------------------------------
+{% include "options.lua.j2" %}
+
+-- NETWORK section ----------------------------------
+{% include "network.lua.j2" %}
+
+-- DNSSEC section -----------------------------------
+{% include "dnssec.lua.j2" %}
+
+-- FORWARD section ----------------------------------
+{% include "forward.lua.j2" %}
+
+-- CACHE section ------------------------------------
+{% include "cache.lua.j2" %}
+
+-- DNS64 section ------------------------------------
+{% include "dns64.lua.j2" %}
+
+{% endif %}
+
+-- LUA section --------------------------------------
+-- Custom Lua code cannot be validated
+
+{% if cfg.lua.script_file %}
+{% import cfg.lua.script_file as script_file %}
+{{ script_file }}
+{% endif %}
+
+{% if cfg.lua.script %}
+{{ cfg.lua.script }}
+{% endif %}
diff --git a/python/knot_resolver/datamodel/types/__init__.py b/python/knot_resolver/datamodel/types/__init__.py
new file mode 100644
index 00000000..a3d7db3e
--- /dev/null
+++ b/python/knot_resolver/datamodel/types/__init__.py
@@ -0,0 +1,69 @@
+from .enums import DNSRecordTypeEnum, PolicyActionEnum, PolicyFlagEnum
+from .files import AbsoluteDir, Dir, File, FilePath, ReadableFile, WritableDir, WritableFilePath
+from .generic_types import ListOrItem
+from .types import (
+ DomainName,
+ EscapedStr,
+ EscapedStr32B,
+ IDPattern,
+ Int0_512,
+ Int0_65535,
+ InterfaceName,
+ InterfaceOptionalPort,
+ InterfacePort,
+ IntNonNegative,
+ IntPositive,
+ IPAddress,
+ IPAddressEM,
+ IPAddressOptionalPort,
+ IPAddressPort,
+ IPNetwork,
+ IPv4Address,
+ IPv6Address,
+ IPv6Network,
+ IPv6Network96,
+ Percent,
+ PinSha256,
+ PortNumber,
+ SizeUnit,
+ TimeUnit,
+)
+
+__all__ = [
+ "PolicyActionEnum",
+ "PolicyFlagEnum",
+ "DNSRecordTypeEnum",
+ "DomainName",
+ "EscapedStr",
+ "EscapedStr32B",
+ "IDPattern",
+ "Int0_512",
+ "Int0_65535",
+ "InterfaceName",
+ "InterfaceOptionalPort",
+ "InterfacePort",
+ "IntNonNegative",
+ "IntPositive",
+ "IPAddress",
+ "IPAddressEM",
+ "IPAddressOptionalPort",
+ "IPAddressPort",
+ "IPNetwork",
+ "IPv4Address",
+ "IPv6Address",
+ "IPv6Network",
+ "IPv6Network96",
+ "ListOrItem",
+ "Percent",
+ "PinSha256",
+ "PortNumber",
+ "SizeUnit",
+ "TimeUnit",
+ "AbsoluteDir",
+ "ReadableFile",
+ "WritableDir",
+ "WritableFilePath",
+ "File",
+ "FilePath",
+ "Dir",
+]
diff --git a/python/knot_resolver/datamodel/types/base_types.py b/python/knot_resolver/datamodel/types/base_types.py
new file mode 100644
index 00000000..c2d60312
--- /dev/null
+++ b/python/knot_resolver/datamodel/types/base_types.py
@@ -0,0 +1,227 @@
+import re
+from typing import Any, Dict, Pattern, Type
+
+from knot_resolver.utils.modeling import BaseValueType
+
+
+class IntBase(BaseValueType):
+ """
+ Base class to work with integer value.
+ """
+
+ _orig_value: int
+ _value: int
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ if isinstance(source_value, int) and not isinstance(source_value, bool):
+ self._orig_value = source_value
+ self._value = source_value
+ else:
+ raise ValueError(
+ f"Unexpected value for '{type(self)}'."
+ f" Expected integer, got '{source_value}' with type '{type(source_value)}'",
+ object_path,
+ )
+
+ def __int__(self) -> int:
+ return self._value
+
+ def __str__(self) -> str:
+ return str(self._value)
+
+ def __repr__(self) -> str:
+ return f'{type(self).__name__}("{self._value}")'
+
+ def __eq__(self, o: object) -> bool:
+ return isinstance(o, IntBase) and o._value == self._value
+
+ def serialize(self) -> Any:
+ return self._orig_value
+
+ @classmethod
+ def json_schema(cls: Type["IntBase"]) -> Dict[Any, Any]:
+ return {"type": "integer"}
+
+
+class StrBase(BaseValueType):
+ """
+ Base class to work with string value.
+ """
+
+ _orig_value: str
+ _value: str
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ if isinstance(source_value, (str, int)) and not isinstance(source_value, bool):
+ self._orig_value = str(source_value)
+ self._value = str(source_value)
+ else:
+ raise ValueError(
+ f"Unexpected value for '{type(self)}'."
+ f" Expected string, got '{source_value}' with type '{type(source_value)}'",
+ object_path,
+ )
+
+ def __int__(self) -> int:
+ raise ValueError("Can't convert string to an integer.")
+
+ def __str__(self) -> str:
+ return self._value
+
+ def __repr__(self) -> str:
+ return f'{type(self).__name__}("{self._value}")'
+
+ def __hash__(self) -> int:
+ return hash(self._value)
+
+ def __eq__(self, o: object) -> bool:
+ return isinstance(o, StrBase) and o._value == self._value
+
+ def serialize(self) -> Any:
+ return self._orig_value
+
+ @classmethod
+ def json_schema(cls: Type["StrBase"]) -> Dict[Any, Any]:
+ return {"type": "string"}
+
+
+class StringLengthBase(StrBase):
+ """
+ Base class to work with string value length.
+ Just inherit the class and set the values for '_min_bytes' and '_max_bytes'.
+
+ class String32B(StringLengthBase):
+ _min_bytes: int = 32
+ """
+
+ _min_bytes: int = 1
+ _max_bytes: int
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ super().__init__(source_value, object_path)
+ value_bytes = len(self._value.encode("utf-8"))
+ if hasattr(self, "_min_bytes") and (value_bytes < self._min_bytes):
+ raise ValueError(
+ f"the string value {source_value} is shorter than the minimum {self._min_bytes} bytes.", object_path
+ )
+ if hasattr(self, "_max_bytes") and (value_bytes > self._max_bytes):
+ raise ValueError(
+ f"the string value {source_value} is longer than the maximum {self._max_bytes} bytes.", object_path
+ )
+
+ @classmethod
+ def json_schema(cls: Type["StringLengthBase"]) -> Dict[Any, Any]:
+ typ: Dict[str, Any] = {"type": "string"}
+ if hasattr(cls, "_min_bytes"):
+ typ["minLength"] = cls._min_bytes
+ if hasattr(cls, "_max_bytes"):
+ typ["maxLength"] = cls._max_bytes
+ return typ
+
+
+class IntRangeBase(IntBase):
+ """
+ Base class to work with integer value in range.
+ Just inherit the class and set the values for '_min' and '_max'.
+
+ class IntNonNegative(IntRangeBase):
+ _min: int = 0
+ """
+
+ _min: int
+ _max: int
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ super().__init__(source_value, object_path)
+ if hasattr(self, "_min") and (self._value < self._min):
+ raise ValueError(f"value {self._value} is lower than the minimum {self._min}.", object_path)
+ if hasattr(self, "_max") and (self._value > self._max):
+ raise ValueError(f"value {self._value} is higher than the maximum {self._max}", object_path)
+
+ @classmethod
+ def json_schema(cls: Type["IntRangeBase"]) -> Dict[Any, Any]:
+ typ: Dict[str, Any] = {"type": "integer"}
+ if hasattr(cls, "_min"):
+ typ["minimum"] = cls._min
+ if hasattr(cls, "_max"):
+ typ["maximum"] = cls._max
+ return typ
+
+
+class PatternBase(StrBase):
+ """
+ Base class to work with string value that match regex pattern.
+ Just inherit the class and set regex pattern for '_re'.
+
+ class ABPattern(PatternBase):
+ _re: Pattern[str] = re.compile(r"ab*")
+ """
+
+ _re: Pattern[str]
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ super().__init__(source_value, object_path)
+ if not type(self)._re.match(self._value):
+ raise ValueError(f"'{self._value}' does not match '{self._re.pattern}' pattern", object_path)
+
+ @classmethod
+ def json_schema(cls: Type["PatternBase"]) -> Dict[Any, Any]:
+ return {"type": "string", "pattern": rf"{cls._re.pattern}"}
+
+
+class UnitBase(StrBase):
+ """
+ Base class to work with string value that match regex pattern.
+ Just inherit the class and set '_units'.
+
+ class CustomUnit(PatternBase):
+ _units = {"b": 1, "kb": 1000}
+ """
+
+ _re: Pattern[str]
+ _units: Dict[str, int]
+ _base_value: int
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ super().__init__(source_value, object_path)
+
+ type(self)._re = re.compile(rf"^(\d+)({r'|'.join(type(self)._units.keys())})$")
+ grouped = self._re.search(self._value)
+ if grouped:
+ val, unit = grouped.groups()
+ if unit is None:
+ raise ValueError(f"Missing units. Accepted units are {list(type(self)._units.keys())}", object_path)
+ elif unit not in type(self)._units:
+ raise ValueError(
+ f"Used unexpected unit '{unit}' for {type(self).__name__}."
+ f" Accepted units are {list(type(self)._units.keys())}",
+ object_path,
+ )
+ self._base_value = int(val) * type(self)._units[unit]
+ else:
+ raise ValueError(
+ f"Unexpected value for '{type(self)}'."
+ " Expected string that matches pattern " + rf"'{type(self)._re.pattern}'."
+ f" Positive integer and one of the units {list(type(self)._units.keys())}, got '{source_value}'.",
+ object_path,
+ )
+
+ def __int__(self) -> int:
+ return self._base_value
+
+ def __repr__(self) -> str:
+ return f"Unit[{type(self).__name__},{self._value}]"
+
+ def __eq__(self, o: object) -> bool:
+ """
+ Two instances are equal when they represent the same size
+ regardless of their string representation.
+ """
+ return isinstance(o, UnitBase) and o._value == self._value
+
+ def serialize(self) -> Any:
+ return self._orig_value
+
+ @classmethod
+ def json_schema(cls: Type["UnitBase"]) -> Dict[Any, Any]:
+ return {"type": "string", "pattern": rf"{cls._re.pattern}"}
diff --git a/python/knot_resolver/datamodel/types/enums.py b/python/knot_resolver/datamodel/types/enums.py
new file mode 100644
index 00000000..bc93ae2f
--- /dev/null
+++ b/python/knot_resolver/datamodel/types/enums.py
@@ -0,0 +1,153 @@
+from typing_extensions import Literal
+
+# Policy actions
+PolicyActionEnum = Literal[
+ # Nonchain actions
+ "pass",
+ "deny",
+ "drop",
+ "refuse",
+ "tc",
+ "reroute",
+ "answer",
+ # Chain actions
+ "mirror",
+ "forward",
+ "stub",
+ "debug-always",
+ "debug-cache-miss",
+ "qtrace",
+ "reqtrace",
+]
+
+# FLAGS from https://www.knot-resolver.cz/documentation/latest/lib.html?highlight=options#c.kr_qflags
+PolicyFlagEnum = Literal[
+ "no-minimize",
+ "no-ipv4",
+ "no-ipv6",
+ "tcp",
+ "resolved",
+ "await-ipv4",
+ "await-ipv6",
+ "await-cut",
+ "no-edns",
+ "cached",
+ "no-cache",
+ "expiring",
+ "allow_local",
+ "dnssec-want",
+ "dnssec-bogus",
+ "dnssec-insecure",
+ "dnssec-cd",
+ "stub",
+ "always-cut",
+ "dnssec-wexpand",
+ "permissive",
+ "strict",
+ "badcookie-again",
+ "cname",
+ "reorder-rr",
+ "trace",
+ "no-0x20",
+ "dnssec-nods",
+ "dnssec-optout",
+ "nonauth",
+ "forward",
+ "dns64-mark",
+ "cache-tried",
+ "no-ns-found",
+ "pkt-is-sane",
+ "dns64-disable",
+]
+
+# DNS records from 'kres.type' table
+DNSRecordTypeEnum = Literal[
+ "A",
+ "A6",
+ "AAAA",
+ "AFSDB",
+ "ANY",
+ "APL",
+ "ATMA",
+ "AVC",
+ "AXFR",
+ "CAA",
+ "CDNSKEY",
+ "CDS",
+ "CERT",
+ "CNAME",
+ "CSYNC",
+ "DHCID",
+ "DLV",
+ "DNAME",
+ "DNSKEY",
+ "DOA",
+ "DS",
+ "EID",
+ "EUI48",
+ "EUI64",
+ "GID",
+ "GPOS",
+ "HINFO",
+ "HIP",
+ "HTTPS",
+ "IPSECKEY",
+ "ISDN",
+ "IXFR",
+ "KEY",
+ "KX",
+ "L32",
+ "L64",
+ "LOC",
+ "LP",
+ "MAILA",
+ "MAILB",
+ "MB",
+ "MD",
+ "MF",
+ "MG",
+ "MINFO",
+ "MR",
+ "MX",
+ "NAPTR",
+ "NID",
+ "NIMLOC",
+ "NINFO",
+ "NS",
+ "NSAP",
+ "NSAP-PTR",
+ "NSEC",
+ "NSEC3",
+ "NSEC3PARAM",
+ "NULL",
+ "NXT",
+ "OPENPGPKEY",
+ "OPT",
+ "PTR",
+ "PX",
+ "RKEY",
+ "RP",
+ "RRSIG",
+ "RT",
+ "SIG",
+ "SINK",
+ "SMIMEA",
+ "SOA",
+ "SPF",
+ "SRV",
+ "SSHFP",
+ "SVCB",
+ "TA",
+ "TALINK",
+ "TKEY",
+ "TLSA",
+ "TSIG",
+ "TXT",
+ "UID",
+ "UINFO",
+ "UNSPEC",
+ "URI",
+ "WKS",
+ "X25",
+ "ZONEMD",
+]
diff --git a/python/knot_resolver/datamodel/types/files.py b/python/knot_resolver/datamodel/types/files.py
new file mode 100644
index 00000000..920d90b1
--- /dev/null
+++ b/python/knot_resolver/datamodel/types/files.py
@@ -0,0 +1,245 @@
+import os
+import stat
+from enum import auto, Flag
+from grp import getgrnam
+from pathlib import Path
+from pwd import getpwnam
+from typing import Any, Dict, Tuple, Type, TypeVar
+
+from knot_resolver.manager.constants import kresd_group, kresd_user
+from knot_resolver.datamodel.globals import get_resolve_root, get_strict_validation
+from knot_resolver.utils.modeling.base_value_type import BaseValueType
+
+
+class UncheckedPath(BaseValueType):
+ """
+ Wrapper around pathlib.Path object. Can represent pretty much any Path. No checks are
+ performed on the value. The value is taken as is.
+ """
+
+ _value: Path
+
+ def __init__(
+ self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
+ ) -> None:
+ self._object_path: str = object_path
+ self._parents: Tuple[UncheckedPath, ...] = parents
+ self.strict_validation: bool = get_strict_validation()
+
+ if isinstance(source_value, str):
+ # we do not load global validation context if the path is absolute
+ # this prevents errors when constructing defaults in the schema
+ if source_value.startswith("/"):
+ resolve_root = Path("/")
+ else:
+ resolve_root = get_resolve_root()
+
+ self._raw_value: str = source_value
+ if self._parents:
+ pp = map(lambda p: p.to_path(), self._parents)
+ self._value: Path = Path(resolve_root, *pp, source_value)
+ else:
+ self._value: Path = Path(resolve_root, source_value)
+ else:
+ raise ValueError(f"expected file path in a string, got '{source_value}' with type '{type(source_value)}'.")
+
+ def __str__(self) -> str:
+ return str(self._value)
+
+ def __eq__(self, o: object) -> bool:
+ if not isinstance(o, UncheckedPath):
+ return False
+
+ return o._value == self._value
+
+ def __int__(self) -> int:
+ raise RuntimeError("Path cannot be converted to type <int>")
+
+ def to_path(self) -> Path:
+ return self._value
+
+ def serialize(self) -> Any:
+ return self._raw_value
+
+ def relative_to(self, parent: "UncheckedPath") -> "UncheckedPath":
+ """return a path with an added parent part"""
+ return UncheckedPath(self._raw_value, parents=(parent, *self._parents), object_path=self._object_path)
+
+ UPT = TypeVar("UPT", bound="UncheckedPath")
+
+ def reconstruct(self, cls: Type[UPT]) -> UPT:
+ """
+ Rebuild this object as an instance of its subclass. Practically, allows for conversions from
+ """
+ return cls(self._raw_value, parents=self._parents, object_path=self._object_path)
+
+ @classmethod
+ def json_schema(cls: Type["UncheckedPath"]) -> Dict[Any, Any]:
+ return {
+ "type": "string",
+ }
+
+
+class Dir(UncheckedPath):
+ """
+ Path, that is enforced to be:
+ - an existing directory
+ """
+
+ def __init__(
+ self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
+ ) -> None:
+ super().__init__(source_value, parents=parents, object_path=object_path)
+ if self.strict_validation and not self._value.is_dir():
+ raise ValueError(f"path '{self._value}' does not point to an existing directory")
+
+
+class AbsoluteDir(Dir):
+ """
+ Path, that is enforced to be:
+ - absolute
+ - an existing directory
+ """
+
+ def __init__(
+ self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
+ ) -> None:
+ super().__init__(source_value, parents=parents, object_path=object_path)
+ if self.strict_validation and not self._value.is_absolute():
+ raise ValueError(f"path '{self._value}' is not absolute")
+
+
+class File(UncheckedPath):
+ """
+ Path, that is enforced to be:
+ - an existing file
+ """
+
+ def __init__(
+ self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
+ ) -> None:
+ super().__init__(source_value, parents=parents, object_path=object_path)
+ if self.strict_validation and not self._value.exists():
+ raise ValueError(f"file '{self._value}' does not exist")
+ if self.strict_validation and not self._value.is_file():
+ raise ValueError(f"path '{self._value}' is not a file")
+
+
+class FilePath(UncheckedPath):
+ """
+ Path, that is enforced to be:
+ - parent of the last path segment is an existing directory
+ - it does not point to a dir
+ """
+
+ def __init__(
+ self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
+ ) -> None:
+ super().__init__(source_value, parents=parents, object_path=object_path)
+ p = self._value.parent
+ if self.strict_validation and (not p.exists() or not p.is_dir()):
+ raise ValueError(f"path '{self._value}' does not point inside an existing directory")
+
+ if self.strict_validation and self._value.is_dir():
+ raise ValueError(f"path '{self._value}' points to a directory when we expected a file")
+
+
+class _PermissionMode(Flag):
+ READ = auto()
+ WRITE = auto()
+ EXECUTE = auto()
+
+
+def _kres_accessible(dest_path: Path, perm_mode: _PermissionMode) -> bool:
+ chflags = {
+ _PermissionMode.READ: [stat.S_IRUSR, stat.S_IRGRP, stat.S_IROTH],
+ _PermissionMode.WRITE: [stat.S_IWUSR, stat.S_IWGRP, stat.S_IWOTH],
+ _PermissionMode.EXECUTE: [stat.S_IXUSR, stat.S_IXGRP, stat.S_IXOTH],
+ }
+
+ username = kresd_user()
+ groupname = kresd_group()
+
+ if username is None or groupname is None:
+ return True
+
+ user_uid = getpwnam(username).pw_uid
+ user_gid = getgrnam(groupname).gr_gid
+
+ dest_stat = os.stat(dest_path)
+ dest_uid = dest_stat.st_uid
+ dest_gid = dest_stat.st_gid
+ dest_mode = dest_stat.st_mode
+
+ def accessible(perm: _PermissionMode) -> bool:
+ if user_uid == dest_uid:
+ return bool(dest_mode & chflags[perm][0])
+ b_groups = os.getgrouplist(os.getlogin(), user_gid)
+ if user_gid == dest_gid or dest_gid in b_groups:
+ return bool(dest_mode & chflags[perm][1])
+ return bool(dest_mode & chflags[perm][2])
+
+ # __iter__ for class enum.Flag added in python3.11
+ # 'for perm in perm_mode:' failes for <=python3.11
+ for perm in _PermissionMode:
+ if perm in perm_mode:
+ if not accessible(perm):
+ return False
+ return True
+
+
+class ReadableFile(File):
+ """
+ Path, that is enforced to be:
+ - an existing file
+ - readable by knot-resolver processes
+ """
+
+ def __init__(
+ self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
+ ) -> None:
+ super().__init__(source_value, parents=parents, object_path=object_path)
+
+ if self.strict_validation and not _kres_accessible(self._value, _PermissionMode.READ):
+ raise ValueError(f"{kresd_user()}:{kresd_group()} has insufficient permissions to read '{self._value}'")
+
+
+class WritableDir(Dir):
+ """
+ Path, that is enforced to be:
+ - an existing directory
+ - writable/executable by knot-resolver processes
+ """
+
+ def __init__(
+ self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
+ ) -> None:
+ super().__init__(source_value, parents=parents, object_path=object_path)
+
+ if self.strict_validation and not _kres_accessible(
+ self._value, _PermissionMode.WRITE | _PermissionMode.EXECUTE
+ ):
+ raise ValueError(
+ f"{kresd_user()}:{kresd_group()} has insufficient permissions to write/execute '{self._value}'"
+ )
+
+
+class WritableFilePath(FilePath):
+ """
+ Path, that is enforced to be:
+ - parent of the last path segment is an existing directory
+ - it does not point to a dir
+ - writable/executable parent directory by knot-resolver processes
+ """
+
+ def __init__(
+ self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
+ ) -> None:
+ super().__init__(source_value, parents=parents, object_path=object_path)
+
+ if self.strict_validation and not _kres_accessible(
+ self._value.parent, _PermissionMode.WRITE | _PermissionMode.EXECUTE
+ ):
+ raise ValueError(
+ f"{kresd_user()}:{kresd_group()} has insufficient permissions to write/execute'{self._value.parent}'"
+ )
diff --git a/python/knot_resolver/datamodel/types/generic_types.py b/python/knot_resolver/datamodel/types/generic_types.py
new file mode 100644
index 00000000..8649a0f0
--- /dev/null
+++ b/python/knot_resolver/datamodel/types/generic_types.py
@@ -0,0 +1,38 @@
+from typing import Any, List, TypeVar, Union
+
+from knot_resolver.utils.modeling import BaseGenericTypeWrapper
+
+T = TypeVar("T")
+
+
+class ListOrItem(BaseGenericTypeWrapper[Union[List[T], T]]):
+ _value_orig: Union[List[T], T]
+ _list: List[T]
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None: # pylint: disable=unused-argument
+ self._value_orig: Union[List[T], T] = source_value
+
+ self._list: List[T] = source_value if isinstance(source_value, list) else [source_value]
+ if len(self) == 0:
+ raise ValueError("empty list is not allowed")
+
+ def __getitem__(self, index: Any) -> T:
+ return self._list[index]
+
+ def __int__(self) -> int:
+ raise ValueError(f"Can't convert '{type(self).__name__}' to an integer.")
+
+ def __str__(self) -> str:
+ return str(self._value_orig)
+
+ def to_std(self) -> List[T]:
+ return self._list
+
+ def __eq__(self, o: object) -> bool:
+ return isinstance(o, ListOrItem) and o._value_orig == self._value_orig
+
+ def __len__(self) -> int:
+ return len(self._list)
+
+ def serialize(self) -> Union[List[T], T]:
+ return self._value_orig
diff --git a/python/knot_resolver/datamodel/types/types.py b/python/knot_resolver/datamodel/types/types.py
new file mode 100644
index 00000000..ca8706d2
--- /dev/null
+++ b/python/knot_resolver/datamodel/types/types.py
@@ -0,0 +1,526 @@
+import ipaddress
+import re
+from typing import Any, Dict, Optional, Type, Union
+
+from knot_resolver.datamodel.types.base_types import (
+ IntRangeBase,
+ PatternBase,
+ StrBase,
+ StringLengthBase,
+ UnitBase,
+)
+from knot_resolver.utils.modeling import BaseValueType
+
+
+class IntNonNegative(IntRangeBase):
+ _min: int = 0
+
+
+class IntPositive(IntRangeBase):
+ _min: int = 1
+
+
+class Int0_512(IntRangeBase):
+ _min: int = 0
+ _max: int = 512
+
+
+class Int0_65535(IntRangeBase):
+ _min: int = 0
+ _max: int = 65_535
+
+
+class Percent(IntRangeBase):
+ _min: int = 0
+ _max: int = 100
+
+
+class PortNumber(IntRangeBase):
+ _min: int = 1
+ _max: int = 65_535
+
+ @classmethod
+ def from_str(cls: Type["PortNumber"], port: str, object_path: str = "/") -> "PortNumber":
+ try:
+ return cls(int(port), object_path)
+ except ValueError as e:
+ raise ValueError(f"invalid port number {port}") from e
+
+
+class SizeUnit(UnitBase):
+ _units = {"B": 1, "K": 1024, "M": 1024**2, "G": 1024**3}
+
+ def bytes(self) -> int:
+ return self._base_value
+
+ def mbytes(self) -> int:
+ return self._base_value // 1024**2
+
+
+class TimeUnit(UnitBase):
+ _units = {"us": 1, "ms": 10**3, "s": 10**6, "m": 60 * 10**6, "h": 3600 * 10**6, "d": 24 * 3600 * 10**6}
+
+ def minutes(self) -> int:
+ return self._base_value // 1000**2 // 60
+
+ def seconds(self) -> int:
+ return self._base_value // 1000**2
+
+ def millis(self) -> int:
+ return self._base_value // 1000
+
+ def micros(self) -> int:
+ return self._base_value
+
+
+class EscapedStr(StrBase):
+ """
+ A string where escape sequences are ignored and quotes are escaped.
+ """
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ super().__init__(source_value, object_path)
+
+ escape = {
+ "'": r"\'",
+ '"': r"\"",
+ "\a": r"\a",
+ "\n": r"\n",
+ "\r": r"\r",
+ "\t": r"\t",
+ "\b": r"\b",
+ "\f": r"\f",
+ "\v": r"\v",
+ "\0": r"\0",
+ }
+
+ s = list(self._value)
+ for i, c in enumerate(self._value):
+ if c in escape:
+ s[i] = escape[c]
+ elif not c.isalnum():
+ s[i] = repr(c)[1:-1]
+ self._value = "".join(s)
+
+ def multiline(self) -> str:
+ """
+ Lua multiline string is enclosed in double square brackets '[[ ]]'.
+ This method makes sure that double square brackets are escaped.
+ """
+
+ replace = {
+ "[[": r"\[\[",
+ "]]": r"\]\]",
+ }
+
+ ml = self._orig_value
+ for s, r in replace.items():
+ ml = ml.replace(s, r)
+ return ml
+
+
+class EscapedStr32B(EscapedStr, StringLengthBase):
+ """
+ Same as 'EscapedStr', but minimal length is 32 bytes.
+ """
+
+ _min_bytes: int = 32
+
+
+class DomainName(StrBase):
+ """
+ Fully or partially qualified domain name.
+ """
+
+ _punycode: str
+ _re = re.compile(
+ r"(?=^.{,253}\.?$)" # max 253 chars
+ r"(^(?!\.)" # do not start name with dot
+ r"((?!-)" # do not start label with hyphen
+ r"\.?[a-zA-Z0-9-]{,62}" # max 63 chars in label
+ r"[a-zA-Z0-9])+" # do not end label with hyphen
+ r"\.?$)" # end with or without '.'
+ r"|^\.$" # allow root-zone
+ )
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ super().__init__(source_value, object_path)
+ try:
+ punycode = self._value.encode("idna").decode("utf-8") if self._value != "." else "."
+ except ValueError as e:
+ raise ValueError(
+ f"conversion of '{self._value}' to IDN punycode representation failed",
+ object_path,
+ ) from e
+
+ if type(self)._re.match(punycode):
+ self._punycode = punycode
+ else:
+ raise ValueError(
+ f"'{source_value}' represented in punycode '{punycode}' does not match '{self._re.pattern}' pattern",
+ object_path,
+ )
+
+ def __hash__(self) -> int:
+ if self._value.endswith("."):
+ return hash(self._value)
+ return hash(f"{self._value}.")
+
+ def punycode(self) -> str:
+ return self._punycode
+
+ @classmethod
+ def json_schema(cls: Type["DomainName"]) -> Dict[Any, Any]:
+ return {"type": "string", "pattern": rf"{cls._re.pattern}"}
+
+
+class InterfaceName(PatternBase):
+ """
+ Network interface name.
+ """
+
+ _re = re.compile(r"^[a-zA-Z0-9]+(?:[-_][a-zA-Z0-9]+)*$")
+
+
+class IDPattern(PatternBase):
+ """
+ Alphanumerical ID for identifying systemd slice.
+ """
+
+ _re = re.compile(r"^(?!-)[a-z0-9-]*[a-z0-9]+$")
+
+
+class PinSha256(PatternBase):
+ """
+ A string that stores base64 encoded sha256.
+ """
+
+ _re = re.compile(r"^[A-Za-z\d+/]{43}=$")
+
+
+class InterfacePort(StrBase):
+ addr: Union[None, ipaddress.IPv4Address, ipaddress.IPv6Address] = None
+ if_name: Optional[InterfaceName] = None
+ port: PortNumber
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ super().__init__(source_value, object_path)
+
+ parts = self._value.split("@")
+ if len(parts) == 2:
+ try:
+ self.addr = ipaddress.ip_address(parts[0])
+ except ValueError as e1:
+ try:
+ self.if_name = InterfaceName(parts[0])
+ except ValueError as e2:
+ raise ValueError(
+ f"expected IP address or interface name, got '{parts[0]}'.", object_path
+ ) from e1 and e2
+ self.port = PortNumber.from_str(parts[1], object_path)
+ else:
+ raise ValueError(f"expected '<ip-address|interface-name>@<port>', got '{source_value}'.", object_path)
+
+
+class InterfaceOptionalPort(StrBase):
+ addr: Union[None, ipaddress.IPv4Address, ipaddress.IPv6Address] = None
+ if_name: Optional[InterfaceName] = None
+ port: Optional[PortNumber] = None
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ super().__init__(source_value, object_path)
+
+ parts = self._value.split("@")
+ if 0 < len(parts) < 3:
+ try:
+ self.addr = ipaddress.ip_address(parts[0])
+ except ValueError as e1:
+ try:
+ self.if_name = InterfaceName(parts[0])
+ except ValueError as e2:
+ raise ValueError(
+ f"expected IP address or interface name, got '{parts[0]}'.", object_path
+ ) from e1 and e2
+ if len(parts) == 2:
+ self.port = PortNumber.from_str(parts[1], object_path)
+ else:
+ raise ValueError(f"expected '<ip-address|interface-name>[@<port>]', got '{parts}'.", object_path)
+
+
+class IPAddressPort(StrBase):
+ addr: Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
+ port: PortNumber
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ super().__init__(source_value, object_path)
+
+ parts = self._value.split("@")
+ if len(parts) == 2:
+ self.port = PortNumber.from_str(parts[1], object_path)
+ try:
+ self.addr = ipaddress.ip_address(parts[0])
+ except ValueError as e:
+ raise ValueError(f"failed to parse IP address '{parts[0]}'.", object_path) from e
+ else:
+ raise ValueError(f"expected '<ip-address>@<port>', got '{source_value}'.", object_path)
+
+
+class IPAddressOptionalPort(StrBase):
+ addr: Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
+ port: Optional[PortNumber] = None
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ super().__init__(source_value)
+ parts = source_value.split("@")
+ if 0 < len(parts) < 3:
+ try:
+ self.addr = ipaddress.ip_address(parts[0])
+ except ValueError as e:
+ raise ValueError(f"failed to parse IP address '{parts[0]}'.", object_path) from e
+ if len(parts) == 2:
+ self.port = PortNumber.from_str(parts[1], object_path)
+ else:
+ raise ValueError(f"expected '<ip-address>[@<port>]', got '{parts}'.", object_path)
+
+
+class IPv4Address(BaseValueType):
+ _value: ipaddress.IPv4Address
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ if isinstance(source_value, str):
+ try:
+ self._value: ipaddress.IPv4Address = ipaddress.IPv4Address(source_value)
+ except ValueError as e:
+ raise ValueError("failed to parse IPv4 address.") from e
+ else:
+ raise ValueError(
+ "Unexpected value for a IPv4 address."
+ f" Expected string, got '{source_value}' with type '{type(source_value)}'",
+ object_path,
+ )
+
+ def to_std(self) -> ipaddress.IPv4Address:
+ return self._value
+
+ def __str__(self) -> str:
+ return str(self._value)
+
+ def __int__(self) -> int:
+ raise ValueError("Can't convert IPv4 address to an integer")
+
+ def __repr__(self) -> str:
+ return f'{type(self).__name__}("{self._value}")'
+
+ def __eq__(self, o: object) -> bool:
+ """
+ Two instances of IPv4Address are equal when they represent same IPv4 address as string.
+ """
+ return isinstance(o, IPv4Address) and str(o._value) == str(self._value)
+
+ def serialize(self) -> Any:
+ return str(self._value)
+
+ @classmethod
+ def json_schema(cls: Type["IPv4Address"]) -> Dict[Any, Any]:
+ return {
+ "type": "string",
+ }
+
+
+class IPv6Address(BaseValueType):
+ _value: ipaddress.IPv6Address
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ if isinstance(source_value, str):
+ try:
+ self._value: ipaddress.IPv6Address = ipaddress.IPv6Address(source_value)
+ except ValueError as e:
+ raise ValueError("failed to parse IPv6 address.") from e
+ else:
+ raise ValueError(
+ "Unexpected value for a IPv6 address."
+ f" Expected string, got '{source_value}' with type '{type(source_value)}'",
+ object_path,
+ )
+
+ def to_std(self) -> ipaddress.IPv6Address:
+ return self._value
+
+ def __str__(self) -> str:
+ return str(self._value)
+
+ def __int__(self) -> int:
+ raise ValueError("Can't convert IPv6 address to an integer")
+
+ def __repr__(self) -> str:
+ return f'{type(self).__name__}("{self._value}")'
+
+ def __eq__(self, o: object) -> bool:
+ """
+ Two instances of IPv6Address are equal when they represent same IPv6 address as string.
+ """
+ return isinstance(o, IPv6Address) and str(o._value) == str(self._value)
+
+ def serialize(self) -> Any:
+ return str(self._value)
+
+ @classmethod
+ def json_schema(cls: Type["IPv6Address"]) -> Dict[Any, Any]:
+ return {
+ "type": "string",
+ }
+
+
+IPAddress = Union[IPv4Address, IPv6Address]
+
+
+class IPAddressEM(BaseValueType):
+ """
+ IP address with exclamation mark suffix, e.g. '127.0.0.1!'.
+ """
+
+ _value: str
+ _addr: Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ if isinstance(source_value, str):
+ if source_value.endswith("!"):
+ addr, suff = source_value.split("!", 1)
+ if suff != "":
+ raise ValueError(f"suffix '{suff}' found after '!'.")
+ else:
+ raise ValueError("string does not end with '!'.")
+ try:
+ self._addr: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] = ipaddress.ip_address(addr)
+ self._value = source_value
+ except ValueError as e:
+ raise ValueError("failed to parse IP address.") from e
+ else:
+ raise ValueError(
+ "Unexpected value for a IPv6 address."
+ f" Expected string, got '{source_value}' with type '{type(source_value)}'",
+ object_path,
+ )
+
+ def to_std(self) -> str:
+ return self._value
+
+ def __str__(self) -> str:
+ return self._value
+
+ def __int__(self) -> int:
+ raise ValueError("Can't convert to an integer")
+
+ def __repr__(self) -> str:
+ return f'{type(self).__name__}("{self._value}")'
+
+ def __eq__(self, o: object) -> bool:
+ """
+ Two instances of IPAddressEM are equal when they represent same string.
+ """
+ return isinstance(o, IPAddressEM) and o._value == self._value
+
+ def serialize(self) -> Any:
+ return self._value
+
+ @classmethod
+ def json_schema(cls: Type["IPAddressEM"]) -> Dict[Any, Any]:
+ return {
+ "type": "string",
+ }
+
+
+class IPNetwork(BaseValueType):
+ _value: Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ if isinstance(source_value, str):
+ try:
+ self._value: Union[ipaddress.IPv4Network, ipaddress.IPv6Network] = ipaddress.ip_network(source_value)
+ except ValueError as e:
+ raise ValueError("failed to parse IP network.") from e
+ else:
+ raise ValueError(
+ "Unexpected value for a network subnet."
+ f" Expected string, got '{source_value}' with type '{type(source_value)}'"
+ )
+
+ def __int__(self) -> int:
+ raise ValueError("Can't convert network prefix to an integer")
+
+ def __str__(self) -> str:
+ return self._value.with_prefixlen
+
+ def __repr__(self) -> str:
+ return f'{type(self).__name__}("{self._value}")'
+
+ def __eq__(self, o: object) -> bool:
+ return isinstance(o, IPNetwork) and o._value == self._value
+
+ def to_std(self) -> Union[ipaddress.IPv4Network, ipaddress.IPv6Network]:
+ return self._value
+
+ def serialize(self) -> Any:
+ return self._value.with_prefixlen
+
+ @classmethod
+ def json_schema(cls: Type["IPNetwork"]) -> Dict[Any, Any]:
+ return {
+ "type": "string",
+ }
+
+
+class IPv6Network(BaseValueType):
+ _value: ipaddress.IPv6Network
+
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ if isinstance(source_value, str):
+ try:
+ self._value: ipaddress.IPv6Network = ipaddress.IPv6Network(source_value)
+ except ValueError as e:
+ raise ValueError("failed to parse IPv6 network.") from e
+ else:
+ raise ValueError(
+ "Unexpected value for a IPv6 network subnet."
+ f" Expected string, got '{source_value}' with type '{type(source_value)}'"
+ )
+
+ def to_std(self) -> ipaddress.IPv6Network:
+ return self._value
+
+ def __str__(self) -> str:
+ return self._value.with_prefixlen
+
+ def __int__(self) -> int:
+ raise ValueError("Can't convert network prefix to an integer")
+
+ def __repr__(self) -> str:
+ return f'{type(self).__name__}("{self._value}")'
+
+ def __eq__(self, o: object) -> bool:
+ return isinstance(o, IPv6Network) and o._value == self._value
+
+ def serialize(self) -> Any:
+ return self._value.with_prefixlen
+
+ @classmethod
+ def json_schema(cls: Type["IPv6Network"]) -> Dict[Any, Any]:
+ return {
+ "type": "string",
+ }
+
+
+class IPv6Network96(IPv6Network):
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ super().__init__(source_value, object_path=object_path)
+ if self._value.prefixlen == 128:
+ raise ValueError(
+ "Expected IPv6 network address with /96 prefix length."
+ " Submitted address has been interpreted as /128."
+ " Maybe, you forgot to add /96 after the base address?"
+ )
+
+ if self._value.prefixlen != 96:
+ raise ValueError(
+ "expected IPv6 network address with /96 prefix length." f" Got prefix lenght of {self._value.prefixlen}"
+ )
diff --git a/python/knot_resolver/datamodel/view_schema.py b/python/knot_resolver/datamodel/view_schema.py
new file mode 100644
index 00000000..b1d3adbe
--- /dev/null
+++ b/python/knot_resolver/datamodel/view_schema.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from typing_extensions import Literal
+
+from knot_resolver.datamodel.types import IDPattern, IPNetwork
+from knot_resolver.utils.modeling import ConfigSchema
+
+
+class ViewOptionsSchema(ConfigSchema):
+ """
+ Configuration options for clients identified by the view.
+
+ ---
+ minimize: Send minimum amount of information in recursive queries to enhance privacy.
+ dns64: Enable/disable DNS64.
+ """
+
+ minimize: bool = True
+ dns64: bool = True
+
+
+class ViewSchema(ConfigSchema):
+ """
+ Configuration parameters that allow you to create personalized policy rules and other.
+
+ ---
+ subnets: Identifies the client based on his subnet. Rule with more precise subnet takes priority.
+ dst_subnet: Destination subnet, as an additional condition.
+ protocols: Transport protocol, as an additional condition.
+ tags: Tags to link with other policy rules.
+ answer: Direct approach how to handle request from clients identified by the view.
+ options: Configuration options for clients identified by the view.
+ """
+
+ subnets: List[IPNetwork]
+ dst_subnet: Optional[IPNetwork] = None # could be a list as well, iterated in template
+ protocols: Optional[List[Literal["udp53", "tcp53", "dot", "doh", "doq"]]] = None
+
+ tags: Optional[List[IDPattern]] = None
+ answer: Optional[Literal["allow", "refused", "noanswer"]] = None
+ options: ViewOptionsSchema = ViewOptionsSchema()
+
+ def _validate(self) -> None:
+ if bool(self.tags) == bool(self.answer):
+ raise ValueError("exactly one of 'tags' and 'answer' must be configured")
diff --git a/python/knot_resolver/datamodel/webmgmt_schema.py b/python/knot_resolver/datamodel/webmgmt_schema.py
new file mode 100644
index 00000000..ce18376b
--- /dev/null
+++ b/python/knot_resolver/datamodel/webmgmt_schema.py
@@ -0,0 +1,27 @@
+from typing import Optional
+
+from knot_resolver.datamodel.types import WritableFilePath, InterfacePort, ReadableFile
+from knot_resolver.utils.modeling import ConfigSchema
+
+
+class WebmgmtSchema(ConfigSchema):
+ """
+ Configuration of legacy web management endpoint.
+
+ ---
+ unix_socket: Path to unix domain socket to listen to.
+ interface: IP address or interface name with port number to listen to.
+ tls: Enable/disable TLS.
+ cert_file: Path to certificate file.
+ key_file: Path to certificate key.
+ """
+
+ unix_socket: Optional[WritableFilePath] = None
+ interface: Optional[InterfacePort] = None
+ tls: bool = False
+ cert_file: Optional[ReadableFile] = None
+ key_file: Optional[ReadableFile] = None
+
+ def _validate(self) -> None:
+ if bool(self.unix_socket) == bool(self.interface):
+ raise ValueError("One of 'interface' or 'unix-socket' must be configured.")
diff --git a/python/knot_resolver/manager/__init__.py b/python/knot_resolver/manager/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/python/knot_resolver/manager/__init__.py
diff --git a/python/knot_resolver/manager/__main__.py b/python/knot_resolver/manager/__main__.py
new file mode 100644
index 00000000..26aae1d6
--- /dev/null
+++ b/python/knot_resolver/manager/__main__.py
@@ -0,0 +1,5 @@
+from knot_resolver.manager.main import main
+
+
+if __name__ == "__main__":
+ main()
diff --git a/python/knot_resolver/manager/config_store.py b/python/knot_resolver/manager/config_store.py
new file mode 100644
index 00000000..1c0174f2
--- /dev/null
+++ b/python/knot_resolver/manager/config_store.py
@@ -0,0 +1,101 @@
+import asyncio
+from asyncio import Lock
+from typing import Any, Awaitable, Callable, List, Tuple
+
+from knot_resolver.datamodel import KresConfig
+from knot_resolver.manager.exceptions import KresManagerException
+from knot_resolver.utils.functional import Result
+from knot_resolver.utils.modeling.exceptions import DataParsingError
+from knot_resolver.utils.modeling.types import NoneType
+
+VerifyCallback = Callable[[KresConfig, KresConfig], Awaitable[Result[None, str]]]
+UpdateCallback = Callable[[KresConfig], Awaitable[None]]
+
+
+class ConfigStore:
+ def __init__(self, initial_config: KresConfig):
+ self._config = initial_config
+ self._verifiers: List[VerifyCallback] = []
+ self._callbacks: List[UpdateCallback] = []
+ self._update_lock: Lock = Lock()
+
+ async def update(self, config: KresConfig) -> None:
+ # invoke pre-change verifiers
+ results: Tuple[Result[None, str], ...] = tuple(
+ await asyncio.gather(*[ver(self._config, config) for ver in self._verifiers])
+ )
+ err_res = filter(lambda r: r.is_err(), results)
+ errs = list(map(lambda r: r.unwrap_err(), err_res))
+ if len(errs) > 0:
+ raise KresManagerException("Configuration validation failed. The reasons are:\n - " + "\n - ".join(errs))
+
+ async with self._update_lock:
+ # update the stored config with the new version
+ self._config = config
+
+ # invoke change callbacks
+ for call in self._callbacks:
+ await call(config)
+
+ async def renew(self) -> None:
+ await self.update(self._config)
+
+ async def register_verifier(self, verifier: VerifyCallback) -> None:
+ self._verifiers.append(verifier)
+ res = await verifier(self.get(), self.get())
+ if res.is_err():
+ raise DataParsingError(f"Initial config verification failed with error: {res.unwrap_err()}")
+
+ async def register_on_change_callback(self, callback: UpdateCallback) -> None:
+ """
+ Registers new callback and immediatelly calls it with current config
+ """
+
+ self._callbacks.append(callback)
+ await callback(self.get())
+
+ def get(self) -> KresConfig:
+ return self._config
+
+
+def only_on_real_changes_update(selector: Callable[[KresConfig], Any]) -> Callable[[UpdateCallback], UpdateCallback]:
+ def decorator(orig_func: UpdateCallback) -> UpdateCallback:
+ original_value_set: Any = False
+ original_value: Any = None
+
+ async def new_func_update(config: KresConfig) -> None:
+ nonlocal original_value_set
+ nonlocal original_value
+ if not original_value_set:
+ original_value_set = True
+ original_value = selector(config)
+ await orig_func(config)
+ elif original_value != selector(config):
+ original_value = selector(config)
+ await orig_func(config)
+
+ return new_func_update
+
+ return decorator
+
+
+def only_on_real_changes_verifier(selector: Callable[[KresConfig], Any]) -> Callable[[VerifyCallback], VerifyCallback]:
+ def decorator(orig_func: VerifyCallback) -> VerifyCallback:
+ original_value_set: Any = False
+ original_value: Any = None
+
+ async def new_func_verifier(old: KresConfig, new: KresConfig) -> Result[NoneType, str]:
+ nonlocal original_value_set
+ nonlocal original_value
+ if not original_value_set:
+ original_value_set = True
+ original_value = selector(new)
+ await orig_func(old, new)
+ elif original_value != selector(new):
+ original_value = selector(new)
+ await orig_func(old, new)
+ return Result.ok(None)
+
+ return new_func_verifier
+
+ return decorator
diff --git a/python/knot_resolver/manager/constants.py b/python/knot_resolver/manager/constants.py
new file mode 100644
index 00000000..832d3fa9
--- /dev/null
+++ b/python/knot_resolver/manager/constants.py
@@ -0,0 +1,108 @@
+import importlib.util
+import logging
+from pathlib import Path
+from typing import TYPE_CHECKING, Optional
+
+# Install config is semi-optional - only needed to actually run Manager, but not
+# for its unit tests.
+if importlib.util.find_spec("knot_resolver"):
+ import knot_resolver # type: ignore[import-not-found]
+else:
+ knot_resolver = None
+
+if TYPE_CHECKING:
+ from knot_resolver.manager.config_store import ConfigStore
+ from knot_resolver.datamodel.config_schema import KresConfig
+ from knot_resolver.controller.interface import KresID
+
+STARTUP_LOG_LEVEL = logging.DEBUG
+DEFAULT_MANAGER_CONFIG_FILE = Path("/etc/knot-resolver/config.yaml")
+CONFIG_FILE_ENV_VAR = "KRES_MANAGER_CONFIG"
+API_SOCK_ENV_VAR = "KRES_MANAGER_API_SOCK"
+MANAGER_FIX_ATTEMPT_MAX_COUNTER = 2
+FIX_COUNTER_DECREASE_INTERVAL_SEC = 30 * 60
+PID_FILE_NAME = "manager.pid"
+MAX_WORKERS = 256
+
+
+def kresd_executable() -> Path:
+ assert knot_resolver is not None
+ return knot_resolver.sbin_dir / "kresd"
+
+
+def kres_gc_executable() -> Path:
+ assert knot_resolver is not None
+ return knot_resolver.sbin_dir / "kres-cache-gc"
+
+
+def kresd_user():
+ return None if knot_resolver is None else knot_resolver.user
+
+
+def kresd_group():
+ return None if knot_resolver is None else knot_resolver.group
+
+
+def kresd_cache_dir(config: "KresConfig") -> Path:
+ return config.cache.storage.to_path()
+
+
+def policy_loader_config_file(_config: "KresConfig") -> Path:
+ return Path("policy-loader.conf")
+
+
+def kresd_config_file(_config: "KresConfig", kres_id: "KresID") -> Path:
+ return Path(f"kresd{int(kres_id)}.conf")
+
+
+def kresd_config_file_supervisord_pattern(_config: "KresConfig") -> Path:
+ return Path("kresd%(process_num)d.conf")
+
+
+def supervisord_config_file(_config: "KresConfig") -> Path:
+ return Path("supervisord.conf")
+
+
+def supervisord_config_file_tmp(_config: "KresConfig") -> Path:
+ return Path("supervisord.conf.tmp")
+
+
+def supervisord_pid_file(_config: "KresConfig") -> Path:
+ return Path("supervisord.pid")
+
+
+def supervisord_sock_file(_config: "KresConfig") -> Path:
+ return Path("supervisord.sock")
+
+
+def supervisord_subprocess_log_dir(_config: "KresConfig") -> Path:
+ return Path("logs")
+
+
+WATCHDOG_INTERVAL: float = 5
+"""
+Used in KresdManager. It's a number of seconds in between system health checks.
+"""
+
+
+class _UserConstants:
+ """
+ Class for accessing constants, which are technically not constants as they are user configurable.
+ """
+
+ def __init__(self, config_store: "ConfigStore", working_directory_on_startup: str) -> None:
+ self._config_store = config_store
+ self.working_directory_on_startup = working_directory_on_startup
+
+
+_user_constants: Optional[_UserConstants] = None
+
+
+async def init_user_constants(config_store: "ConfigStore", working_directory_on_startup: str) -> None:
+ global _user_constants
+ _user_constants = _UserConstants(config_store, working_directory_on_startup)
+
+
+def user_constants() -> _UserConstants:
+ assert _user_constants is not None
+ return _user_constants
diff --git a/python/knot_resolver/manager/exceptions.py b/python/knot_resolver/manager/exceptions.py
new file mode 100644
index 00000000..5b05d98e
--- /dev/null
+++ b/python/knot_resolver/manager/exceptions.py
@@ -0,0 +1,28 @@
+from typing import List
+
+
+class CancelStartupExecInsteadException(Exception):
+ """
+ Exception used for terminating system startup and instead
+ causing an exec of something else. Could be used by subprocess
+ controllers such as supervisord to allow them to run as top-level
+ process in a process tree.
+ """
+
+ def __init__(self, exec_args: List[str], *args: object) -> None:
+ self.exec_args = exec_args
+ super().__init__(*args)
+
+
+class KresManagerException(Exception):
+ """
+ Base class for all custom exceptions we use in our code
+ """
+
+
+class SubprocessControllerException(KresManagerException):
+ pass
+
+
+class SubprocessControllerTimeoutException(KresManagerException):
+ pass
diff --git a/python/knot_resolver/manager/kres_manager.py b/python/knot_resolver/manager/kres_manager.py
new file mode 100644
index 00000000..3dbc1079
--- /dev/null
+++ b/python/knot_resolver/manager/kres_manager.py
@@ -0,0 +1,429 @@
+import asyncio
+import logging
+import sys
+import time
+from secrets import token_hex
+from subprocess import SubprocessError
+from typing import Any, Callable, List, Optional
+
+from knot_resolver.compat.asyncio import create_task
+from knot_resolver.manager.config_store import (
+ ConfigStore,
+ only_on_real_changes_update,
+ only_on_real_changes_verifier,
+)
+from knot_resolver.manager.constants import (
+ FIX_COUNTER_DECREASE_INTERVAL_SEC,
+ MANAGER_FIX_ATTEMPT_MAX_COUNTER,
+ WATCHDOG_INTERVAL,
+)
+from knot_resolver.manager.exceptions import SubprocessControllerException
+from knot_resolver.controller.interface import (
+ Subprocess,
+ SubprocessController,
+ SubprocessStatus,
+ SubprocessType,
+)
+from knot_resolver.controller.registered_workers import (
+ command_registered_workers,
+ get_registered_workers_kresids,
+)
+from knot_resolver.utils.functional import Result
+from knot_resolver.utils.modeling.types import NoneType
+
+from knot_resolver import KresConfig
+
+logger = logging.getLogger(__name__)
+
+
+class _FixCounter:
+ def __init__(self) -> None:
+ self._counter = 0
+ self._timestamp = time.time()
+
+ def increase(self) -> None:
+ self._counter += 1
+ self._timestamp = time.time()
+
+ def try_decrease(self) -> None:
+ if time.time() - self._timestamp > FIX_COUNTER_DECREASE_INTERVAL_SEC:
+ if self._counter > 0:
+ logger.info(
+ f"Enough time has passed since last detected instability, decreasing fix attempt counter to {self._counter}"
+ )
+ self._counter -= 1
+ self._timestamp = time.time()
+
+ def __str__(self) -> str:
+ return str(self._counter)
+
+ def is_too_high(self) -> bool:
+ return self._counter >= MANAGER_FIX_ATTEMPT_MAX_COUNTER
+
+
+async def _deny_max_worker_changes(config_old: KresConfig, config_new: KresConfig) -> Result[None, str]:
+ if config_old.max_workers != config_new.max_workers:
+ return Result.err(
+ "Changing 'max-workers', the maximum number of workers allowed to run, is not allowed at runtime."
+ )
+
+ return Result.ok(None)
+
+
+class KresManager: # pylint: disable=too-many-instance-attributes
+ """
+ Core of the whole operation. Orchestrates individual instances under some
+ service manager like systemd.
+
+ Instantiate with `KresManager.create()`, not with the usual constructor!
+ """
+
+ def __init__(self, shutdown_trigger: Callable[[int], None], _i_know_what_i_am_doing: bool = False):
+ if not _i_know_what_i_am_doing:
+ logger.error(
+ "Trying to create an instance of KresManager using normal constructor. Please use "
+ "`KresManager.get_instance()` instead"
+ )
+ assert False
+
+ self._workers: List[Subprocess] = []
+ self._gc: Optional[Subprocess] = None
+ self._policy_loader: Optional[Subprocess] = None
+ self._manager_lock = asyncio.Lock()
+ self._workers_reset_needed: bool = False
+ self._controller: SubprocessController
+ self._watchdog_task: Optional["asyncio.Task[None]"] = None
+ self._fix_counter: _FixCounter = _FixCounter()
+ self._config_store: ConfigStore
+ self._shutdown_trigger: Callable[[int], None] = shutdown_trigger
+
+ @staticmethod
+ async def create(
+ subprocess_controller: SubprocessController,
+ config_store: ConfigStore,
+ shutdown_trigger: Callable[[int], None],
+ ) -> "KresManager":
+ """
+ Creates new instance of KresManager.
+ """
+
+ inst = KresManager(shutdown_trigger, _i_know_what_i_am_doing=True)
+ await inst._async_init(subprocess_controller, config_store) # pylint: disable=protected-access
+ return inst
+
+ async def _async_init(self, subprocess_controller: SubprocessController, config_store: ConfigStore) -> None:
+ self._controller = subprocess_controller
+ self._config_store = config_store
+
+ # initialize subprocess controller
+ logger.debug("Starting controller")
+ await self._controller.initialize_controller(config_store.get())
+ self._watchdog_task = create_task(self._watchdog())
+ logger.debug("Looking for already running workers")
+ await self._collect_already_running_workers()
+
+ # register and immediately call a verifier that loads policy rules into the rules database
+ await config_store.register_verifier(self.load_policy_rules)
+
+ # configuration nodes that are relevant to kresd workers and the cache garbage collector
+ def config_nodes(config: KresConfig) -> List[Any]:
+ return [
+ config.nsid,
+ config.hostname,
+ config.workers,
+ config.max_workers,
+ config.webmgmt,
+ config.options,
+ config.network,
+ config.forward,
+ config.cache,
+ config.dnssec,
+ config.dns64,
+ config.logging,
+ config.monitoring,
+ config.lua,
+ ]
+
+ # register and immediately call a verifier that validates config with 'canary' kresd process
+ await config_store.register_verifier(only_on_real_changes_verifier(config_nodes)(self.validate_config))
+
+ # register and immediately call a callback to apply config to all 'kresd' workers and 'cache-gc'
+ await config_store.register_on_change_callback(only_on_real_changes_update(config_nodes)(self.apply_config))
+
+ # register callback to reset policy rules for each 'kresd' worker
+ await config_store.register_on_change_callback(self.reset_workers_policy_rules)
+
+ # register and immediately call a callback to set new TLS session ticket secret for 'kresd' workers
+ await config_store.register_on_change_callback(
+ only_on_real_changes_update(config_nodes)(self.set_new_tls_sticket_secret)
+ )
+
+ # register controller config change listeners
+ await config_store.register_verifier(_deny_max_worker_changes)
+
+ async def _spawn_new_worker(self, config: KresConfig) -> None:
+ subprocess = await self._controller.create_subprocess(config, SubprocessType.KRESD)
+ await subprocess.start()
+ self._workers.append(subprocess)
+
+ async def _stop_a_worker(self) -> None:
+ if len(self._workers) == 0:
+ raise IndexError("Can't stop a kresd when there are no running")
+
+ subprocess = self._workers.pop()
+ await subprocess.stop()
+
+ async def _collect_already_running_workers(self) -> None:
+ for subp in await self._controller.get_all_running_instances():
+ if subp.type == SubprocessType.KRESD:
+ self._workers.append(subp)
+ elif subp.type == SubprocessType.GC:
+ assert self._gc is None
+ self._gc = subp
+ elif subp.type == SubprocessType.POLICY_LOADER:
+ assert self._policy_loader is None
+ self._policy_loader = subp
+ else:
+ raise RuntimeError("unexpected subprocess type")
+
+ async def _rolling_restart(self, new_config: KresConfig) -> None:
+ for kresd in self._workers:
+ await kresd.apply_new_config(new_config)
+
+ async def _ensure_number_of_children(self, config: KresConfig, n: int) -> None:
+ # kill children that are not needed
+ while len(self._workers) > n:
+ await self._stop_a_worker()
+
+ # spawn new children if needed
+ while len(self._workers) < n:
+ await self._spawn_new_worker(config)
+
+ async def _run_policy_loader(self, config: KresConfig) -> None:
+ if self._policy_loader:
+ await self._policy_loader.start(config)
+ else:
+ subprocess = await self._controller.create_subprocess(config, SubprocessType.POLICY_LOADER)
+ await subprocess.start()
+ self._policy_loader = subprocess
+
+ def _is_policy_loader_exited(self) -> bool:
+ if self._policy_loader:
+ return self._policy_loader.status() is SubprocessStatus.EXITED
+ return False
+
+ def _is_gc_running(self) -> bool:
+ return self._gc is not None
+
+ async def _start_gc(self, config: KresConfig) -> None:
+ subprocess = await self._controller.create_subprocess(config, SubprocessType.GC)
+ await subprocess.start()
+ self._gc = subprocess
+
+ async def _stop_gc(self) -> None:
+ assert self._gc is not None
+ await self._gc.stop()
+ self._gc = None
+
+ async def validate_config(self, _old: KresConfig, new: KresConfig) -> Result[NoneType, str]:
+ async with self._manager_lock:
+ logger.debug("Testing the new config with a canary process")
+ try:
+ # technically, this has side effects of leaving a new process runnning
+ # but it's practically not a problem, because
+ # if it keeps running, the config is valid and others will soon join as well
+ # if it crashes and the startup fails, then well, it's not running anymore... :)
+ await self._spawn_new_worker(new)
+ except (SubprocessError, SubprocessControllerException):
+ logger.error("Kresd with the new config failed to start, rejecting config")
+ return Result.err("canary kresd process failed to start. Config might be invalid.")
+
+ logger.debug("Canary process test passed.")
+ return Result.ok(None)
+
+ async def _reload_system_state(self) -> None:
+ async with self._manager_lock:
+ self._workers = []
+ self._policy_loader = None
+ self._gc = None
+ await self._collect_already_running_workers()
+
+ async def reset_workers_policy_rules(self, _config: KresConfig) -> None:
+
+ # command all running 'kresd' workers to reset their old policy rules,
+ # unless the workers have already been started with a new config so reset is not needed
+ if self._workers_reset_needed and get_registered_workers_kresids():
+ logger.debug("Resetting policy rules for all running 'kresd' workers")
+ cmd_results = await command_registered_workers("require('ffi').C.kr_rules_reset()")
+ for worker, res in cmd_results.items():
+ if res != 0:
+ logger.error("Failed to reset policy rules in %s: %s", worker, res)
+ else:
+ logger.debug(
+ "Skipped resetting policy rules for all running 'kresd' workers:"
+ " the workers are already running with new configuration"
+ )
+
+ async def set_new_tls_sticket_secret(self, config: KresConfig) -> None:
+
+ if config.network.tls.sticket_secret or config.network.tls.sticket_secret_file:
+ logger.debug("User-configured TLS resumption secret found - skipping auto-generation.")
+ return
+
+ logger.debug("Creating TLS session ticket secret")
+ secret = token_hex(32)
+ logger.debug("Setting TLS session ticket secret for all running 'kresd' workers")
+ cmd_results = await command_registered_workers(f"net.tls_sticket_secret('{secret}')")
+ for worker, res in cmd_results.items():
+ if res not in (0, True):
+ logger.error("Failed to set TLS session ticket secret in %s: %s", worker, res)
+
+ async def apply_config(self, config: KresConfig, _noretry: bool = False) -> None:
+ try:
+ async with self._manager_lock:
+ logger.debug("Applying config to all workers")
+ await self._rolling_restart(config)
+ await self._ensure_number_of_children(config, int(config.workers))
+
+ if self._is_gc_running() != bool(config.cache.garbage_collector):
+ if config.cache.garbage_collector:
+ logger.debug("Starting cache GC")
+ await self._start_gc(config)
+ else:
+ logger.debug("Stopping cache GC")
+ await self._stop_gc()
+ except SubprocessControllerException as e:
+ if _noretry:
+ raise
+ elif self._fix_counter.is_too_high():
+ logger.error(f"Failed to apply config: {e}")
+ logger.error("There have already been problems recently, refusing to try to fix it.")
+ await self.forced_shutdown() # possible improvement - the person who requested this change won't get a response this way
+ else:
+ logger.error(f"Failed to apply config: {e}")
+ logger.warning("Reloading system state and trying again.")
+ self._fix_counter.increase()
+ await self._reload_system_state()
+ await self.apply_config(config, _noretry=True)
+
+ self._workers_reset_needed = False
+
+ async def load_policy_rules(self, _old: KresConfig, new: KresConfig) -> Result[NoneType, str]:
+ try:
+ async with self._manager_lock:
+ logger.debug("Running kresd 'policy-loader'")
+ await self._run_policy_loader(new)
+
+ # wait for 'policy-loader' to finish
+ logger.debug("Waiting for 'policy-loader' to finish loading policy rules")
+ while not self._is_policy_loader_exited():
+ await asyncio.sleep(1)
+
+ except (SubprocessError, SubprocessControllerException) as e:
+ logger.error(f"Failed to load policy rules: {e}")
+ return Result.err("kresd 'policy-loader' process failed to start. Config might be invalid.")
+
+ self._workers_reset_needed = True
+ logger.debug("Loading policy rules has been successfully completed")
+ return Result.ok(None)
+
+ async def stop(self):
+ if self._watchdog_task is not None:
+ self._watchdog_task.cancel() # cancel it
+ try:
+ await self._watchdog_task # and let it really finish
+ except asyncio.CancelledError:
+ pass
+
+ async with self._manager_lock:
+ # we could stop all the children one by one right now
+ # we won't do that and we leave that up to the subprocess controller to do that while it is shutting down
+ await self._controller.shutdown_controller()
+ # now, when everything is stopped, let's clean up all the remains
+ await asyncio.gather(*[w.cleanup() for w in self._workers])
+
+ async def forced_shutdown(self) -> None:
+ logger.warning("Collecting all remaining workers...")
+ await self._reload_system_state()
+ logger.warning("Terminating...")
+ self._shutdown_trigger(1)
+
+ async def _instability_handler(self) -> None:
+ if self._fix_counter.is_too_high():
+ logger.error(
+ "Already attempted too many times to fix system state. Refusing to try again and shutting down."
+ )
+ await self.forced_shutdown()
+ return
+
+ try:
+ logger.warning("Instability detected. Dropping known list of workers and reloading it from the system.")
+ self._fix_counter.increase()
+ await self._reload_system_state()
+ logger.warning("Workers reloaded. Applying old config....")
+ await self._config_store.renew()
+ logger.warning(f"System stability hopefully renewed. Fix attempt counter is currently {self._fix_counter}")
+ except BaseException:
+ logger.error("Failed attempting to fix an error. Forcefully shutting down.", exc_info=True)
+ await self.forced_shutdown()
+
+ async def _watchdog(self) -> None: # pylint: disable=too-many-branches
+ while True:
+ await asyncio.sleep(WATCHDOG_INTERVAL)
+
+ self._fix_counter.try_decrease()
+
+ try:
+ # gather current state
+ async with self._manager_lock:
+ detected_subprocesses = await self._controller.get_subprocess_status()
+ expected_ids = [x.id for x in self._workers]
+ if self._gc:
+ expected_ids.append(self._gc.id)
+
+ invoke_callback = False
+
+ if self._policy_loader:
+ expected_ids.append(self._policy_loader.id)
+
+ for eid in expected_ids:
+ if eid not in detected_subprocesses:
+ logger.error("Subprocess with id '%s' was not found in the system!", eid)
+ invoke_callback = True
+ continue
+
+ if detected_subprocesses[eid] is SubprocessStatus.FATAL:
+ if self._policy_loader and self._policy_loader.id == eid:
+ logger.info(
+ "Subprocess '%s' is skipped by WatchDog because its status is monitored in a different way.",
+ eid,
+ )
+ continue
+ logger.error("Subprocess '%s' is in FATAL state!", eid)
+ invoke_callback = True
+ continue
+
+ if detected_subprocesses[eid] is SubprocessStatus.UNKNOWN:
+ logger.warning("Subprocess '%s' is in UNKNOWN state!", eid)
+
+ non_registered_ids = detected_subprocesses.keys() - set(expected_ids)
+ if len(non_registered_ids) != 0:
+ logger.error(
+ "Found additional process in the system, which shouldn't be there - %s",
+ non_registered_ids,
+ )
+ invoke_callback = True
+
+ except asyncio.CancelledError:
+ raise
+ except BaseException:
+ invoke_callback = True
+ logger.error("Knot Resolver watchdog failed with an unexpected exception.", exc_info=True)
+
+ if invoke_callback:
+ try:
+ await self._instability_handler()
+ except Exception:
+ logger.error("Watchdog failed while invoking instability callback", exc_info=True)
+ logger.error("Violently terminating!")
+ sys.exit(1)
diff --git a/python/knot_resolver/manager/log.py b/python/knot_resolver/manager/log.py
new file mode 100644
index 00000000..a22898a5
--- /dev/null
+++ b/python/knot_resolver/manager/log.py
@@ -0,0 +1,105 @@
+import logging
+import logging.handlers
+import os
+import sys
+from typing import Optional
+
+from knot_resolver.manager.config_store import ConfigStore, only_on_real_changes_update
+from knot_resolver.manager.constants import STARTUP_LOG_LEVEL
+from knot_resolver.datamodel.config_schema import KresConfig
+from knot_resolver.datamodel.logging_schema import LogTargetEnum
+
+logger = logging.getLogger(__name__)
+
+
+def get_log_format(config: KresConfig) -> str:
+ """
+ Based on an environment variable $KRES_SUPRESS_LOG_PREFIX, returns the appropriate format string for logger.
+ """
+
+ if os.environ.get("KRES_SUPRESS_LOG_PREFIX") == "true":
+ # In this case, we are running under supervisord and it's adding prefixes to our output
+ return "[%(levelname)s] %(name)s: %(message)s"
+ else:
+ # In this case, we are running standalone during inicialization and we need to add a prefix to each line
+ # by ourselves to make it consistent
+ assert config.logging.target != "syslog"
+ stream = ""
+ if config.logging.target == "stderr":
+ stream = " (stderr)"
+
+ pid = os.getpid()
+ return f"%(asctime)s manager[{pid}]{stream}: [%(levelname)s] %(name)s: %(message)s"
+
+
+async def _set_log_level(config: KresConfig) -> None:
+ levels_map = {
+ "crit": "CRITICAL",
+ "err": "ERROR",
+ "warning": "WARNING",
+ "notice": "WARNING",
+ "info": "INFO",
+ "debug": "DEBUG",
+ }
+
+ # when logging group is set to make us log with DEBUG
+ if config.logging.groups and "manager" in config.logging.groups:
+ target = "DEBUG"
+ # otherwise, follow the standard log level
+ else:
+ target = levels_map[config.logging.level]
+
+ # expect exactly one existing log handler on the root
+ logger.warning(f"Changing logging level to '{target}'")
+ logging.getLogger().setLevel(target)
+
+
+async def _set_logging_handler(config: KresConfig) -> None:
+ target: Optional[LogTargetEnum] = config.logging.target
+
+ if target is None:
+ target = "stdout"
+
+ handler: logging.Handler
+ if target == "syslog":
+ handler = logging.handlers.SysLogHandler(address="/dev/log")
+ handler.setFormatter(logging.Formatter("%(name)s: %(message)s"))
+ elif target == "stdout":
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setFormatter(logging.Formatter(get_log_format(config)))
+ elif target == "stderr":
+ handler = logging.StreamHandler(sys.stderr)
+ handler.setFormatter(logging.Formatter(get_log_format(config)))
+ else:
+ raise RuntimeError(f"Unexpected value '{target}' for log target in the config")
+
+ root = logging.getLogger()
+
+ # if we had a MemoryHandler before, we should give it the new handler where we can flush it
+ if isinstance(root.handlers[0], logging.handlers.MemoryHandler):
+ root.handlers[0].setTarget(handler)
+
+ # stop the old handler
+ root.handlers[0].flush()
+ root.handlers[0].close()
+ root.removeHandler(root.handlers[0])
+
+ # configure the new handler
+ root.addHandler(handler)
+
+
+@only_on_real_changes_update(lambda config: config.logging)
+async def _configure_logger(config: KresConfig) -> None:
+ await _set_logging_handler(config)
+ await _set_log_level(config)
+
+
+async def logger_init(config_store: ConfigStore) -> None:
+ await config_store.register_on_change_callback(_configure_logger)
+
+
+def logger_startup() -> None:
+ logging.getLogger().setLevel(STARTUP_LOG_LEVEL)
+ err_handler = logging.StreamHandler(sys.stderr)
+ err_handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT))
+ logging.getLogger().addHandler(logging.handlers.MemoryHandler(10_000, logging.ERROR, err_handler))
diff --git a/python/knot_resolver/manager/main.py b/python/knot_resolver/manager/main.py
new file mode 100644
index 00000000..5facc470
--- /dev/null
+++ b/python/knot_resolver/manager/main.py
@@ -0,0 +1,49 @@
+"""
+Effectively the same as normal __main__.py. However, we moved it's content over to this
+file to allow us to exclude the __main__.py file from black's autoformatting
+"""
+
+import argparse
+import os
+import sys
+from pathlib import Path
+from typing import NoReturn
+
+from knot_resolver import compat
+from knot_resolver.manager.constants import CONFIG_FILE_ENV_VAR, DEFAULT_MANAGER_CONFIG_FILE
+from knot_resolver.manager.log import logger_startup
+from knot_resolver.manager.server import start_server
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="Knot Resolver - caching DNS resolver")
+ parser.add_argument(
+ "-c",
+ "--config",
+ help="Config file to load. Overrides default config location at '" + str(DEFAULT_MANAGER_CONFIG_FILE) + "'",
+ type=str,
+ nargs=1,
+ required=False,
+ default=None,
+ )
+ return parser.parse_args()
+
+
+def main() -> NoReturn:
+ # initial logging is to memory until we read the config
+ logger_startup()
+
+ # parse arguments
+ args = parse_args()
+
+ # where to look for config
+ config_env = os.getenv(CONFIG_FILE_ENV_VAR)
+ if args.config is not None:
+ config_path = Path(args.config[0])
+ elif config_env is not None:
+ config_path = Path(config_env)
+ else:
+ config_path = DEFAULT_MANAGER_CONFIG_FILE
+
+ exit_code = compat.asyncio.run(start_server(config=config_path))
+ sys.exit(exit_code)
diff --git a/python/knot_resolver/manager/server.py b/python/knot_resolver/manager/server.py
new file mode 100644
index 00000000..cf31a3fc
--- /dev/null
+++ b/python/knot_resolver/manager/server.py
@@ -0,0 +1,637 @@
+import asyncio
+import errno
+import json
+import logging
+import os
+import signal
+import sys
+from functools import partial
+from http import HTTPStatus
+from pathlib import Path
+from time import time
+from typing import Any, Dict, List, Optional, Set, Union, cast
+
+from aiohttp import web
+from aiohttp.web import middleware
+from aiohttp.web_app import Application
+from aiohttp.web_response import json_response
+from aiohttp.web_runner import AppRunner, TCPSite, UnixSite
+from typing_extensions import Literal
+
+import knot_resolver.utils.custom_atexit as atexit
+from knot_resolver.manager import log, statistics
+from knot_resolver.compat import asyncio as asyncio_compat
+from knot_resolver.manager.config_store import ConfigStore
+from knot_resolver.manager.constants import DEFAULT_MANAGER_CONFIG_FILE, PID_FILE_NAME, init_user_constants
+from knot_resolver.datamodel.cache_schema import CacheClearRPCSchema
+from knot_resolver.datamodel.config_schema import KresConfig, get_rundir_without_validation
+from knot_resolver.datamodel.globals import Context, set_global_validation_context
+from knot_resolver.datamodel.management_schema import ManagementSchema
+from knot_resolver.manager.exceptions import CancelStartupExecInsteadException, KresManagerException
+from knot_resolver.controller import get_best_controller_implementation
+from knot_resolver.controller.registered_workers import command_single_registered_worker
+from knot_resolver.utils import ignore_exceptions_optional
+from knot_resolver.utils.async_utils import readfile
+from knot_resolver.utils.etag import structural_etag
+from knot_resolver.utils.functional import Result
+from knot_resolver.utils.modeling.exceptions import (
+ AggregateDataValidationError,
+ DataParsingError,
+ DataValidationError,
+)
+from knot_resolver.utils.modeling.parsing import DataFormat, try_to_parse
+from knot_resolver.utils.modeling.query import query
+from knot_resolver.utils.modeling.types import NoneType
+from knot_resolver.utils.systemd_notify import systemd_notify
+
+from .kres_manager import KresManager
+
+logger = logging.getLogger(__name__)
+
+
+@middleware
+async def error_handler(request: web.Request, handler: Any) -> web.Response:
+ """
+ Generic error handler for route handlers.
+
+ If an exception is thrown during request processing, this middleware catches it
+ and responds accordingly.
+ """
+
+ try:
+ return await handler(request)
+ except DataValidationError as e:
+ return web.Response(text=f"validation of configuration failed:\n{e}", status=HTTPStatus.BAD_REQUEST)
+ except DataParsingError as e:
+ return web.Response(text=f"request processing error:\n{e}", status=HTTPStatus.BAD_REQUEST)
+ except KresManagerException as e:
+ return web.Response(text=f"request processing failed:\n{e}", status=HTTPStatus.INTERNAL_SERVER_ERROR)
+
+
+def from_mime_type(mime_type: str) -> DataFormat:
+ formats = {
+ "application/json": DataFormat.JSON,
+ "application/octet-stream": DataFormat.JSON, # default in aiohttp
+ }
+ if mime_type not in formats:
+ raise DataParsingError(f"unsupported MIME type '{mime_type}', expected: {str(formats)[1:-1]}")
+ return formats[mime_type]
+
+
+def parse_from_mime_type(data: str, mime_type: str) -> Any:
+ return from_mime_type(mime_type).parse_to_dict(data)
+
+
+class Server:
+ # pylint: disable=too-many-instance-attributes
+ # This is top-level class containing pretty much everything. Instead of global
+ # variables, we use instance attributes. That's why there are so many and it's
+ # ok.
+ def __init__(self, store: ConfigStore, config_path: Optional[Path]):
+ # config store & server dynamic reconfiguration
+ self.config_store = store
+
+ # HTTP server
+ self.app = Application(middlewares=[error_handler])
+ self.runner = AppRunner(self.app)
+ self.listen: Optional[ManagementSchema] = None
+ self.site: Union[NoneType, TCPSite, UnixSite] = None
+ self.listen_lock = asyncio.Lock()
+ self._config_path: Optional[Path] = config_path
+ self._exit_code: int = 0
+ self._shutdown_event = asyncio.Event()
+
+ async def _reconfigure(self, config: KresConfig) -> None:
+ await self._reconfigure_listen_address(config)
+
+ async def _deny_management_changes(self, config_old: KresConfig, config_new: KresConfig) -> Result[None, str]:
+ if config_old.management != config_new.management:
+ return Result.err(
+ "/server/management: Changing management API address/unix-socket dynamically is not allowed as it's really dangerous."
+ " If you really need this feature, please contact the developers and explain why. Technically,"
+ " there are no problems in supporting it. We are only blocking the dynamic changes because"
+ " we think the consequences of leaving this footgun unprotected are worse than its usefulness."
+ )
+ return Result.ok(None)
+
+ async def _reload_config(self) -> None:
+ if self._config_path is None:
+ logger.warning("The manager was started with inlined configuration - can't reload")
+ else:
+ try:
+ data = await readfile(self._config_path)
+ config = KresConfig(try_to_parse(data))
+ await self.config_store.update(config)
+ logger.info("Configuration file successfully reloaded")
+ except FileNotFoundError:
+ logger.error(
+ f"Configuration file was not found at '{self._config_path}'."
+ " Something must have happened to it while we were running."
+ )
+ logger.error("Configuration have NOT been changed.")
+ except (DataParsingError, DataValidationError) as e:
+ logger.error(f"Failed to parse the updated configuration file: {e}")
+ logger.error("Configuration have NOT been changed.")
+ except KresManagerException as e:
+ logger.error(f"Reloading of the configuration file failed: {e}")
+ logger.error("Configuration have NOT been changed.")
+
+ async def sigint_handler(self) -> None:
+ logger.info("Received SIGINT, triggering graceful shutdown")
+ self.trigger_shutdown(0)
+
+ async def sigterm_handler(self) -> None:
+ logger.info("Received SIGTERM, triggering graceful shutdown")
+ self.trigger_shutdown(0)
+
+ async def sighup_handler(self) -> None:
+ logger.info("Received SIGHUP, reloading configuration file")
+ systemd_notify(RELOADING="1")
+ await self._reload_config()
+ systemd_notify(READY="1")
+
+ @staticmethod
+ def all_handled_signals() -> Set[signal.Signals]:
+ return {signal.SIGHUP, signal.SIGINT, signal.SIGTERM}
+
+ def bind_signal_handlers(self):
+ asyncio_compat.add_async_signal_handler(signal.SIGTERM, self.sigterm_handler)
+ asyncio_compat.add_async_signal_handler(signal.SIGINT, self.sigint_handler)
+ asyncio_compat.add_async_signal_handler(signal.SIGHUP, self.sighup_handler)
+
+ def unbind_signal_handlers(self):
+ asyncio_compat.remove_signal_handler(signal.SIGTERM)
+ asyncio_compat.remove_signal_handler(signal.SIGINT)
+ asyncio_compat.remove_signal_handler(signal.SIGHUP)
+
+ async def start(self) -> None:
+ self._setup_routes()
+ await self.runner.setup()
+ await self.config_store.register_verifier(self._deny_management_changes)
+ await self.config_store.register_on_change_callback(self._reconfigure)
+
+ async def wait_for_shutdown(self) -> None:
+ await self._shutdown_event.wait()
+
+ def trigger_shutdown(self, exit_code: int) -> None:
+ self._shutdown_event.set()
+ self._exit_code = exit_code
+
+ async def _handler_index(self, _request: web.Request) -> web.Response:
+ """
+ Dummy index handler to indicate that the server is indeed running...
+ """
+ return json_response(
+ {
+ "msg": "Knot Resolver Manager is running! The configuration endpoint is at /config",
+ "status": "RUNNING",
+ }
+ )
+
+ async def _handler_config_query(self, request: web.Request) -> web.Response:
+ """
+ Route handler for changing resolver configuration
+ """
+ # There are a lot of local variables in here, but they are usually immutable (almost SSA form :) )
+ # pylint: disable=too-many-locals
+
+ # parse the incoming data
+ if request.method == "GET":
+ update_with: Optional[Dict[str, Any]] = None
+ else:
+ update_with = parse_from_mime_type(await request.text(), request.content_type)
+ document_path = request.match_info["path"]
+ getheaders = ignore_exceptions_optional(List[str], None, KeyError)(request.headers.getall)
+ etags = getheaders("if-match")
+ not_etags = getheaders("if-none-match")
+ current_config: Dict[str, Any] = self.config_store.get().get_unparsed_data()
+
+ # stop processing if etags
+ def strip_quotes(s: str) -> str:
+ return s.strip('"')
+
+ # WARNING: this check is prone to race conditions. When changing, make sure that the current config
+ # is really the latest current config (i.e. no await in between obtaining the config and the checks)
+ status = HTTPStatus.NOT_MODIFIED if request.method in ("GET", "HEAD") else HTTPStatus.PRECONDITION_FAILED
+ if etags is not None and structural_etag(current_config) not in map(strip_quotes, etags):
+ return web.Response(status=status)
+ if not_etags is not None and structural_etag(current_config) in map(strip_quotes, not_etags):
+ return web.Response(status=status)
+
+ # run query
+ op = cast(Literal["get", "delete", "patch", "put"], request.method.lower())
+ new_config, to_return = query(current_config, op, document_path, update_with)
+
+ # update the config
+ if request.method != "GET":
+ # validate
+ config_validated = KresConfig(new_config)
+ # apply
+ await self.config_store.update(config_validated)
+
+ # serialize the response (the `to_return` object is a Dict/list/scalar, we want to return json)
+ resp_text: Optional[str] = json.dumps(to_return) if to_return is not None else None
+
+ # create the response and return it
+ res = web.Response(status=HTTPStatus.OK, text=resp_text, content_type="application/json")
+ res.headers.add("ETag", f'"{structural_etag(new_config)}"')
+ return res
+
+ async def _handler_metrics(self, request: web.Request) -> web.Response:
+ raise web.HTTPMovedPermanently("/metrics/json")
+
+ async def _handler_metrics_json(self, _request: web.Request) -> web.Response:
+ return web.Response(
+ body=await statistics.report_stats(),
+ content_type="application/json",
+ charset="utf8",
+ )
+
+ async def _handler_metrics_prometheus(self, _request: web.Request) -> web.Response:
+
+ metrics_report = await statistics.report_stats(prometheus_format=True)
+ if not metrics_report:
+ raise web.HTTPNotFound()
+
+ return web.Response(
+ body=metrics_report,
+ content_type="text/plain",
+ charset="utf8",
+ )
+
+ async def _handler_cache_clear(self, request: web.Request) -> web.Response:
+ data = parse_from_mime_type(await request.text(), request.content_type)
+
+ try:
+ config = CacheClearRPCSchema(data)
+ except (AggregateDataValidationError, DataValidationError) as e:
+ return web.Response(
+ body=e,
+ status=HTTPStatus.BAD_REQUEST,
+ content_type="text/plain",
+ charset="utf8",
+ )
+
+ _, result = await command_single_registered_worker(config.render_lua())
+ return web.Response(
+ body=json.dumps(result),
+ content_type="application/json",
+ charset="utf8",
+ )
+
+ async def _handler_schema(self, _request: web.Request) -> web.Response:
+ return web.json_response(
+ KresConfig.json_schema(), headers={"Access-Control-Allow-Origin": "*"}, dumps=partial(json.dumps, indent=4)
+ )
+
+ async def _handle_view_schema(self, _request: web.Request) -> web.Response:
+ """
+ Provides a UI for visuallising and understanding JSON schema.
+
+ The feature in the Knot Resolver Manager to render schemas is unwanted, as it's completely
+ out of scope. However, it can be convinient. We therefore rely on a public web-based viewers
+ and provide just a redirect. If this feature ever breaks due to disapearance of the public
+ service, we can fix it. But we are not guaranteeing, that this will always work.
+ """
+
+ return web.Response(
+ text="""
+ <html>
+ <head><title>Redirect to schema viewer</title></head>
+ <body>
+ <script>
+ // we are using JS in order to use proper host
+ let protocol = window.location.protocol;
+ let host = window.location.host;
+ let url = encodeURIComponent(`${protocol}//${host}/schema`);
+ window.location.replace(`https://json-schema.app/view/%23?url=${url}`);
+ </script>
+ <h1>JavaScript required for a dynamic redirect...</h1>
+ </body>
+ </html>
+ """,
+ content_type="text/html",
+ )
+
+ async def _handler_stop(self, _request: web.Request) -> web.Response:
+ """
+ Route handler for shutting down the server (and whole manager)
+ """
+
+ self._shutdown_event.set()
+ logger.info("Shutdown event triggered...")
+ return web.Response(text="Shutting down...")
+
+ async def _handler_reload(self, _request: web.Request) -> web.Response:
+ """
+ Route handler for reloading the server
+ """
+
+ logger.info("Reloading event triggered...")
+ await self._reload_config()
+ return web.Response(text="Reloading...")
+
+ def _setup_routes(self) -> None:
+ self.app.add_routes(
+ [
+ web.get("/", self._handler_index),
+ web.get(r"/v1/config{path:.*}", self._handler_config_query),
+ web.put(r"/v1/config{path:.*}", self._handler_config_query),
+ web.delete(r"/v1/config{path:.*}", self._handler_config_query),
+ web.patch(r"/v1/config{path:.*}", self._handler_config_query),
+ web.post("/stop", self._handler_stop),
+ web.post("/reload", self._handler_reload),
+ web.get("/schema", self._handler_schema),
+ web.get("/schema/ui", self._handle_view_schema),
+ web.get("/metrics", self._handler_metrics),
+ web.get("/metrics/json", self._handler_metrics_json),
+ web.get("/metrics/prometheus", self._handler_metrics_prometheus),
+ web.post("/cache/clear", self._handler_cache_clear),
+ ]
+ )
+
+ async def _reconfigure_listen_address(self, config: KresConfig) -> None:
+ async with self.listen_lock:
+ mgn = config.management
+
+ # if the listen address did not change, do nothing
+ if self.listen == mgn:
+ return
+
+ # start the new listen address
+ nsite: Union[web.TCPSite, web.UnixSite]
+ if mgn.unix_socket:
+ nsite = web.UnixSite(self.runner, str(mgn.unix_socket))
+ logger.info(f"Starting API HTTP server on http+unix://{mgn.unix_socket}")
+ elif mgn.interface:
+ nsite = web.TCPSite(self.runner, str(mgn.interface.addr), int(mgn.interface.port))
+ logger.info(f"Starting API HTTP server on http://{mgn.interface.addr}:{mgn.interface.port}")
+ else:
+ raise KresManagerException("Requested API on unsupported configuration format.")
+ await nsite.start()
+
+ # stop the old listen
+ assert (self.listen is None) == (self.site is None)
+ if self.listen is not None and self.site is not None:
+ if self.listen.unix_socket:
+ logger.info(f"Stopping API HTTP server on http+unix://{mgn.unix_socket}")
+ elif self.listen.interface:
+ logger.info(
+ f"Stopping API HTTP server on http://{self.listen.interface.addr}:{self.listen.interface.port}"
+ )
+ await self.site.stop()
+
+ # save new state
+ self.listen = mgn
+ self.site = nsite
+
+ async def shutdown(self) -> None:
+ if self.site is not None:
+ await self.site.stop()
+ await self.runner.cleanup()
+
+ def get_exit_code(self) -> int:
+ return self._exit_code
+
+
+async def _load_raw_config(config: Union[Path, Dict[str, Any]]) -> Dict[str, Any]:
+ # Initial configuration of the manager
+ if isinstance(config, Path):
+ if not config.exists():
+ raise KresManagerException(
+ f"Manager is configured to load config file at {config} on startup, but the file does not exist."
+ )
+ else:
+ logger.info(f"Loading configuration from '{config}' file.")
+ config = try_to_parse(await readfile(config))
+
+ # validate the initial configuration
+ assert isinstance(config, dict)
+ return config
+
+
+async def _load_config(config: Dict[str, Any]) -> KresConfig:
+ config_validated = KresConfig(config)
+ return config_validated
+
+
+async def _init_config_store(config: Dict[str, Any]) -> ConfigStore:
+ config_validated = await _load_config(config)
+ config_store = ConfigStore(config_validated)
+ return config_store
+
+
+async def _init_manager(config_store: ConfigStore, server: Server) -> KresManager:
+ """
+ Called asynchronously when the application initializes.
+ """
+
+ # Instantiate subprocess controller (if we wanted to, we could switch it at this point)
+ controller = await get_best_controller_implementation(config_store.get())
+
+ # Create KresManager. This will perform autodetection of available service managers and
+ # select the most appropriate to use (or use the one configured directly)
+ manager = await KresManager.create(controller, config_store, server.trigger_shutdown)
+
+ logger.info("Initial configuration applied. Process manager initialized...")
+ return manager
+
+
+async def _deny_working_directory_changes(config_old: KresConfig, config_new: KresConfig) -> Result[None, str]:
+ if config_old.rundir != config_new.rundir:
+ return Result.err("Changing manager's `rundir` during runtime is not allowed.")
+
+ return Result.ok(None)
+
+
+def _set_working_directory(config_raw: Dict[str, Any]) -> None:
+ try:
+ rundir = get_rundir_without_validation(config_raw)
+ except ValueError as e:
+ raise DataValidationError(str(e), "/rundir") from e
+
+ logger.debug(f"Changing working directory to '{rundir.to_path().absolute()}'.")
+ os.chdir(rundir.to_path())
+
+
+def _lock_working_directory(attempt: int = 0) -> None:
+ # the following syscall is atomic, it's essentially the same as acquiring a lock
+ try:
+ pidfile_fd = os.open(PID_FILE_NAME, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644)
+ except OSError as e:
+ if e.errno == errno.EEXIST and attempt == 0:
+ # the pid file exists, let's check PID
+ with open(PID_FILE_NAME, "r", encoding="utf-8") as f:
+ pid = int(f.read().strip())
+ try:
+ os.kill(pid, 0)
+ except OSError as e2:
+ if e2.errno == errno.ESRCH:
+ os.unlink(PID_FILE_NAME)
+ _lock_working_directory(attempt=attempt + 1)
+ return
+ raise KresManagerException(
+ "Another manager is running in the same working directory."
+ f" PID file is located at {os.getcwd()}/{PID_FILE_NAME}"
+ ) from e
+ else:
+ raise KresManagerException(
+ "Another manager is running in the same working directory."
+ f" PID file is located at {os.getcwd()}/{PID_FILE_NAME}"
+ ) from e
+
+ # now we know that we are the only manager running in this directory
+
+ # write PID to the pidfile and close it afterwards
+ pidfile = os.fdopen(pidfile_fd, "w")
+ pid = os.getpid()
+ pidfile.write(f"{pid}\n")
+ pidfile.close()
+
+ # make sure that the file is deleted on shutdown
+ atexit.register(lambda: os.unlink(PID_FILE_NAME))
+
+
+async def _sigint_while_shutting_down():
+ logger.warning(
+ "Received SIGINT while already shutting down. Ignoring."
+ " If you want to forcefully stop the manager right now, use SIGTERM."
+ )
+
+
+async def _sigterm_while_shutting_down():
+ logger.warning("Received SIGTERM. Invoking dirty shutdown!")
+ sys.exit(128 + signal.SIGTERM)
+
+
+async def start_server(config: Path = DEFAULT_MANAGER_CONFIG_FILE) -> int:
+ # This function is quite long, but it describes how manager runs. So let's silence pylint
+ # pylint: disable=too-many-statements
+
+ start_time = time()
+ working_directory_on_startup = os.getcwd()
+ manager: Optional[KresManager] = None
+
+ # Block signals during initialization to force their processing once everything is ready
+ signal.pthread_sigmask(signal.SIG_BLOCK, Server.all_handled_signals())
+
+ # before starting server, initialize the subprocess controller, config store, etc. Any errors during inicialization
+ # are fatal
+ try:
+ # Make sure that the config path does not change meaning when we change working directory
+ config = config.absolute()
+
+ # Preprocess config - load from file or in general take it to the last step before validation.
+ config_raw = await _load_raw_config(config)
+
+ # before processing any configuration, set validation context
+ # - resolve_root = root against which all relative paths will be resolved
+ set_global_validation_context(Context(config.parent, True))
+
+ # We want to change cwd as soon as possible. Some parts of the codebase are using os.getcwd() to get the
+ # working directory.
+ #
+ # If we fail to read rundir from unparsed config, the first config validation error comes from here
+ _set_working_directory(config_raw)
+
+ # We don't want more than one manager in a single working directory. So we lock it with a PID file.
+ # Warning - this does not prevent multiple managers with the same naming of kresd service.
+ _lock_working_directory()
+
+ # set_global_validation_context(Context(config.parent))
+
+ # After the working directory is set, we can initialize proper config store with a newly parsed configuration.
+ config_store = await _init_config_store(config_raw)
+
+ # Some "constants" need to be loaded from the initial config, some need to be stored from the initial run conditions
+ await init_user_constants(config_store, working_directory_on_startup)
+
+ # This behaviour described above with paths means, that we MUST NOT allow `rundir` change after initialization.
+ # It would cause strange problems because every other path configuration depends on it. Therefore, we have to
+ # add a check to the config store, which disallows changes.
+ await config_store.register_verifier(_deny_working_directory_changes)
+
+ # Up to this point, we have been logging to memory buffer. But now, when we have the configuration loaded, we
+ # can flush the buffer into the proper place
+ await log.logger_init(config_store)
+
+ # With configuration on hand, we can initialize monitoring. We want to do this before any subprocesses are
+ # started, therefore before initializing manager
+ await statistics.init_monitoring(config_store)
+
+ # prepare instance of the server (no side effects)
+ server = Server(config_store, config)
+
+ # After we have loaded the configuration, we can start worring about subprocess management.
+ manager = await _init_manager(config_store, server)
+
+ except CancelStartupExecInsteadException as e:
+ # if we caught this exception, some component wants to perform a reexec during startup. Most likely, it would
+ # be a subprocess manager like supervisord, which wants to make sure the manager runs under supervisord in
+ # the process tree. So now we stop everything, and exec what we are told to. We are assuming, that the thing
+ # we'll exec will invoke us again.
+ logger.info("Exec requested with arguments: %s", str(e.exec_args))
+
+ # unblock signals, this could actually terminate us straight away
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, Server.all_handled_signals())
+
+ # run exit functions
+ atexit.run_callbacks()
+
+ # and finally exec what we were told to exec
+ os.execl(*e.exec_args)
+
+ except KresManagerException as e:
+ # We caught an error with a pretty error message. Just print it and exit.
+ logger.error(e)
+ return 1
+
+ except BaseException:
+ logger.error("Uncaught generic exception during manager inicialization...", exc_info=True)
+ return 1
+
+ # At this point, all backend functionality-providing components are initialized. It's therefore save to start
+ # the API server.
+ try:
+ await server.start()
+ except OSError as e:
+ if e.errno in (errno.EADDRINUSE, errno.EADDRNOTAVAIL):
+ # fancy error reporting of network binding errors
+ logger.error(str(e))
+ await manager.stop()
+ return 1
+ raise
+
+ # At this point, pretty much everything is ready to go. We should just make sure the user can shut
+ # the manager down with signals.
+ server.bind_signal_handlers()
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, Server.all_handled_signals())
+
+ logger.info(f"Manager fully initialized and running in {round(time() - start_time, 3)} seconds")
+
+ # notify systemd/anything compatible that we are ready
+ systemd_notify(READY="1")
+
+ await server.wait_for_shutdown()
+
+ # notify systemd that we are shutting down
+ systemd_notify(STOPPING="1")
+
+ # Ok, now we are tearing everything down.
+
+ # First of all, let's block all unwanted interruptions. We don't want to be reconfiguring kresd's while
+ # shutting down.
+ signal.pthread_sigmask(signal.SIG_BLOCK, Server.all_handled_signals())
+ server.unbind_signal_handlers()
+ # on the other hand, we want to immediatelly stop when the user really wants us to stop
+ asyncio_compat.add_async_signal_handler(signal.SIGTERM, _sigterm_while_shutting_down)
+ asyncio_compat.add_async_signal_handler(signal.SIGINT, _sigint_while_shutting_down)
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, {signal.SIGTERM, signal.SIGINT})
+
+ # After triggering shutdown, we neet to clean everything up
+ logger.info("Stopping API service...")
+ await server.shutdown()
+ logger.info("Stopping kresd manager...")
+ await manager.stop()
+ logger.info(f"The manager run for {round(time() - start_time)} seconds...")
+ return server.get_exit_code()
diff --git a/python/knot_resolver/manager/statistics.py b/python/knot_resolver/manager/statistics.py
new file mode 100644
index 00000000..292a480d
--- /dev/null
+++ b/python/knot_resolver/manager/statistics.py
@@ -0,0 +1,434 @@
+import asyncio
+import importlib
+import json
+import logging
+from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
+
+from knot_resolver import compat
+from knot_resolver.manager.config_store import ConfigStore, only_on_real_changes_update
+from knot_resolver.datamodel.config_schema import KresConfig
+from knot_resolver.controller.registered_workers import (
+ command_registered_workers,
+ get_registered_workers_kresids,
+)
+from knot_resolver.utils.functional import Result
+from knot_resolver.utils.modeling.parsing import DataFormat
+
+if TYPE_CHECKING:
+ from knot_resolver.controller.interface import KresID
+
+logger = logging.getLogger(__name__)
+
+
+_prometheus_support = False
+if importlib.util.find_spec("prometheus_client"):
+ _prometheus_support = True
+
+
+if _prometheus_support:
+ from prometheus_client import exposition # type: ignore
+ from prometheus_client.bridge.graphite import GraphiteBridge # type: ignore
+ from prometheus_client.core import GaugeMetricFamily # type: ignore
+ from prometheus_client.core import REGISTRY, CounterMetricFamily, HistogramMetricFamily, Metric
+
+ def _counter(name: str, description: str, label: Tuple[str, str], value: float) -> CounterMetricFamily:
+ c = CounterMetricFamily(name, description, labels=(label[0],))
+ c.add_metric((label[1],), value) # type: ignore
+ return c
+
+ def _gauge(name: str, description: str, label: Tuple[str, str], value: float) -> GaugeMetricFamily:
+ c = GaugeMetricFamily(name, description, labels=(label[0],))
+ c.add_metric((label[1],), value) # type: ignore
+ return c
+
+ def _histogram(
+ name: str, description: str, label: Tuple[str, str], buckets: List[Tuple[str, int]], sum_value: float
+ ) -> HistogramMetricFamily:
+ c = HistogramMetricFamily(name, description, labels=(label[0],))
+ c.add_metric((label[1],), buckets, sum_value=sum_value) # type: ignore
+ return c
+
+ def _parse_resolver_metrics(instance_id: "KresID", metrics: Any) -> Generator[Metric, None, None]:
+ sid = str(instance_id)
+
+ # response latency histogram
+ BUCKET_NAMES_IN_RESOLVER = ("1ms", "10ms", "50ms", "100ms", "250ms", "500ms", "1000ms", "1500ms", "slow")
+ BUCKET_NAMES_PROMETHEUS = ("0.001", "0.01", "0.05", "0.1", "0.25", "0.5", "1.0", "1.5", "+Inf")
+ yield _histogram(
+ "resolver_response_latency",
+ "Time it takes to respond to queries in seconds",
+ label=("instance_id", sid),
+ buckets=[
+ (bnp, metrics["answer"][f"{duration}"])
+ for bnp, duration in zip(BUCKET_NAMES_PROMETHEUS, BUCKET_NAMES_IN_RESOLVER)
+ ],
+ sum_value=metrics["answer"]["sum_ms"] / 1_000,
+ )
+
+ yield _counter(
+ "resolver_request_total",
+ "total number of DNS requests (including internal client requests)",
+ label=("instance_id", sid),
+ value=metrics["request"]["total"],
+ )
+ yield _counter(
+ "resolver_request_internal",
+ "number of internal requests generated by Knot Resolver (e.g. DNSSEC trust anchor updates)",
+ label=("instance_id", sid),
+ value=metrics["request"]["internal"],
+ )
+ yield _counter(
+ "resolver_request_udp",
+ "number of external requests received over plain UDP (RFC 1035)",
+ label=("instance_id", sid),
+ value=metrics["request"]["udp"],
+ )
+ yield _counter(
+ "resolver_request_tcp",
+ "number of external requests received over plain TCP (RFC 1035)",
+ label=("instance_id", sid),
+ value=metrics["request"]["tcp"],
+ )
+ yield _counter(
+ "resolver_request_dot",
+ "number of external requests received over DNS-over-TLS (RFC 7858)",
+ label=("instance_id", sid),
+ value=metrics["request"]["dot"],
+ )
+ yield _counter(
+ "resolver_request_doh",
+ "number of external requests received over DNS-over-HTTP (RFC 8484)",
+ label=("instance_id", sid),
+ value=metrics["request"]["doh"],
+ )
+ yield _counter(
+ "resolver_request_xdp",
+ "number of external requests received over plain UDP via an AF_XDP socket",
+ label=("instance_id", sid),
+ value=metrics["request"]["xdp"],
+ )
+ yield _counter(
+ "resolver_answer_total",
+ "total number of answered queries",
+ label=("instance_id", sid),
+ value=metrics["answer"]["total"],
+ )
+ yield _counter(
+ "resolver_answer_cached",
+ "number of queries answered from cache",
+ label=("instance_id", sid),
+ value=metrics["answer"]["cached"],
+ )
+ yield _counter(
+ "resolver_answer_stale",
+ "number of queries that utilized stale data",
+ label=("instance_id", sid),
+ value=metrics["answer"]["stale"],
+ )
+ yield _counter(
+ "resolver_answer_rcode_noerror",
+ "number of NOERROR answers",
+ label=("instance_id", sid),
+ value=metrics["answer"]["noerror"],
+ )
+ yield _counter(
+ "resolver_answer_rcode_nodata",
+ "number of NOERROR answers without any data",
+ label=("instance_id", sid),
+ value=metrics["answer"]["nodata"],
+ )
+ yield _counter(
+ "resolver_answer_rcode_nxdomain",
+ "number of NXDOMAIN answers",
+ label=("instance_id", sid),
+ value=metrics["answer"]["nxdomain"],
+ )
+ yield _counter(
+ "resolver_answer_rcode_servfail",
+ "number of SERVFAIL answers",
+ label=("instance_id", sid),
+ value=metrics["answer"]["servfail"],
+ )
+ yield _counter(
+ "resolver_answer_flag_aa",
+ "number of authoritative answers",
+ label=("instance_id", sid),
+ value=metrics["answer"]["aa"],
+ )
+ yield _counter(
+ "resolver_answer_flag_tc",
+ "number of truncated answers",
+ label=("instance_id", sid),
+ value=metrics["answer"]["tc"],
+ )
+ yield _counter(
+ "resolver_answer_flag_ra",
+ "number of answers with recursion available flag",
+ label=("instance_id", sid),
+ value=metrics["answer"]["ra"],
+ )
+ yield _counter(
+ "resolver_answer_flags_rd",
+ "number of recursion desired (in answer!)",
+ label=("instance_id", sid),
+ value=metrics["answer"]["rd"],
+ )
+ yield _counter(
+ "resolver_answer_flag_ad",
+ "number of authentic data (DNSSEC) answers",
+ label=("instance_id", sid),
+ value=metrics["answer"]["ad"],
+ )
+ yield _counter(
+ "resolver_answer_flag_cd",
+ "number of checking disabled (DNSSEC) answers",
+ label=("instance_id", sid),
+ value=metrics["answer"]["cd"],
+ )
+ yield _counter(
+ "resolver_answer_flag_do",
+ "number of DNSSEC answer OK",
+ label=("instance_id", sid),
+ value=metrics["answer"]["do"],
+ )
+ yield _counter(
+ "resolver_answer_flag_edns0",
+ "number of answers with EDNS0 present",
+ label=("instance_id", sid),
+ value=metrics["answer"]["edns0"],
+ )
+ yield _counter(
+ "resolver_query_edns",
+ "number of queries with EDNS present",
+ label=("instance_id", sid),
+ value=metrics["query"]["edns"],
+ )
+ yield _counter(
+ "resolver_query_dnssec",
+ "number of queries with DNSSEC DO=1",
+ label=("instance_id", sid),
+ value=metrics["query"]["dnssec"],
+ )
+
+ if "predict" in metrics:
+ if "epoch" in metrics["predict"]:
+ yield _counter(
+ "resolver_predict_epoch",
+ "current prediction epoch (based on time of day and sampling window)",
+ label=("instance_id", sid),
+ value=metrics["predict"]["epoch"],
+ )
+ yield _counter(
+ "resolver_predict_queue",
+ "number of queued queries in current window",
+ label=("instance_id", sid),
+ value=metrics["predict"]["queue"],
+ )
+ yield _counter(
+ "resolver_predict_learned",
+ "number of learned queries in current window",
+ label=("instance_id", sid),
+ value=metrics["predict"]["learned"],
+ )
+
+ def _create_resolver_metrics_loaded_gauge(kresid: "KresID", loaded: bool) -> GaugeMetricFamily:
+ return _gauge(
+ "resolver_metrics_loaded",
+ "0 if metrics from resolver instance were not loaded, otherwise 1",
+ label=("instance_id", str(kresid)),
+ value=int(loaded),
+ )
+
+ async def _deny_turning_off_graphite_bridge(old_config: KresConfig, new_config: KresConfig) -> Result[None, str]:
+ if old_config.monitoring.graphite and not new_config.monitoring.graphite:
+ return Result.err(
+ "You can't turn off graphite monitoring dynamically. If you really want this feature, please let the developers know."
+ )
+
+ if (
+ old_config.monitoring.graphite is not None
+ and new_config.monitoring.graphite is not None
+ and old_config.monitoring.graphite != new_config.monitoring.graphite
+ ):
+ return Result.err("Changing graphite exporter configuration in runtime is not allowed.")
+
+ return Result.ok(None)
+
+ _graphite_bridge: Optional[GraphiteBridge] = None
+
+ @only_on_real_changes_update(lambda c: c.monitoring.graphite)
+ async def _configure_graphite_bridge(config: KresConfig) -> None:
+ """
+ Starts graphite bridge if required
+ """
+ global _graphite_bridge
+ if config.monitoring.graphite is not False and _graphite_bridge is None:
+ logger.info(
+ "Starting Graphite metrics exporter for [%s]:%d",
+ str(config.monitoring.graphite.host),
+ int(config.monitoring.graphite.port),
+ )
+ _graphite_bridge = GraphiteBridge(
+ (str(config.monitoring.graphite.host), int(config.monitoring.graphite.port))
+ )
+ _graphite_bridge.start( # type: ignore
+ interval=config.monitoring.graphite.interval.seconds(), prefix=str(config.monitoring.graphite.prefix)
+ )
+
+
+class ResolverCollector:
+ def __init__(self, config_store: ConfigStore) -> None:
+ self._stats_raw: "Optional[Dict[KresID, object]]" = None
+ self._config_store: ConfigStore = config_store
+ self._collection_task: "Optional[asyncio.Task[None]]" = None
+ self._skip_immediate_collection: bool = False
+
+ if _prometheus_support:
+
+ def collect(self) -> Generator[Metric, None, None]:
+ # schedule new stats collection
+ self._trigger_stats_collection()
+
+ # if we have no data, return metrics with information about it and exit
+ if self._stats_raw is None:
+ for kresid in get_registered_workers_kresids():
+ yield _create_resolver_metrics_loaded_gauge(kresid, False)
+ return
+
+ # if we have data, parse them
+ for kresid in get_registered_workers_kresids():
+ success = False
+ try:
+ if kresid in self._stats_raw:
+ metrics = self._stats_raw[kresid]
+ yield from _parse_resolver_metrics(kresid, metrics)
+ success = True
+ except json.JSONDecodeError:
+ logger.warning(
+ "Failed to load metrics from resolver instance %s: failed to parse statistics", str(kresid)
+ )
+ except KeyError as e:
+ logger.warning(
+ "Failed to load metrics from resolver instance %s: attempted to read missing statistic %s",
+ str(kresid),
+ str(e),
+ )
+
+ yield _create_resolver_metrics_loaded_gauge(kresid, success)
+
+ def describe(self) -> List[Metric]:
+ # this function prevents the collector registry from invoking the collect function on startup
+ return []
+
+ def report_json(self) -> str:
+ # schedule new stats collection
+ self._trigger_stats_collection()
+
+ # if we have no data, return metrics with information about it and exit
+ if self._stats_raw is None:
+ no_stats_dict: Dict[str, None] = {}
+ for kresid in get_registered_workers_kresids():
+ no_stats_dict[str(kresid)] = None
+ return DataFormat.JSON.dict_dump(no_stats_dict)
+
+ stats_dict: Dict[str, object] = {}
+ for kresid, stats in self._stats_raw.items():
+ stats_dict[str(kresid)] = stats
+
+ return DataFormat.JSON.dict_dump(stats_dict)
+
+ async def collect_kresd_stats(self, _triggered_from_prometheus_library: bool = False) -> None:
+ if self._skip_immediate_collection:
+ # this would happen because we are calling this function first manually before stat generation,
+ # and once again immediately afterwards caused by the prometheus library's stat collection
+ #
+ # this is a code made to solve problem with calling async functions from sync methods
+ self._skip_immediate_collection = False
+ return
+
+ config = self._config_store.get()
+
+ if config.monitoring.enabled == "manager-only":
+ logger.debug("Skipping kresd stat collection due to configuration")
+ self._stats_raw = None
+ return
+
+ lazy = config.monitoring.enabled == "lazy"
+ cmd = "collect_lazy_statistics()" if lazy else "collect_statistics()"
+ logger.debug("Collecting kresd stats with method '%s'", cmd)
+ stats_raw = await command_registered_workers(cmd)
+ self._stats_raw = stats_raw
+
+ # if this function was not called by the prometheus library and calling collect() is imminent,
+ # we should block the next collection cycle as it would be useless
+ if not _triggered_from_prometheus_library:
+ self._skip_immediate_collection = True
+
+ def _trigger_stats_collection(self) -> None:
+ # we are running inside an event loop, but in a synchronous function and that sucks a lot
+ # it means that we shouldn't block the event loop by performing a blocking stats collection
+ # but it also means that we can't yield to the event loop as this function is synchronous
+ # therefore we can only start a new task, but we can't wait for it
+ # which causes the metrics to be delayed by one collection pass (not the best, but probably good enough)
+ #
+ # this issue can be prevented by calling the `collect_kresd_stats()` function manually before entering
+ # the Prometheus library. We just have to prevent the library from invoking it again. See the mentioned
+ # function for details
+
+ if compat.asyncio.is_event_loop_running():
+ # when running, we can schedule the new data collection
+ if self._collection_task is not None and not self._collection_task.done():
+ logger.warning("Statistics collection task is still running. Skipping scheduling of a new one!")
+ else:
+ self._collection_task = compat.asyncio.create_task(
+ self.collect_kresd_stats(_triggered_from_prometheus_library=True)
+ )
+
+ else:
+ # when not running, we can start a new loop (we are not in the manager's main thread)
+ compat.asyncio.run(self.collect_kresd_stats(_triggered_from_prometheus_library=True))
+
+
+_resolver_collector: Optional[ResolverCollector] = None
+
+
+async def _collect_stats() -> None:
+ # manually trigger stat collection so that we do not have to wait for it
+ if _resolver_collector is not None:
+ await _resolver_collector.collect_kresd_stats()
+ else:
+ raise RuntimeError("Function invoked before initializing the module!")
+
+
+async def report_stats(prometheus_format: bool = False) -> Optional[bytes]:
+ """
+ Collects metrics from everything, returns data string in JSON (default) or Prometheus format.
+ """
+
+ # manually trigger stat collection so that we do not have to wait for it
+ if _resolver_collector is not None:
+ await _resolver_collector.collect_kresd_stats()
+ else:
+ raise RuntimeError("Function invoked before initializing the module!")
+
+ if prometheus_format:
+ if _prometheus_support:
+ return exposition.generate_latest() # type: ignore
+ return None
+ return _resolver_collector.report_json().encode()
+
+
+async def init_monitoring(config_store: ConfigStore) -> None:
+ """
+ Initialize monitoring. Must be called before any other function from this module.
+ """
+ global _resolver_collector
+ _resolver_collector = ResolverCollector(config_store)
+
+ if _prometheus_support:
+ # register metrics collector
+ REGISTRY.register(_resolver_collector) # type: ignore
+
+ # register graphite bridge
+ await config_store.register_verifier(_deny_turning_off_graphite_bridge)
+ await config_store.register_on_change_callback(_configure_graphite_bridge)
diff --git a/python/knot_resolver/utils/__init__.py b/python/knot_resolver/utils/__init__.py
new file mode 100644
index 00000000..edc36fca
--- /dev/null
+++ b/python/knot_resolver/utils/__init__.py
@@ -0,0 +1,45 @@
+from typing import Any, Callable, Optional, Type, TypeVar
+
+T = TypeVar("T")
+
+
+def ignore_exceptions_optional(
+ _tp: Type[T], default: Optional[T], *exceptions: Type[BaseException]
+) -> Callable[[Callable[..., Optional[T]]], Callable[..., Optional[T]]]:
+ """
+ Decorator, that wraps around a function preventing it from raising exceptions
+ and instead returning the configured default value.
+
+ :param Type[T] _tp: Return type of the function. Essentialy only a template argument for type-checking
+ :param T default: The value to return as a default
+ :param List[Type[BaseException]] exceptions: The list of exceptions to catch
+ :return: value of the decorated function, or default if exception raised
+ :rtype: T
+ """
+
+ def decorator(func: Callable[..., Optional[T]]) -> Callable[..., Optional[T]]:
+ def f(*nargs: Any, **nkwargs: Any) -> Optional[T]:
+ try:
+ return func(*nargs, **nkwargs)
+ except BaseException as e:
+ if isinstance(e, exceptions):
+ return default
+ else:
+ raise e
+
+ return f
+
+ return decorator
+
+
+def ignore_exceptions(
+ default: T, *exceptions: Type[BaseException]
+) -> Callable[[Callable[..., Optional[T]]], Callable[..., Optional[T]]]:
+ return ignore_exceptions_optional(type(default), default, *exceptions)
+
+
+def phantom_use(var: Any) -> None: # pylint: disable=unused-argument
+ """
+ Function, which consumes its argument doing absolutely nothing with it. Useful
+ for convincing pylint, that we need the variable even when its unused.
+ """
diff --git a/python/knot_resolver/utils/async_utils.py b/python/knot_resolver/utils/async_utils.py
new file mode 100644
index 00000000..a5acdbd5
--- /dev/null
+++ b/python/knot_resolver/utils/async_utils.py
@@ -0,0 +1,129 @@
+import asyncio
+import os
+import pkgutil
+import signal
+import sys
+import time
+from asyncio import create_subprocess_exec, create_subprocess_shell
+from pathlib import PurePath
+from threading import Thread
+from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
+
+from knot_resolver.compat.asyncio import to_thread
+
+
+def unblock_signals():
+ if sys.version_info.major >= 3 and sys.version_info.minor >= 8:
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, signal.valid_signals()) # type: ignore
+ else:
+ # the list of signals is not exhaustive, but it should cover all signals we might ever want to block
+ signal.pthread_sigmask(
+ signal.SIG_UNBLOCK,
+ {
+ signal.SIGHUP,
+ signal.SIGINT,
+ signal.SIGTERM,
+ signal.SIGUSR1,
+ signal.SIGUSR2,
+ },
+ )
+
+
+async def call(
+ cmd: Union[str, bytes, List[str], List[bytes]], shell: bool = False, discard_output: bool = False
+) -> int:
+ """
+ custom async alternative to subprocess.call()
+ """
+ kwargs: Dict[str, Any] = {
+ "preexec_fn": unblock_signals,
+ }
+ if discard_output:
+ kwargs["stdout"] = asyncio.subprocess.DEVNULL
+ kwargs["stderr"] = asyncio.subprocess.DEVNULL
+
+ if shell:
+ if isinstance(cmd, list):
+ raise RuntimeError("can't use list of arguments with shell=True")
+ proc = await create_subprocess_shell(cmd, **kwargs)
+ else:
+ if not isinstance(cmd, list):
+ raise RuntimeError(
+ "Please use list of arguments, not a single string. It will prevent ambiguity when parsing"
+ )
+ proc = await create_subprocess_exec(*cmd, **kwargs)
+
+ return await proc.wait()
+
+
+async def readfile(path: Union[str, PurePath]) -> str:
+ """
+ asynchronously read whole file and return its content
+ """
+
+ def readfile_sync(path: Union[str, PurePath]) -> str:
+ with open(path, "r", encoding="utf8") as f:
+ return f.read()
+
+ return await to_thread(readfile_sync, path)
+
+
+async def writefile(path: Union[str, PurePath], content: str) -> None:
+ """
+ asynchronously set content of a file to a given string `content`.
+ """
+
+ def writefile_sync(path: Union[str, PurePath], content: str) -> int:
+ with open(path, "w", encoding="utf8") as f:
+ return f.write(content)
+
+ await to_thread(writefile_sync, path, content)
+
+
+async def wait_for_process_termination(pid: int, sleep_sec: float = 0) -> None:
+ """
+ will wait for any process (does not have to be a child process) given by its PID to terminate
+
+ sleep_sec configures the granularity, with which we should return
+ """
+
+ def wait_sync(pid: int, sleep_sec: float) -> None:
+ while True:
+ try:
+ os.kill(pid, 0)
+ if sleep_sec == 0:
+ os.sched_yield()
+ else:
+ time.sleep(sleep_sec)
+ except ProcessLookupError:
+ break
+
+ await to_thread(wait_sync, pid, sleep_sec)
+
+
+async def read_resource(package: str, filename: str) -> Optional[bytes]:
+ return await to_thread(pkgutil.get_data, package, filename)
+
+
+T = TypeVar("T")
+
+
+class BlockingEventDispatcher(Thread, Generic[T]):
+ def __init__(self, name: str = "blocking_event_dispatcher") -> None:
+ super().__init__(name=name, daemon=True)
+ # warning: the asyncio queue is not thread safe
+ self._removed_unit_names: "asyncio.Queue[T]" = asyncio.Queue()
+ self._main_event_loop = asyncio.get_event_loop()
+
+ def dispatch_event(self, event: T) -> None:
+ """
+ Method to dispatch events from the blocking thread
+ """
+
+ async def add_to_queue():
+ await self._removed_unit_names.put(event)
+
+ self._main_event_loop.call_soon_threadsafe(add_to_queue)
+
+ async def next_event(self) -> T:
+ return await self._removed_unit_names.get()
diff --git a/python/knot_resolver/utils/custom_atexit.py b/python/knot_resolver/utils/custom_atexit.py
new file mode 100644
index 00000000..2fe55433
--- /dev/null
+++ b/python/knot_resolver/utils/custom_atexit.py
@@ -0,0 +1,20 @@
+"""
+Custom replacement for standard module `atexit`. We use `atexit` behind the scenes, we just add the option
+to invoke the exit functions manually.
+"""
+
+import atexit
+from typing import Callable, List
+
+_at_exit_functions: List[Callable[[], None]] = []
+
+
+def register(func: Callable[[], None]) -> None:
+ _at_exit_functions.append(func)
+ atexit.register(func)
+
+
+def run_callbacks() -> None:
+ for func in _at_exit_functions:
+ func()
+ atexit.unregister(func)
diff --git a/python/knot_resolver/utils/etag.py b/python/knot_resolver/utils/etag.py
new file mode 100644
index 00000000..bb80700b
--- /dev/null
+++ b/python/knot_resolver/utils/etag.py
@@ -0,0 +1,10 @@
+import base64
+import json
+from hashlib import blake2b
+from typing import Any
+
+
+def structural_etag(obj: Any) -> str:
+ m = blake2b(digest_size=15)
+ m.update(json.dumps(obj, sort_keys=True).encode("utf8"))
+ return base64.urlsafe_b64encode(m.digest()).decode("utf8")
diff --git a/python/knot_resolver/utils/functional.py b/python/knot_resolver/utils/functional.py
new file mode 100644
index 00000000..43abd705
--- /dev/null
+++ b/python/knot_resolver/utils/functional.py
@@ -0,0 +1,72 @@
+from enum import Enum, auto
+from typing import Any, Callable, Generic, Iterable, TypeVar, Union
+
+T = TypeVar("T")
+
+
+def foldl(oper: Callable[[T, T], T], default: T, arr: Iterable[T]) -> T:
+ val = default
+ for x in arr:
+ val = oper(val, x)
+ return val
+
+
+def contains_element_matching(cond: Callable[[T], bool], arr: Iterable[T]) -> bool:
+ return foldl(lambda x, y: x or y, False, map(cond, arr))
+
+
+def all_matches(cond: Callable[[T], bool], arr: Iterable[T]) -> bool:
+ return foldl(lambda x, y: x and y, True, map(cond, arr))
+
+
+Succ = TypeVar("Succ")
+Err = TypeVar("Err")
+
+
+class _Status(Enum):
+ OK = auto()
+ ERROR = auto()
+
+
+class _ResultSentinel:
+ pass
+
+
+_RESULT_SENTINEL = _ResultSentinel()
+
+
+class Result(Generic[Succ, Err]):
+ @staticmethod
+ def ok(succ: T) -> "Result[T, Any]":
+ return Result(_Status.OK, succ=succ)
+
+ @staticmethod
+ def err(err: T) -> "Result[Any, T]":
+ return Result(_Status.ERROR, err=err)
+
+ def __init__(
+ self,
+ status: _Status,
+ succ: Union[Succ, _ResultSentinel] = _RESULT_SENTINEL,
+ err: Union[Err, _ResultSentinel] = _RESULT_SENTINEL,
+ ) -> None:
+ super().__init__()
+ self._status: _Status = status
+ self._succ: Union[_ResultSentinel, Succ] = succ
+ self._err: Union[_ResultSentinel, Err] = err
+
+ def unwrap(self) -> Succ:
+ assert self._status is _Status.OK
+ assert not isinstance(self._succ, _ResultSentinel)
+ return self._succ
+
+ def unwrap_err(self) -> Err:
+ assert self._status is _Status.ERROR
+ assert not isinstance(self._err, _ResultSentinel)
+ return self._err
+
+ def is_ok(self) -> bool:
+ return self._status is _Status.OK
+
+ def is_err(self) -> bool:
+ return self._status is _Status.ERROR
diff --git a/python/knot_resolver/utils/modeling/README.md b/python/knot_resolver/utils/modeling/README.md
new file mode 100644
index 00000000..97c68b54
--- /dev/null
+++ b/python/knot_resolver/utils/modeling/README.md
@@ -0,0 +1,155 @@
+# Modeling utils
+
+These utilities are used to model schemas for data stored in a python dictionary or YAML and JSON format.
+The utilities also take care of parsing, validating and creating JSON schemas and basic documentation.
+
+## Creating schema
+
+Schema is created using `ConfigSchema` class. Schema structure is specified using annotations.
+
+```python
+from .modeling import ConfigSchema
+
+class SimpleSchema(ConfigSchema):
+ integer: int = 5 # a default value can be specified
+ string: str
+ boolean: bool
+```
+Even more complex types can be used in a schema. Schemas can be also nested.
+Words in multi-word names are separated by underscore `_` (e.g. `simple_schema`).
+
+```python
+from typing import Dict, List, Optional, Union
+
+class ComplexSchema(ConfigSchema):
+ optional: Optional[str] # this field is optional
+ union: Union[int, str] # integer and string are both valid
+ list: List[int] # list of integers
+ dictionary: Dict[str, bool] = {"key": False}
+ simple_schema: SimpleSchema # nested schema
+```
+
+
+### Additianal validation
+
+If a some additional validation needs to be done, there is `_validate()` method for that.
+`ValueError` exception should be raised in case of validation error.
+
+```python
+class FieldsSchema(ConfigSchema):
+ field1: int
+ field2: int
+
+ def _validate(self) -> None:
+ if self.field1 > self.field2:
+ raise ValueError("field1 is bigger than field2")
+```
+
+
+### Additional layer, transformation methods
+
+It is possible to add layers to schema and use a transformation method between layers to process the value.
+Transformation method must be named based on field (`value` in this example) with `_` underscore prefix.
+In this example, the `Layer2Schema` is structure for input data and `Layer1Schema` is for result data.
+
+```python
+class Layer1Schema(ConfigSchema):
+ class Layer2Schema(ConfigSchema):
+ value: Union[str, int]
+
+ _LAYER = Layer2Schema
+
+ value: int
+
+ def _value(self, obj: Layer2Schema) -> Any:
+ if isinstance(str, obj.value):
+ return len(obj.value) # transform str values to int; this is just example
+ return obj.value
+```
+
+### Documentation and JSON schema
+
+Created schema can be documented using simple docstring. Json schema is created by calling `json_schema()` method on schema class. JSON schema includes description from docstring, defaults, etc.
+
+```python
+SimpleSchema(ConfigSchema):
+ """
+ This is description for SimpleSchema itself.
+
+ ---
+ integer: description for integer field
+ string: description for string field
+ boolean: description for boolean field
+ """
+
+ integer: int = 5
+ string: str
+ boolean: bool
+
+json_schema = SimpleSchema.json_schema()
+```
+
+
+## Creating custom type
+
+Custom types can be made by extending `BaseValueType` class which is integrated to parsing and validating process.
+Use `DataValidationError` to rase exception during validation. `object_path` is used to track node in more complex/nested schemas and create useful logging message.
+
+```python
+from .modeling import BaseValueType
+from .modeling.exceptions import DataValidationError
+
+class IntNonNegative(BaseValueType):
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ super().__init__(source_value)
+ if isinstance(source_value, int) and not isinstance(source_value, bool):
+ if source_value < 0:
+ raise DataValidationError(f"value {source_value} is negative number.", object_path)
+ self._value = source_value
+ else:
+ raise DataValidationError(
+ f"expected integer, got '{type(source_value)}'",
+ object_path,
+ )
+```
+
+For JSON schema you should implement `json_schema` method.
+It should return [JSON schema representation](https://json-schema.org/understanding-json-schema/index.html) of the custom type.
+
+```python
+ @classmethod
+ def json_schema(cls: Type["IntNonNegative"]) -> Dict[Any, Any]:
+ return {"type": "integer", "minimum": 0}
+```
+
+
+## Parsing JSON/YAML
+
+For example, YAML data for `ComplexSchema` can look like this.
+Words in multi-word names are separated by hyphen `-` (e.g. `simple-schema`).
+
+```yaml
+# data.yaml
+union: here could also be a number
+list: [1,2,3,]
+dictionary:
+ key": false
+simple-schema:
+ integer: 55
+ string: this is string
+ boolean: false
+```
+
+To parse data from YAML format just use `parse_yaml` function or `parse_json` for JSON format.
+Parsed data are stored in a dict-like object that takes care of `-`/`_` conversion.
+
+```python
+from .modeling import parse_yaml
+
+# read data from file
+with open("data.yaml") as f:
+ str_data = f.read()
+
+dict_data = parse_yaml(str_data)
+validated_data = ComplexSchema(dict_data)
+``` \ No newline at end of file
diff --git a/python/knot_resolver/utils/modeling/__init__.py b/python/knot_resolver/utils/modeling/__init__.py
new file mode 100644
index 00000000..d16f6c12
--- /dev/null
+++ b/python/knot_resolver/utils/modeling/__init__.py
@@ -0,0 +1,14 @@
+from .base_generic_type_wrapper import BaseGenericTypeWrapper
+from .base_schema import BaseSchema, ConfigSchema
+from .base_value_type import BaseValueType
+from .parsing import parse_json, parse_yaml, try_to_parse
+
+__all__ = [
+ "BaseGenericTypeWrapper",
+ "BaseValueType",
+ "BaseSchema",
+ "ConfigSchema",
+ "parse_yaml",
+ "parse_json",
+ "try_to_parse",
+]
diff --git a/python/knot_resolver/utils/modeling/base_generic_type_wrapper.py b/python/knot_resolver/utils/modeling/base_generic_type_wrapper.py
new file mode 100644
index 00000000..1f2c1767
--- /dev/null
+++ b/python/knot_resolver/utils/modeling/base_generic_type_wrapper.py
@@ -0,0 +1,9 @@
+from typing import Generic, TypeVar
+
+from .base_value_type import BaseTypeABC
+
+T = TypeVar("T")
+
+
+class BaseGenericTypeWrapper(Generic[T], BaseTypeABC): # pylint: disable=abstract-method
+ pass
diff --git a/python/knot_resolver/utils/modeling/base_schema.py b/python/knot_resolver/utils/modeling/base_schema.py
new file mode 100644
index 00000000..aca3be05
--- /dev/null
+++ b/python/knot_resolver/utils/modeling/base_schema.py
@@ -0,0 +1,816 @@
+import enum
+import inspect
+from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module]
+from typing import Any, Callable, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar, Union, cast
+
+import yaml
+
+from knot_resolver.utils.functional import all_matches
+
+from .base_generic_type_wrapper import BaseGenericTypeWrapper
+from .base_value_type import BaseValueType
+from .exceptions import AggregateDataValidationError, DataDescriptionError, DataValidationError
+from .renaming import Renamed, renamed
+from .types import (
+ get_generic_type_argument,
+ get_generic_type_arguments,
+ get_generic_type_wrapper_argument,
+ get_optional_inner_type,
+ is_dict,
+ is_enum,
+ is_generic_type_wrapper,
+ is_internal_field_name,
+ is_list,
+ is_literal,
+ is_none_type,
+ is_optional,
+ is_tuple,
+ is_union,
+)
+
+T = TypeVar("T")
+
+
+def is_obj_type(obj: Any, types: Union[type, Tuple[Any, ...], Tuple[type, ...]]) -> bool:
+ # To check specific type we are using 'type()' instead of 'isinstance()'
+ # because for example 'bool' is instance of 'int', 'isinstance(False, int)' returns True.
+ # pylint: disable=unidiomatic-typecheck
+ if isinstance(types, tuple):
+ return type(obj) in types
+ return type(obj) is types
+
+
+class Serializable(ABC):
+ """
+ An interface for making classes serializable to a dictionary (and in turn into a JSON).
+ """
+
+ @abstractmethod
+ def to_dict(self) -> Dict[Any, Any]:
+ raise NotImplementedError(f"...for class {self.__class__.__name__}")
+
+ @staticmethod
+ def is_serializable(typ: Type[Any]) -> bool:
+ return (
+ typ in {str, bool, int, float}
+ or is_none_type(typ)
+ or is_literal(typ)
+ or is_dict(typ)
+ or is_list(typ)
+ or is_generic_type_wrapper(typ)
+ or (inspect.isclass(typ) and issubclass(typ, Serializable))
+ or (inspect.isclass(typ) and issubclass(typ, BaseValueType))
+ or (inspect.isclass(typ) and issubclass(typ, BaseSchema))
+ or (is_optional(typ) and Serializable.is_serializable(get_optional_inner_type(typ)))
+ or (is_union(typ) and all_matches(Serializable.is_serializable, get_generic_type_arguments(typ)))
+ )
+
+ @staticmethod
+ def serialize(obj: Any) -> Any:
+ if isinstance(obj, Serializable):
+ return obj.to_dict()
+
+ elif isinstance(obj, (BaseValueType, BaseGenericTypeWrapper)):
+ o = obj.serialize()
+ # if Serializable.is_serializable(o):
+ return Serializable.serialize(o)
+ # return o
+
+ elif isinstance(obj, list):
+ res: List[Any] = [Serializable.serialize(i) for i in cast(List[Any], obj)]
+ return res
+
+ return obj
+
+
+class _lazy_default(Generic[T], Serializable):
+ """
+ Wrapper for default values BaseSchema classes which deffers their instantiation until the schema
+ itself is being instantiated
+ """
+
+ def __init__(self, constructor: Callable[..., T], *args: Any, **kwargs: Any) -> None:
+ # pylint: disable=[super-init-not-called]
+ self._func = constructor
+ self._args = args
+ self._kwargs = kwargs
+
+ def instantiate(self) -> T:
+ return self._func(*self._args, **self._kwargs)
+
+ def to_dict(self) -> Dict[Any, Any]:
+ return Serializable.serialize(self.instantiate())
+
+
+def lazy_default(constructor: Callable[..., T], *args: Any, **kwargs: Any) -> T:
+ """We use a factory function because you can't lie about the return type in `__new__`"""
+ return _lazy_default(constructor, *args, **kwargs) # type: ignore
+
+
+def _split_docstring(docstring: str) -> Tuple[str, Optional[str]]:
+ """
+ Splits docstring into description of the class and description of attributes
+ """
+
+ if "---" not in docstring:
+ return ("\n".join([s.strip() for s in docstring.splitlines()]).strip(), None)
+
+ doc, attrs_doc = docstring.split("---", maxsplit=1)
+ return (
+ "\n".join([s.strip() for s in doc.splitlines()]).strip(),
+ attrs_doc,
+ )
+
+
+def _parse_attrs_docstrings(docstring: str) -> Optional[Dict[str, str]]:
+ """
+ Given a docstring of a BaseSchema, return a dict with descriptions of individual attributes.
+ """
+
+ _, attrs_doc = _split_docstring(docstring)
+ if attrs_doc is None:
+ return None
+
+ # try to parse it as yaml:
+ data = yaml.safe_load(attrs_doc)
+ assert isinstance(data, dict), "Invalid format of attribute description"
+ return cast(Dict[str, str], data)
+
+
+def _get_properties_schema(typ: Type[Any]) -> Dict[Any, Any]:
+ schema: Dict[Any, Any] = {}
+ annot: Dict[str, Any] = typ.__dict__.get("__annotations__", {})
+ docstring: str = typ.__dict__.get("__doc__", "") or ""
+ attribute_documentation = _parse_attrs_docstrings(docstring)
+ for field_name, python_type in annot.items():
+ name = field_name.replace("_", "-")
+ schema[name] = _describe_type(python_type)
+
+ # description
+ if attribute_documentation is not None:
+ if field_name not in attribute_documentation:
+ raise DataDescriptionError(f"The docstring does not describe field '{field_name}'", str(typ))
+ schema[name]["description"] = attribute_documentation[field_name]
+ del attribute_documentation[field_name]
+
+ # default value
+ if hasattr(typ, field_name):
+ assert Serializable.is_serializable(
+ python_type
+ ), f"Type '{python_type}' does not appear to be JSON serializable"
+ schema[name]["default"] = Serializable.serialize(getattr(typ, field_name))
+
+ if attribute_documentation is not None and len(attribute_documentation) > 0:
+ raise DataDescriptionError(
+ f"The docstring describes attributes which are not present - {tuple(attribute_documentation.keys())}",
+ str(typ),
+ )
+
+ return schema
+
+
+def _describe_type(typ: Type[Any]) -> Dict[Any, Any]:
+ # pylint: disable=too-many-branches
+
+ if inspect.isclass(typ) and issubclass(typ, BaseSchema):
+ return typ.json_schema(include_schema_definition=False)
+
+ elif inspect.isclass(typ) and issubclass(typ, BaseValueType):
+ return typ.json_schema()
+
+ elif is_generic_type_wrapper(typ):
+ wrapped = get_generic_type_wrapper_argument(typ)
+ return _describe_type(wrapped)
+
+ elif is_none_type(typ):
+ return {"type": "null"}
+
+ elif typ == int:
+ return {"type": "integer"}
+
+ elif typ == bool:
+ return {"type": "boolean"}
+
+ elif typ == str:
+ return {"type": "string"}
+
+ elif is_literal(typ):
+ lit = get_generic_type_arguments(typ)
+ return {"type": "string", "enum": lit}
+
+ elif is_optional(typ):
+ desc = _describe_type(get_optional_inner_type(typ))
+ if "type" in desc:
+ desc["type"] = [desc["type"], "null"]
+ return desc
+ else:
+ return {"anyOf": [{"type": "null"}, desc]}
+
+ elif is_union(typ):
+ variants = get_generic_type_arguments(typ)
+ return {"anyOf": [_describe_type(v) for v in variants]}
+
+ elif is_list(typ):
+ return {"type": "array", "items": _describe_type(get_generic_type_argument(typ))}
+
+ elif is_dict(typ):
+ key, val = get_generic_type_arguments(typ)
+
+ if inspect.isclass(key) and issubclass(key, BaseValueType):
+ assert (
+ key.__str__ is not BaseValueType.__str__
+ ), "To support derived 'BaseValueType', __str__ must be implemented."
+ else:
+ assert key == str, "We currently do not support any other keys then strings"
+
+ return {"type": "object", "additionalProperties": _describe_type(val)}
+
+ elif inspect.isclass(typ) and issubclass(typ, enum.Enum): # same as our is_enum(typ), but inlined for type checker
+ return {"type": "string", "enum": [str(v) for v in typ]}
+
+ raise NotImplementedError(f"Trying to get JSON schema for type '{typ}', which is not implemented")
+
+
+TSource = Union[None, "BaseSchema", Dict[str, Any]]
+
+
+def _create_untouchable(name: str) -> object:
+ class _Untouchable:
+ def __getattribute__(self, item_name: str) -> Any:
+ raise RuntimeError(f"You are not supposed to access object '{name}'.")
+
+ def __setattr__(self, item_name: str, value: Any) -> None:
+ raise RuntimeError(f"You are not supposed to access object '{name}'.")
+
+ return _Untouchable()
+
+
+class ObjectMapper:
+ def _create_tuple(self, tp: Type[Any], obj: Tuple[Any, ...], object_path: str) -> Tuple[Any, ...]:
+ types = get_generic_type_arguments(tp)
+ errs: List[DataValidationError] = []
+ res: List[Any] = []
+ for i, (t, val) in enumerate(zip(types, obj)):
+ try:
+ res.append(self.map_object(t, val, object_path=f"{object_path}[{i}]"))
+ except DataValidationError as e:
+ errs.append(e)
+ if len(errs) == 1:
+ raise errs[0]
+ elif len(errs) > 1:
+ raise AggregateDataValidationError(object_path, child_exceptions=errs)
+ return tuple(res)
+
+ def _create_dict(self, tp: Type[Any], obj: Dict[Any, Any], object_path: str) -> Dict[Any, Any]:
+ key_type, val_type = get_generic_type_arguments(tp)
+ try:
+ errs: List[DataValidationError] = []
+ res: Dict[Any, Any] = {}
+ for key, val in obj.items():
+ try:
+ nkey = self.map_object(key_type, key, object_path=f"{object_path}[{key}]")
+ nval = self.map_object(val_type, val, object_path=f"{object_path}[{key}]")
+ res[nkey] = nval
+ except DataValidationError as e:
+ errs.append(e)
+ if len(errs) == 1:
+ raise errs[0]
+ elif len(errs) > 1:
+ raise AggregateDataValidationError(object_path, child_exceptions=errs)
+ return res
+ except AttributeError as e:
+ raise DataValidationError(
+ f"Expected dict-like object, but failed to access its .items() method. Value was {obj}", object_path
+ ) from e
+
+ def _create_list(self, tp: Type[Any], obj: List[Any], object_path: str) -> List[Any]:
+ if isinstance(obj, str):
+ raise DataValidationError("expected list, got string", object_path)
+
+ inner_type = get_generic_type_argument(tp)
+ errs: List[DataValidationError] = []
+ res: List[Any] = []
+
+ try:
+ for i, val in enumerate(obj):
+ res.append(self.map_object(inner_type, val, object_path=f"{object_path}[{i}]"))
+ if len(res) == 0:
+ raise DataValidationError("empty list is not allowed", object_path)
+ except DataValidationError as e:
+ errs.append(e)
+ except TypeError as e:
+ errs.append(DataValidationError(str(e), object_path))
+
+ if len(errs) == 1:
+ raise errs[0]
+ elif len(errs) > 1:
+ raise AggregateDataValidationError(object_path, child_exceptions=errs)
+ return res
+
+ def _create_str(self, obj: Any, object_path: str) -> str:
+ # we are willing to cast any primitive value to string, but no compound values are allowed
+ if is_obj_type(obj, (str, float, int)) or isinstance(obj, BaseValueType):
+ return str(obj)
+ elif is_obj_type(obj, bool):
+ raise DataValidationError(
+ "Expected str, found bool. Be careful, that YAML parsers consider even"
+ ' "no" and "yes" as a bool. Search for the Norway Problem for more'
+ " details. And please use quotes explicitly.",
+ object_path,
+ )
+ else:
+ raise DataValidationError(
+ f"expected str (or number that would be cast to string), but found type {type(obj)}", object_path
+ )
+
+ def _create_int(self, obj: Any, object_path: str) -> int:
+ # we don't want to make an int out of anything else than other int
+ # except for BaseValueType class instances
+ if is_obj_type(obj, int) or isinstance(obj, BaseValueType):
+ return int(obj)
+ raise DataValidationError(f"expected int, found {type(obj)}", object_path)
+
+ def _create_union(self, tp: Type[T], obj: Any, object_path: str) -> T:
+ variants = get_generic_type_arguments(tp)
+ errs: List[DataValidationError] = []
+ for v in variants:
+ try:
+ return self.map_object(v, obj, object_path=object_path)
+ except DataValidationError as e:
+ errs.append(e)
+
+ raise DataValidationError("could not parse any of the possible variants", object_path, child_exceptions=errs)
+
+ def _create_optional(self, tp: Type[Optional[T]], obj: Any, object_path: str) -> Optional[T]:
+ inner: Type[Any] = get_optional_inner_type(tp)
+ if obj is None:
+ return None
+ else:
+ return self.map_object(inner, obj, object_path=object_path)
+
+ def _create_bool(self, obj: Any, object_path: str) -> bool:
+ if is_obj_type(obj, bool):
+ return obj
+ else:
+ raise DataValidationError(f"expected bool, found {type(obj)}", object_path)
+
+ def _create_literal(self, tp: Type[Any], obj: Any, object_path: str) -> Any:
+ expected = get_generic_type_arguments(tp)
+ if obj in expected:
+ return obj
+ else:
+ raise DataValidationError(f"'{obj}' does not match any of the expected values {expected}", object_path)
+
+ def _create_base_schema_object(self, tp: Type[Any], obj: Any, object_path: str) -> "BaseSchema":
+ if isinstance(obj, (dict, BaseSchema)):
+ return tp(obj, object_path=object_path)
+ raise DataValidationError(f"expected 'dict' or 'NoRenameBaseSchema' object, found '{type(obj)}'", object_path)
+
+ def create_value_type_object(self, tp: Type[Any], obj: Any, object_path: str) -> "BaseValueType":
+ if isinstance(obj, tp):
+ # if we already have a custom value type, just pass it through
+ return obj
+ else:
+ # no validation performed, the implementation does it in the constuctor
+ try:
+ return tp(obj, object_path=object_path)
+ except ValueError as e:
+ if len(e.args) > 0 and isinstance(e.args[0], str):
+ msg = e.args[0]
+ else:
+ msg = f"Failed to validate value against {tp} type"
+ raise DataValidationError(msg, object_path) from e
+
+ def _create_default(self, obj: Any) -> Any:
+ if isinstance(obj, _lazy_default):
+ return obj.instantiate() # type: ignore
+ else:
+ return obj
+
+ def map_object(
+ self,
+ tp: Type[Any],
+ obj: Any,
+ default: Any = ...,
+ use_default: bool = False,
+ object_path: str = "/",
+ ) -> Any:
+ """
+ Given an expected type `cls` and a value object `obj`, return a new object of the given type and map fields of `obj` into it. During the mapping procedure,
+ runtime type checking is performed.
+ """
+
+ # Disabling these checks, because I think it's much more readable as a single function
+ # and it's not that large at this point. If it got larger, then we should definitely split it
+ # pylint: disable=too-many-branches,too-many-locals,too-many-statements
+
+ # default values
+ if obj is None and use_default:
+ return self._create_default(default)
+
+ # NoneType
+ elif is_none_type(tp):
+ if obj is None:
+ return None
+ else:
+ raise DataValidationError(f"expected None, found '{obj}'.", object_path)
+
+ # Optional[T] (could be technically handled by Union[*variants], but this way we have better error reporting)
+ elif is_optional(tp):
+ return self._create_optional(tp, obj, object_path)
+
+ # Union[*variants]
+ elif is_union(tp):
+ return self._create_union(tp, obj, object_path)
+
+ # after this, there is no place for a None object
+ elif obj is None:
+ raise DataValidationError(f"unexpected value 'None' for type {tp}", object_path)
+
+ # int
+ elif tp == int:
+ return self._create_int(obj, object_path)
+
+ # str
+ elif tp == str:
+ return self._create_str(obj, object_path)
+
+ # bool
+ elif tp == bool:
+ return self._create_bool(obj, object_path)
+
+ # float
+ elif tp == float:
+ raise NotImplementedError(
+ "Floating point values are not supported in the object mapper."
+ " Please implement them and be careful with type coercions"
+ )
+
+ # Literal[T]
+ elif is_literal(tp):
+ return self._create_literal(tp, obj, object_path)
+
+ # Dict[K,V]
+ elif is_dict(tp):
+ return self._create_dict(tp, obj, object_path)
+
+ # any Enums (probably used only internally in DataValidator)
+ elif is_enum(tp):
+ if isinstance(obj, tp):
+ return obj
+ else:
+ raise DataValidationError(f"unexpected value '{obj}' for enum '{tp}'", object_path)
+
+ # List[T]
+ elif is_list(tp):
+ return self._create_list(tp, obj, object_path)
+
+ # Tuple[A,B,C,D,...]
+ elif is_tuple(tp):
+ return self._create_tuple(tp, obj, object_path)
+
+ # type of obj and cls type match
+ elif is_obj_type(obj, tp):
+ return obj
+
+ # when the specified type is Any, just return the given value
+ # on mypy version 1.11.0 comparison-overlap error started popping up
+ # https://github.com/python/mypy/issues/17665
+ elif tp == Any: # type: ignore[comparison-overlap]
+ return obj
+
+ # BaseValueType subclasses
+ elif inspect.isclass(tp) and issubclass(tp, BaseValueType):
+ return self.create_value_type_object(tp, obj, object_path)
+
+ # BaseGenericTypeWrapper subclasses
+ elif is_generic_type_wrapper(tp):
+ inner_type = get_generic_type_wrapper_argument(tp)
+ obj_valid = self.map_object(inner_type, obj, object_path)
+ return tp(obj_valid, object_path=object_path) # type: ignore
+
+ # nested BaseSchema subclasses
+ elif inspect.isclass(tp) and issubclass(tp, BaseSchema):
+ return self._create_base_schema_object(tp, obj, object_path)
+
+ # if the object matches, just pass it through
+ elif inspect.isclass(tp) and isinstance(obj, tp):
+ return obj
+
+ # default error handler
+ else:
+ raise DataValidationError(
+ f"Type {tp} cannot be parsed. This is a implementation error. "
+ "Please fix your types in the class or improve the parser/validator.",
+ object_path,
+ )
+
+ def is_obj_type_valid(self, obj: Any, tp: Type[Any]) -> bool:
+ """
+ Runtime type checking. Validate, that a given object is of a given type.
+ """
+
+ try:
+ self.map_object(tp, obj)
+ return True
+ except (DataValidationError, ValueError):
+ return False
+
+ def _assign_default(self, obj: Any, name: str, python_type: Any, object_path: str) -> None:
+ cls = obj.__class__
+
+ try:
+ default = self._create_default(getattr(cls, name, None))
+ except ValueError as e:
+ raise DataValidationError(str(e), f"{object_path}/{name}") from e
+
+ value = self.map_object(python_type, default, object_path=f"{object_path}/{name}")
+ setattr(obj, name, value)
+
+ def _assign_field(self, obj: Any, name: str, python_type: Any, value: Any, object_path: str) -> None:
+ value = self.map_object(python_type, value, object_path=f"{object_path}/{name}")
+ setattr(obj, name, value)
+
+ def _assign_fields(self, obj: Any, source: Union[Dict[str, Any], "BaseSchema", None], object_path: str) -> Set[str]:
+ """
+ Order of assignment:
+ 1. all direct assignments
+ 2. assignments with conversion method
+ """
+ cls = obj.__class__
+ annot = cls.__dict__.get("__annotations__", {})
+ errs: List[DataValidationError] = []
+
+ used_keys: Set[str] = set()
+ for name, python_type in annot.items():
+ try:
+ if is_internal_field_name(name):
+ continue
+
+ # populate field
+ if source is None:
+ self._assign_default(obj, name, python_type, object_path)
+
+ # check for invalid configuration with both transformation function and default value
+ elif hasattr(obj, f"_{name}") and hasattr(obj, name):
+ raise RuntimeError(
+ f"Field '{obj.__class__.__name__}.{name}' has default value and transformation function at"
+ " the same time. That is now allowed. Store the default in the transformation function."
+ )
+
+ # there is a transformation function to create the value
+ elif hasattr(obj, f"_{name}") and callable(getattr(obj, f"_{name}")):
+ val = self._get_converted_value(obj, name, source, object_path)
+ self._assign_field(obj, name, python_type, val, object_path)
+ used_keys.add(name)
+
+ # source just contains the value
+ elif name in source:
+ val = source[name]
+ self._assign_field(obj, name, python_type, val, object_path)
+ used_keys.add(name)
+
+ # there is a default value, or the type is optional => store the default or null
+ elif hasattr(obj, name) or is_optional(python_type):
+ self._assign_default(obj, name, python_type, object_path)
+
+ # we expected a value but it was not there
+ else:
+ errs.append(DataValidationError(f"missing attribute '{name}'.", object_path))
+ except DataValidationError as e:
+ errs.append(e)
+
+ if len(errs) == 1:
+ raise errs[0]
+ elif len(errs) > 1:
+ raise AggregateDataValidationError(object_path, errs)
+ return used_keys
+
+ def _get_converted_value(self, obj: Any, key: str, source: TSource, object_path: str) -> Any:
+ """
+ Get a value of a field by invoking appropriate transformation function.
+ """
+ try:
+ func = getattr(obj.__class__, f"_{key}")
+ argc = len(inspect.signature(func).parameters)
+ if argc == 1:
+ # it is a static method
+ return func(source)
+ elif argc == 2:
+ # it is a instance method
+ return func(_create_untouchable("obj"), source)
+ else:
+ raise RuntimeError("Transformation function has wrong number of arguments")
+ except ValueError as e:
+ if len(e.args) > 0 and isinstance(e.args[0], str):
+ msg = e.args[0]
+ else:
+ msg = "Failed to validate value type"
+ raise DataValidationError(msg, object_path) from e
+
+ def object_constructor(self, obj: Any, source: Union["BaseSchema", Dict[Any, Any]], object_path: str) -> None:
+ """
+ Delegated constructor for the NoRenameBaseSchema class.
+
+ The reason this method is delegated to the mapper is due to renaming. Like this, we don't have to
+ worry about a different BaseSchema class, when we want to have dynamically renamed fields.
+ """
+ # As this is a delegated constructor, we must ignore protected access warnings
+ # pylint: disable=protected-access
+
+ # sanity check
+ if not isinstance(source, (BaseSchema, dict)): # type: ignore
+ raise DataValidationError(f"expected dict-like object, found '{type(source)}'", object_path)
+
+ # construct lower level schema first if configured to do so
+ if obj._LAYER is not None:
+ source = obj._LAYER(source, object_path=object_path) # pylint: disable=not-callable
+
+ # assign fields
+ used_keys = self._assign_fields(obj, source, object_path)
+
+ # check for unused keys in the source object
+ if source and not isinstance(source, BaseSchema):
+ unused = source.keys() - used_keys
+ if len(unused) > 0:
+ keys = ", ".join((f"'{u}'" for u in unused))
+ raise DataValidationError(
+ f"unexpected extra key(s) {keys}",
+ object_path,
+ )
+
+ # validate the constructed value
+ try:
+ obj._validate()
+ except ValueError as e:
+ raise DataValidationError(e.args[0] if len(e.args) > 0 else "Validation error", object_path or "/") from e
+
+
+class BaseSchema(Serializable):
+ """
+ Base class for modeling configuration schema. It somewhat resembles standard dataclasses with additional
+ functionality:
+
+ * type validation
+ * data conversion
+
+ To create an instance of this class, you have to provide source data in the form of dict-like object.
+ Generally, raw dict or another `BaseSchema` instance. The provided data object is traversed, transformed
+ and validated before assigned to the appropriate fields (attributes).
+
+ Fields (attributes)
+ ===================
+
+ The fields (or attributes) of the class are defined the same way as in a dataclass by creating a class-level
+ type-annotated fields. An example of that is:
+
+ class A(BaseSchema):
+ awesome_number: int
+
+ If your `BaseSchema` instance has a field with type of a BaseSchema, its value is recursively created
+ from the nested input data. This way, you can specify a complex tree of BaseSchema's and use the root
+ BaseSchema to create instance of everything.
+
+ Transformation
+ ==============
+
+ You can provide the BaseSchema class with a field and a function with the same name, but starting with
+ underscore ('_'). For example, you could have field called `awesome_number` and function called
+ `_awesome_number(self, source)`. The function takes one argument - the source data (optionally with self,
+ but you are not supposed to touch that). It can read any data from the source object and return a value of
+ an appropriate type, which will be assigned to the field `awesome_number`. If you want to report an error
+ during validation, raise a `ValueError` exception.
+
+ Using this, you can convert any input values into any type and field you want. To make the conversion easier
+ to write, you could also specify a special class variable called `_LAYER` pointing to another
+ BaseSchema class. This causes the source object to be first parsed as the specified additional layer of BaseSchema and after that
+ used a source for this class. This therefore allows nesting of transformation functions.
+
+ Validation
+ ==========
+
+ All assignments to fields during object construction are checked at runtime for proper types. This means,
+ you are free to use an untrusted source object and turn it into a data structure, where you are sure what
+ is what.
+
+ You can also define a `_validate` method, which will be called once the whole data structure is built. You
+ can validate the data in there and raise a `ValueError`, if they are invalid.
+
+ Default values
+ ==============
+
+ If you create a field with a value, it will be used as a default value whenever the data in source object
+ are not present. As a special case, default value for Optional type is None if not specified otherwise. You
+ are not allowed to have a field with a default value and a transformation function at once.
+
+ Example
+ =======
+
+ See tests/utils/test_modelling.py for example usage.
+ """
+
+ _LAYER: Optional[Type["BaseSchema"]] = None
+ _MAPPER: ObjectMapper = ObjectMapper()
+
+ def __init__(self, source: TSource = None, object_path: str = ""): # pylint: disable=[super-init-not-called]
+ # save source data (and drop information about nullness)
+ source = source or {}
+ self.__source: Union[Dict[str, Any], BaseSchema] = source
+
+ # delegate the rest of the constructor
+ self._MAPPER.object_constructor(self, source, object_path)
+
+ def get_unparsed_data(self) -> Dict[str, Any]:
+ if isinstance(self.__source, BaseSchema):
+ return self.__source.get_unparsed_data()
+ elif isinstance(self.__source, Renamed):
+ return self.__source.original()
+ else:
+ return self.__source
+
+ def __getitem__(self, key: str) -> Any:
+ if not hasattr(self, key):
+ raise RuntimeError(f"Object '{self}' of type '{type(self)}' does not have field named '{key}'")
+ return getattr(self, key)
+
+ def __contains__(self, item: Any) -> bool:
+ return hasattr(self, item)
+
+ def _validate(self) -> None:
+ """
+ Validation procedure called after all field are assigned. Should throw a ValueError in case of failure.
+ """
+
+ def __eq__(self, o: object) -> bool:
+ cls = self.__class__
+ if not isinstance(o, cls):
+ return False
+
+ annot = cls.__dict__.get("__annotations__", {})
+ for name in annot.keys():
+ if getattr(self, name) != getattr(o, name):
+ return False
+
+ return True
+
+ @classmethod
+ def json_schema(cls: Type["BaseSchema"], include_schema_definition: bool = True) -> Dict[Any, Any]:
+ if cls._LAYER is not None:
+ return cls._LAYER.json_schema(include_schema_definition=include_schema_definition)
+
+ schema: Dict[Any, Any] = {}
+ if include_schema_definition:
+ schema["$schema"] = "https://json-schema.org/draft/2020-12/schema"
+ if cls.__doc__ is not None:
+ schema["description"] = _split_docstring(cls.__doc__)[0]
+ schema["type"] = "object"
+ schema["properties"] = _get_properties_schema(cls)
+
+ return schema
+
+ def to_dict(self) -> Dict[Any, Any]:
+ res: Dict[Any, Any] = {}
+ cls = self.__class__
+ annot = cls.__dict__.get("__annotations__", {})
+
+ for name in annot:
+ res[name] = Serializable.serialize(getattr(self, name))
+ return res
+
+
+class RenamingObjectMapper(ObjectMapper):
+ """
+ Same as object mapper, but it uses collection wrappers from the module `renamed` to perform dynamic field renaming.
+
+ More specifically:
+ - it renames all properties in (nested) objects
+ - it does not rename keys in dictionaries
+ """
+
+ def _create_dict(self, tp: Type[Any], obj: Dict[Any, Any], object_path: str) -> Dict[Any, Any]:
+ if isinstance(obj, Renamed):
+ obj = obj.original()
+ return super()._create_dict(tp, obj, object_path)
+
+ def _create_base_schema_object(self, tp: Type[Any], obj: Any, object_path: str) -> "BaseSchema":
+ if isinstance(obj, dict):
+ obj = renamed(obj)
+ return super()._create_base_schema_object(tp, obj, object_path)
+
+ def object_constructor(self, obj: Any, source: Union["BaseSchema", Dict[Any, Any]], object_path: str) -> None:
+ if isinstance(source, dict):
+ source = renamed(source)
+ return super().object_constructor(obj, source, object_path)
+
+
+# export as a standalone functions for simplicity compatibility
+is_obj_type_valid = ObjectMapper().is_obj_type_valid
+map_object = ObjectMapper().map_object
+
+
+class ConfigSchema(BaseSchema):
+ """
+ Same as BaseSchema, but maps with RenamingObjectMapper
+ """
+
+ _MAPPER: ObjectMapper = RenamingObjectMapper()
diff --git a/python/knot_resolver/utils/modeling/base_value_type.py b/python/knot_resolver/utils/modeling/base_value_type.py
new file mode 100644
index 00000000..dff4a3fe
--- /dev/null
+++ b/python/knot_resolver/utils/modeling/base_value_type.py
@@ -0,0 +1,45 @@
+from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module]
+from typing import Any, Dict, Type
+
+
+class BaseTypeABC(ABC):
+ @abstractmethod
+ def __init__(self, source_value: Any, object_path: str = "/") -> None:
+ pass
+
+ @abstractmethod
+ def __int__(self) -> int:
+ raise NotImplementedError(f" return 'int()' value for {type(self).__name__} is not implemented.")
+
+ @abstractmethod
+ def __str__(self) -> str:
+ raise NotImplementedError(f"return 'str()' value for {type(self).__name__} is not implemented.")
+
+ @abstractmethod
+ def serialize(self) -> Any:
+ """
+ Used for dumping configuration. Returns a JSON-serializable object from which the object
+ can be recreated again using the constructor.
+
+ It's not necessary to return the same structure that was given as an input. It only has
+ to be the same semantically.
+ """
+ raise NotImplementedError(f"{type(self).__name__}'s' 'serialize()' not implemented.")
+
+
+class BaseValueType(BaseTypeABC):
+ """
+ Subclasses of this class can be used as type annotations in 'DataParser'. When a value
+ is being parsed from a serialized format (e.g. JSON/YAML), an object will be created by
+ calling the constructor of the appropriate type on the field value. The only limitation
+ is that the value MUST NOT be `None`.
+
+ There is no validation done on the wrapped value. The only condition is that
+ it can't be `None`. If you want to perform any validation during creation,
+ raise a `ValueError` in case of errors.
+ """
+
+ @classmethod
+ @abstractmethod
+ def json_schema(cls: Type["BaseValueType"]) -> Dict[Any, Any]:
+ raise NotImplementedError()
diff --git a/python/knot_resolver/utils/modeling/exceptions.py b/python/knot_resolver/utils/modeling/exceptions.py
new file mode 100644
index 00000000..ea057339
--- /dev/null
+++ b/python/knot_resolver/utils/modeling/exceptions.py
@@ -0,0 +1,63 @@
+from typing import Iterable, List
+
+from knot_resolver.manager.exceptions import KresManagerException
+
+
+class DataModelingBaseException(KresManagerException):
+ """
+ Base class for all exceptions used in modelling.
+ """
+
+
+class DataParsingError(DataModelingBaseException):
+ pass
+
+
+class DataDescriptionError(DataModelingBaseException):
+ pass
+
+
+class DataValidationError(DataModelingBaseException):
+ def __init__(self, msg: str, tree_path: str, child_exceptions: "Iterable[DataValidationError]" = tuple()) -> None:
+ super().__init__(msg)
+ self._tree_path = tree_path.replace("_", "-")
+ self._child_exceptions = child_exceptions
+
+ def where(self) -> str:
+ return self._tree_path
+
+ def msg(self):
+ return f"[{self.where()}] {super().__str__()}"
+
+ def recursive_msg(self, indentation_level: int = 0) -> str:
+ msg_parts: List[str] = []
+
+ if indentation_level == 0:
+ indentation_level += 1
+ msg_parts.append("Configuration validation error detected:")
+
+ INDENT = indentation_level * "\t"
+ msg_parts.append(f"{INDENT}{self.msg()}")
+
+ for c in self._child_exceptions:
+ msg_parts.append(c.recursive_msg(indentation_level + 1))
+ return "\n".join(msg_parts)
+
+ def __str__(self) -> str:
+ return self.recursive_msg()
+
+
+class AggregateDataValidationError(DataValidationError):
+ def __init__(self, object_path: str, child_exceptions: "Iterable[DataValidationError]") -> None:
+ super().__init__("error due to lower level exceptions", object_path, child_exceptions)
+
+ def recursive_msg(self, indentation_level: int = 0) -> str:
+ inc = 0
+ msg_parts: List[str] = []
+ if indentation_level == 0:
+ inc = 1
+ msg_parts.append("Configuration validation errors detected:")
+
+ for c in self._child_exceptions:
+ msg_parts.append(c.recursive_msg(indentation_level + inc))
+ return "\n".join(msg_parts)
diff --git a/python/knot_resolver/utils/modeling/json_pointer.py b/python/knot_resolver/utils/modeling/json_pointer.py
new file mode 100644
index 00000000..a60ba5d1
--- /dev/null
+++ b/python/knot_resolver/utils/modeling/json_pointer.py
@@ -0,0 +1,88 @@
+"""
+Implements JSON pointer resolution based on RFC 6901:
+https://www.rfc-editor.org/rfc/rfc6901
+"""
+
+from typing import Any, Optional, Tuple, Union
+
+# JSONPtrAddressable = Optional[Union[Dict[str, "JSONPtrAddressable"], List["JSONPtrAddressable"], int, float, bool, str, None]]
+JSONPtrAddressable = Any # the recursive definition above is not valid :(
+
+
+class _JSONPtr:
+ @staticmethod
+ def _decode_token(token: str) -> str:
+ """
+ Resolves escaped characters ~ and /
+ """
+
+ # the order of the replace statements is important, do not change without
+ # consulting the RFC
+ return token.replace("~1", "/").replace("~0", "~")
+
+ @staticmethod
+ def _encode_token(token: str) -> str:
+ return token.replace("~", "~0").replace("/", "~1")
+
+ def __init__(self, ptr: str):
+ if ptr == "":
+ # pointer to the root
+ self.tokens = []
+
+ else:
+ if ptr[0] != "/":
+ raise SyntaxError(
+ f"JSON pointer '{ptr}' invalid: the first character MUST be '/' or the pointer must be empty"
+ )
+
+ ptr = ptr[1:]
+ self.tokens = [_JSONPtr._decode_token(tok) for tok in ptr.split("/")]
+
+ def resolve(
+ self, obj: JSONPtrAddressable
+ ) -> Tuple[Optional[JSONPtrAddressable], JSONPtrAddressable, Union[str, int, None]]:
+ """
+ Returns (Optional[parent], Optional[direct value], key of value in the parent object)
+ """
+
+ parent: Optional[JSONPtrAddressable] = None
+ current = obj
+ current_ptr = ""
+ token: Union[int, str, None] = None
+
+ for token in self.tokens:
+ if current is None:
+ raise ValueError(
+ f"JSON pointer cannot reference nested non-existent object: object at ptr '{current_ptr}' already points to None, cannot nest deeper with token '{token}'"
+ )
+
+ elif isinstance(current, (bool, int, float, str)):
+ raise ValueError(f"object at '{current_ptr}' is a scalar, JSON pointer cannot point into it")
+
+ else:
+ parent = current
+ if isinstance(current, list):
+ if token == "-":
+ current = None
+ else:
+ try:
+ token = int(token)
+ current = current[token]
+ except ValueError as e:
+ raise ValueError(
+ f"invalid JSON pointer: list '{current_ptr}' require numbers as keys, instead got '{token}'"
+ ) from e
+
+ elif isinstance(current, dict):
+ current = current.get(token, None)
+
+ current_ptr += f"/{token}"
+
+ return parent, current, token
+
+
+def json_ptr_resolve(
+ obj: JSONPtrAddressable,
+ ptr: str,
+) -> Tuple[Optional[JSONPtrAddressable], Optional[JSONPtrAddressable], Union[str, int, None]]:
+ return _JSONPtr(ptr).resolve(obj)
diff --git a/python/knot_resolver/utils/modeling/parsing.py b/python/knot_resolver/utils/modeling/parsing.py
new file mode 100644
index 00000000..185a53a1
--- /dev/null
+++ b/python/knot_resolver/utils/modeling/parsing.py
@@ -0,0 +1,99 @@
+import json
+from enum import Enum, auto
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import yaml
+from yaml.constructor import ConstructorError
+from yaml.nodes import MappingNode
+
+from .exceptions import DataParsingError
+from .renaming import Renamed, renamed
+
+
+# custom hook for 'json.loads()' to detect duplicate keys in data
+# source: https://stackoverflow.com/q/14902299/12858520
+def _json_raise_duplicates(pairs: List[Tuple[Any, Any]]) -> Optional[Any]:
+ dict_out: Dict[Any, Any] = {}
+ for key, val in pairs:
+ if key in dict_out:
+ raise DataParsingError(f"Duplicate attribute key detected: {key}")
+ dict_out[key] = val
+ return dict_out
+
+
+# custom loader for 'yaml.load()' to detect duplicate keys in data
+# source: https://gist.github.com/pypt/94d747fe5180851196eb
+class _RaiseDuplicatesLoader(yaml.SafeLoader):
+ def construct_mapping(self, node: Union[MappingNode, Any], deep: bool = False) -> Dict[Any, Any]:
+ if not isinstance(node, MappingNode):
+ raise ConstructorError(None, None, f"expected a mapping node, but found {node.id}", node.start_mark)
+ mapping: Dict[Any, Any] = {}
+ for key_node, value_node in node.value:
+ key = self.construct_object(key_node, deep=deep) # type: ignore
+ # we need to check, that the key object can be used in a hash table
+ try:
+ _ = hash(key) # type: ignore
+ except TypeError as exc:
+ raise ConstructorError(
+ "while constructing a mapping",
+ node.start_mark,
+ f"found unacceptable key ({exc})",
+ key_node.start_mark,
+ ) from exc
+
+ # check for duplicate keys
+ if key in mapping:
+ raise DataParsingError(f"duplicate key detected: {key_node.start_mark}")
+ value = self.construct_object(value_node, deep=deep) # type: ignore
+ mapping[key] = value
+ return mapping
+
+
+class DataFormat(Enum):
+ YAML = auto()
+ JSON = auto()
+
+ def parse_to_dict(self, text: str) -> Any:
+ if self is DataFormat.YAML:
+ # RaiseDuplicatesLoader extends yaml.SafeLoader, so this should be safe
+ # https://python.land/data-processing/python-yaml#PyYAML_safe_load_vs_load
+ return renamed(yaml.load(text, Loader=_RaiseDuplicatesLoader)) # type: ignore
+ elif self is DataFormat.JSON:
+ return renamed(json.loads(text, object_pairs_hook=_json_raise_duplicates))
+ else:
+ raise NotImplementedError(f"Parsing of format '{self}' is not implemented")
+
+ def dict_dump(self, data: Union[Dict[str, Any], Renamed], indent: Optional[int] = None) -> str:
+ if isinstance(data, Renamed):
+ data = data.original()
+
+ if self is DataFormat.YAML:
+ return yaml.safe_dump(data, indent=indent) # type: ignore
+ elif self is DataFormat.JSON:
+ return json.dumps(data, indent=indent)
+ else:
+ raise NotImplementedError(f"Exporting to '{self}' format is not implemented")
+
+
+def parse_yaml(data: str) -> Any:
+ return DataFormat.YAML.parse_to_dict(data)
+
+
+def parse_json(data: str) -> Any:
+ return DataFormat.JSON.parse_to_dict(data)
+
+
+def try_to_parse(data: str) -> Any:
+ """Attempt to parse the data as a JSON or YAML string."""
+
+ try:
+ return parse_json(data)
+ except json.JSONDecodeError as je:
+ try:
+ return parse_yaml(data)
+ except yaml.YAMLError as ye:
+ # We do not raise-from here because there are two possible causes
+ # and we may not know which one is the actual one.
+ raise DataParsingError( # pylint: disable=raise-missing-from
+ f"failed to parse data, JSON: {je}, YAML: {ye}"
+ )
diff --git a/python/knot_resolver/utils/modeling/query.py b/python/knot_resolver/utils/modeling/query.py
new file mode 100644
index 00000000..aab0be0e
--- /dev/null
+++ b/python/knot_resolver/utils/modeling/query.py
@@ -0,0 +1,183 @@
+import copy
+from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module]
+from typing import Any, List, Optional, Tuple, Union
+
+from typing_extensions import Literal
+
+from knot_resolver.utils.modeling.base_schema import BaseSchema, map_object
+from knot_resolver.utils.modeling.json_pointer import json_ptr_resolve
+
+
+class PatchError(Exception):
+ pass
+
+
+class Op(BaseSchema, ABC):
+ @abstractmethod
+ def eval(self, fakeroot: Any) -> Any:
+ """
+ modifies the given fakeroot, returns a new one
+ """
+
+ def _resolve_ptr(self, fakeroot: Any, ptr: str) -> Tuple[Any, Any, Union[str, int, None]]:
+ # Lookup tree part based on the given JSON pointer
+ parent, obj, token = json_ptr_resolve(fakeroot["root"], ptr)
+
+ # the lookup was on pure data, wrap the results in QueryTree
+ if parent is None:
+ parent = fakeroot
+ token = "root"
+
+ assert token is not None
+
+ return parent, obj, token
+
+
+class AddOp(Op):
+ op: Literal["add"]
+ path: str
+ value: Any
+
+ def eval(self, fakeroot: Any) -> Any:
+ parent, _obj, token = self._resolve_ptr(fakeroot, self.path)
+
+ if isinstance(parent, dict):
+ parent[token] = self.value
+ elif isinstance(parent, list):
+ if token == "-":
+ parent.append(self.value)
+ else:
+ assert isinstance(token, int)
+ parent.insert(token, self.value)
+ else:
+ assert False, "never happens"
+
+ return fakeroot
+
+
+class RemoveOp(Op):
+ op: Literal["remove"]
+ path: str
+
+ def eval(self, fakeroot: Any) -> Any:
+ parent, _obj, token = self._resolve_ptr(fakeroot, self.path)
+ del parent[token]
+ return fakeroot
+
+
+class ReplaceOp(Op):
+ op: Literal["replace"]
+ path: str
+ value: str
+
+ def eval(self, fakeroot: Any) -> Any:
+ parent, obj, token = self._resolve_ptr(fakeroot, self.path)
+
+ if obj is None:
+ raise PatchError("the value you are trying to replace is null")
+ parent[token] = self.value
+ return fakeroot
+
+
+class MoveOp(Op):
+ op: Literal["move"]
+ source: str
+ path: str
+
+ def _source(self, source):
+ if "from" not in source:
+ raise ValueError("missing property 'from' in 'move' JSON patch operation")
+ return str(source["from"])
+
+ def eval(self, fakeroot: Any) -> Any:
+ if self.path.startswith(self.source):
+ raise PatchError("can't move value into itself")
+
+ _parent, obj, _token = self._resolve_ptr(fakeroot, self.source)
+ newobj = copy.deepcopy(obj)
+
+ fakeroot = RemoveOp({"op": "remove", "path": self.source}).eval(fakeroot)
+ fakeroot = AddOp({"path": self.path, "value": newobj, "op": "add"}).eval(fakeroot)
+ return fakeroot
+
+
+class CopyOp(Op):
+ op: Literal["copy"]
+ source: str
+ path: str
+
+ def _source(self, source):
+ if "from" not in source:
+ raise ValueError("missing property 'from' in 'copy' JSON patch operation")
+ return str(source["from"])
+
+ def eval(self, fakeroot: Any) -> Any:
+ _parent, obj, _token = self._resolve_ptr(fakeroot, self.source)
+ newobj = copy.deepcopy(obj)
+
+ fakeroot = AddOp({"path": self.path, "value": newobj, "op": "add"}).eval(fakeroot)
+ return fakeroot
+
+
+class TestOp(Op):
+ op: Literal["test"]
+ path: str
+ value: Any
+
+ def eval(self, fakeroot: Any) -> Any:
+ _parent, obj, _token = self._resolve_ptr(fakeroot, self.path)
+
+ if obj != self.value:
+ raise PatchError("test failed")
+
+ return fakeroot
+
+
+def query(
+ original: Any, method: Literal["get", "delete", "put", "patch"], ptr: str, payload: Any
+) -> Tuple[Any, Optional[Any]]:
+ ########################################
+ # Prepare data we will be working on
+
+ # First of all, we consider the original data to be immutable. So we need to make a copy
+ # in order to freely mutate them
+ dataroot = copy.deepcopy(original)
+
+ # To simplify referencing the root, create a fake root node
+ fakeroot = {"root": dataroot}
+
+ #########################################
+ # Handle the actual requested operation
+
+ # get = return what the path selector picks
+ if method == "get":
+ parent, obj, token = json_ptr_resolve(fakeroot, f"/root{ptr}")
+ return fakeroot["root"], obj
+
+ elif method == "delete":
+ fakeroot = RemoveOp({"op": "remove", "path": ptr}).eval(fakeroot)
+ return fakeroot["root"], None
+
+ elif method == "put":
+ parent, obj, token = json_ptr_resolve(fakeroot, f"/root{ptr}")
+ assert parent is not None # we know this due to the fakeroot
+ if isinstance(parent, list) and token == "-":
+ parent.append(payload)
+ else:
+ parent[token] = payload
+ return fakeroot["root"], None
+
+ elif method == "patch":
+ tp = List[Union[AddOp, RemoveOp, MoveOp, CopyOp, TestOp, ReplaceOp]]
+ transaction: tp = map_object(tp, payload)
+
+ for i, op in enumerate(transaction):
+ try:
+ fakeroot = op.eval(fakeroot)
+ except PatchError as e:
+ raise ValueError(f"json patch transaction failed on step {i}") from e
+
+ return fakeroot["root"], None
+
+ else:
+ assert False, "invalid operation, never happens"
diff --git a/python/knot_resolver/utils/modeling/renaming.py b/python/knot_resolver/utils/modeling/renaming.py
new file mode 100644
index 00000000..2420ed04
--- /dev/null
+++ b/python/knot_resolver/utils/modeling/renaming.py
@@ -0,0 +1,90 @@
+"""
+This module implements a standard dict and list alternatives, which can dynamically rename its keys replacing `-` with `_`.
+They persist in nested data structes, meaning that if you try to obtain a dict from Renamed variant, you will actually
+get RenamedDict back instead.
+
+Usage:
+
+d = dict()
+l = list()
+
+rd = renamed(d)
+rl = renamed(l)
+
+assert isinstance(rd, Renamed) == True
+assert l = rl.original()
+"""
+
+from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module]
+from typing import Any, Dict, List, TypeVar
+
+
+class Renamed(ABC):
+ @abstractmethod
+ def original(self) -> Any:
+ """
+ Returns a data structure, which is the source without dynamic renamings
+ """
+
+ @staticmethod
+ def map_public_to_private(name: Any) -> Any:
+ if isinstance(name, str):
+ return name.replace("_", "-")
+ return name
+
+ @staticmethod
+ def map_private_to_public(name: Any) -> Any:
+ if isinstance(name, str):
+ return name.replace("-", "_")
+ return name
+
+
+K = TypeVar("K")
+V = TypeVar("V")
+
+
+class RenamedDict(Dict[K, V], Renamed):
+ def keys(self) -> Any:
+ keys = super().keys()
+ return {Renamed.map_private_to_public(key) for key in keys}
+
+ def __getitem__(self, key: K) -> V:
+ key = Renamed.map_public_to_private(key)
+ res = super().__getitem__(key)
+ return renamed(res)
+
+ def __setitem__(self, key: K, value: V) -> None:
+ key = Renamed.map_public_to_private(key)
+ return super().__setitem__(key, value)
+
+ def __contains__(self, key: object) -> bool:
+ key = Renamed.map_public_to_private(key)
+ return super().__contains__(key)
+
+ def items(self) -> Any:
+ for k, v in super().items():
+ yield Renamed.map_private_to_public(k), renamed(v)
+
+ def original(self) -> Dict[K, V]:
+ return dict(super().items())
+
+
+class RenamedList(List[V], Renamed): # type: ignore
+ def __getitem__(self, key: Any) -> Any:
+ res = super().__getitem__(key)
+ return renamed(res)
+
+ def original(self) -> Any:
+ return list(super().__iter__())
+
+
+def renamed(obj: Any) -> Any:
+ if isinstance(obj, dict):
+ return RenamedDict(**obj)
+ elif isinstance(obj, list):
+ return RenamedList(obj)
+ else:
+ return obj
+
+
+__all__ = ["renamed", "Renamed"]
diff --git a/python/knot_resolver/utils/modeling/types.py b/python/knot_resolver/utils/modeling/types.py
new file mode 100644
index 00000000..4ce9aecc
--- /dev/null
+++ b/python/knot_resolver/utils/modeling/types.py
@@ -0,0 +1,105 @@
+# pylint: disable=comparison-with-callable
+
+
+import enum
+import inspect
+import sys
+from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
+
+from typing_extensions import Literal
+
+from .base_generic_type_wrapper import BaseGenericTypeWrapper
+
+NoneType = type(None)
+
+
+def is_optional(tp: Any) -> bool:
+ origin = getattr(tp, "__origin__", None)
+ args = get_generic_type_arguments(tp)
+
+ return origin == Union and len(args) == 2 and args[1] == NoneType # type: ignore
+
+
+def is_dict(tp: Any) -> bool:
+ return getattr(tp, "__origin__", None) in (Dict, dict)
+
+
+def is_enum(tp: Any) -> bool:
+ return inspect.isclass(tp) and issubclass(tp, enum.Enum)
+
+
+def is_list(tp: Any) -> bool:
+ return getattr(tp, "__origin__", None) in (List, list)
+
+
+def is_tuple(tp: Any) -> bool:
+ return getattr(tp, "__origin__", None) in (Tuple, tuple)
+
+
+def is_union(tp: Any) -> bool:
+ """Returns true even for optional types, because they are just a Union[T, NoneType]"""
+ return getattr(tp, "__origin__", None) == Union # type: ignore
+
+
+def is_literal(tp: Any) -> bool:
+ if sys.version_info.minor == 6:
+ return isinstance(tp, type(Literal))
+ else:
+ return getattr(tp, "__origin__", None) == Literal
+
+
+def is_generic_type_wrapper(tp: Any) -> bool:
+ orig = getattr(tp, "__origin__", None)
+ return inspect.isclass(orig) and issubclass(orig, BaseGenericTypeWrapper)
+
+
+def get_generic_type_arguments(tp: Any) -> List[Any]:
+ default: List[Any] = []
+ if sys.version_info.minor == 6 and is_literal(tp):
+ return getattr(tp, "__values__")
+ else:
+ return getattr(tp, "__args__", default)
+
+
+def get_generic_type_argument(tp: Any) -> Any:
+ """same as function get_generic_type_arguments, but expects just one type argument"""
+
+ args = get_generic_type_arguments(tp)
+ assert len(args) == 1
+ return args[0]
+
+
+def get_generic_type_wrapper_argument(tp: Type["BaseGenericTypeWrapper[Any]"]) -> Any:
+ assert hasattr(tp, "__origin__")
+ origin = getattr(tp, "__origin__")
+
+ assert hasattr(origin, "__orig_bases__")
+ orig_base: List[Any] = getattr(origin, "__orig_bases__", [])[0]
+
+ arg = get_generic_type_argument(tp)
+ return get_generic_type_argument(orig_base[arg])
+
+
+def is_none_type(tp: Any) -> bool:
+ return tp is None or tp == NoneType
+
+
+def get_attr_type(obj: Any, attr_name: str) -> Any:
+ assert hasattr(obj, attr_name)
+ assert hasattr(obj, "__annotations__")
+ annot = getattr(type(obj), "__annotations__")
+ assert attr_name in annot
+ return annot[attr_name]
+
+
+T = TypeVar("T")
+
+
+def get_optional_inner_type(optional: Type[Optional[T]]) -> Type[T]:
+ assert is_optional(optional)
+ t: Type[T] = get_generic_type_arguments(optional)[0]
+ return t
+
+
+def is_internal_field_name(field_name: str) -> bool:
+ return field_name.startswith("_")
diff --git a/python/knot_resolver/utils/requests.py b/python/knot_resolver/utils/requests.py
new file mode 100644
index 00000000..e52e54a3
--- /dev/null
+++ b/python/knot_resolver/utils/requests.py
@@ -0,0 +1,135 @@
+import errno
+import socket
+import sys
+from http.client import HTTPConnection
+from typing import Any, Optional, Union
+from urllib.error import HTTPError, URLError
+from urllib.parse import quote, unquote, urlparse
+from urllib.request import AbstractHTTPHandler, Request, build_opener, install_opener, urlopen
+
+from typing_extensions import Literal
+
+
+class SocketDesc:
+ def __init__(self, socket_def: str, source: str):
+ self.source = source
+ if ":" in socket_def:
+ # `socket_def` contains a schema, probably already URI-formatted, use directly
+ self.uri = socket_def
+ else:
+ # `socket_def` is probably a path, convert to URI
+ self.uri = f'http+unix://{quote(socket_def, safe="")}'
+
+ while self.uri.endswith("/"):
+ self.uri = self.uri[:-1]
+
+
+class Response:
+ def __init__(self, status: int, body: str) -> None:
+ self.status = status
+ self.body = body
+
+ def __repr__(self) -> str:
+ return f"status: {self.status}\nbody:\n{self.body}"
+
+
+def _print_conn_error(error_desc: str, url: str, socket_source: str) -> None:
+ host: str
+ try:
+ parsed_url = urlparse(url)
+ host = unquote(parsed_url.hostname or "(Unknown)")
+ except Exception as e:
+ host = f"(Invalid URL: {e})"
+ msg = f"""
+{error_desc}
+\tURL: {url}
+\tHost/Path: {host}
+\tSourced from: {socket_source}
+Is the URL correct?
+\tUnix socket would start with http+unix:// and URL encoded path.
+\tInet sockets would start with http:// and domain or ip
+ """
+ print(msg, file=sys.stderr)
+
+
+def request(
+ socket_desc: SocketDesc,
+ method: Literal["GET", "POST", "HEAD", "PUT", "DELETE"],
+ path: str,
+ body: Optional[str] = None,
+ content_type: str = "application/json",
+) -> Response:
+ while path.startswith("/"):
+ path = path[1:]
+ url = f"{socket_desc.uri}/{path}"
+ req = Request(
+ url,
+ method=method,
+ data=body.encode("utf8") if body is not None else None,
+ headers={"Content-Type": content_type},
+ )
+ # req.add_header("Authorization", _authorization_header)
+
+ timeout_m = 5 # minutes
+ try:
+ with urlopen(req, timeout=timeout_m * 60) as response:
+ return Response(response.status, response.read().decode("utf8"))
+ except HTTPError as err:
+ return Response(err.code, err.read().decode("utf8"))
+ except URLError as err:
+ if err.errno == errno.ECONNREFUSED or isinstance(err.reason, ConnectionRefusedError):
+ _print_conn_error("Connection refused.", url, socket_desc.source)
+ elif err.errno == errno.ENOENT or isinstance(err.reason, FileNotFoundError):
+ _print_conn_error("No such file or directory.", url, socket_desc.source)
+ else:
+ _print_conn_error(str(err), url, socket_desc.source)
+ sys.exit(1)
+ except (TimeoutError, socket.timeout):
+ _print_conn_error(
+ f"Connection timed out after {timeout_m} minutes."
+ "\nIt does not mean that the operation necessarily failed."
+ "\nSee Knot Resolver's log for more information.",
+ url,
+ socket_desc.source,
+ )
+ sys.exit(1)
+
+
+# Code heavily inspired by requests-unixsocket
+# https://github.com/msabramo/requests-unixsocket/blob/master/requests_unixsocket/adapters.py
+class UnixHTTPConnection(HTTPConnection):
+ def __init__(self, unix_socket_url: str, timeout: Union[int, float] = 60):
+ """Create an HTTP connection to a unix domain socket
+ :param unix_socket_url: A URL with a scheme of 'http+unix' and the
+ netloc is a percent-encoded path to a unix domain socket. E.g.:
+ 'http+unix://%2Ftmp%2Fprofilesvc.sock/status/pid'
+ """
+ super().__init__("localhost", timeout=timeout)
+ self.unix_socket_path = unix_socket_url
+ self.timeout = timeout
+ self.sock: Optional[socket.socket] = None
+
+ def __del__(self): # base class does not have d'tor
+ if self.sock:
+ self.sock.close()
+
+ def connect(self):
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ sock.settimeout(self.timeout)
+ sock.connect(self.unix_socket_path)
+ self.sock = sock
+
+
+class UnixHTTPHandler(AbstractHTTPHandler):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def open_(self: UnixHTTPHandler, req: Any) -> Any:
+ return self.do_open(UnixHTTPConnection, req) # type: ignore[arg-type]
+
+ setattr(UnixHTTPHandler, "http+unix_open", open_)
+ setattr(UnixHTTPHandler, "http+unix_request", AbstractHTTPHandler.do_request_)
+
+
+opener = build_opener(UnixHTTPHandler())
+install_opener(opener)
diff --git a/python/knot_resolver/utils/systemd_notify.py b/python/knot_resolver/utils/systemd_notify.py
new file mode 100644
index 00000000..44e8dee1
--- /dev/null
+++ b/python/knot_resolver/utils/systemd_notify.py
@@ -0,0 +1,54 @@
+import enum
+import logging
+import os
+import socket
+
+logger = logging.getLogger(__name__)
+
+
+class _Status(enum.Enum):
+ NOT_INITIALIZED = 1
+ FUNCTIONAL = 2
+ FAILED = 3
+
+
+_status = _Status.NOT_INITIALIZED
+_socket = None
+
+
+def systemd_notify(**values: str) -> None:
+ global _status
+ global _socket
+
+ if _status is _Status.NOT_INITIALIZED:
+ socket_addr = os.getenv("NOTIFY_SOCKET")
+ os.unsetenv("NOTIFY_SOCKET")
+ if socket_addr is None:
+ _status = _Status.FAILED
+ return
+ if socket_addr.startswith("@"):
+ socket_addr = socket_addr.replace("@", "\0", 1)
+
+ try:
+ _socket = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+ _socket.connect(socket_addr)
+ _status = _Status.FUNCTIONAL
+ except Exception:
+ _socket = None
+ _status = _Status.FAILED
+ logger.warning(f"Failed to connect to $NOTIFY_SOCKET at '{socket_addr}'", exc_info=True)
+ return
+
+ elif _status is _Status.FAILED:
+ return
+
+ if _status is _Status.FUNCTIONAL:
+ assert _socket is not None
+ payload = "\n".join((f"{key}={value}" for key, value in values.items()))
+ try:
+ _socket.send(payload.encode("utf8"))
+ except Exception:
+ logger.warning("Failed to send notification to systemd", exc_info=True)
+ _status = _Status.FAILED
+ _socket.close()
+ _socket = None
diff --git a/python/knot_resolver/utils/which.py b/python/knot_resolver/utils/which.py
new file mode 100644
index 00000000..450102f3
--- /dev/null
+++ b/python/knot_resolver/utils/which.py
@@ -0,0 +1,22 @@
+import functools
+import os
+from pathlib import Path
+
+
+@functools.lru_cache(maxsize=16)
+def which(binary_name: str) -> Path:
+ """
+ Given a name of an executable, search $PATH and return
+ the absolute path of that executable. The results of this function
+ are LRU cached.
+
+ If not found, throws an RuntimeError.
+ """
+
+ possible_directories = os.get_exec_path()
+ for dr in possible_directories:
+ p = Path(dr, binary_name)
+ if p.exists():
+ return p.absolute()
+
+ raise RuntimeError(f"Executable {binary_name} was not found in $PATH")