diff options
Diffstat (limited to 'python/knot_resolver')
108 files changed, 11180 insertions, 0 deletions
diff --git a/python/knot_resolver/__init__.py b/python/knot_resolver/__init__.py new file mode 100644 index 00000000..511e8d44 --- /dev/null +++ b/python/knot_resolver/__init__.py @@ -0,0 +1,5 @@ +from .datamodel.config_schema import KresConfig + +__version__ = "0.1.0" + +__all__ = ["KresConfig"] diff --git a/python/knot_resolver/client/__init__.py b/python/knot_resolver/client/__init__.py new file mode 100644 index 00000000..5b82d3be --- /dev/null +++ b/python/knot_resolver/client/__init__.py @@ -0,0 +1,5 @@ +from pathlib import Path + +from knot_resolver.datamodel.globals import Context, set_global_validation_context + +set_global_validation_context(Context(Path("."), False)) diff --git a/python/knot_resolver/client/__main__.py b/python/knot_resolver/client/__main__.py new file mode 100644 index 00000000..56200674 --- /dev/null +++ b/python/knot_resolver/client/__main__.py @@ -0,0 +1,4 @@ +from knot_resolver.client.main import main + +if __name__ == "__main__": + main() diff --git a/python/knot_resolver/client/client.py b/python/knot_resolver/client/client.py new file mode 100644 index 00000000..4e7d13ea --- /dev/null +++ b/python/knot_resolver/client/client.py @@ -0,0 +1,51 @@ +import argparse + +from knot_resolver.client.command import CommandArgs + +KRES_CLIENT_NAME = "kresctl" + + +class KresClient: + def __init__( + self, + namespace: argparse.Namespace, + parser: argparse.ArgumentParser, + prompt: str = KRES_CLIENT_NAME, + ) -> None: + self.path = None + self.prompt = prompt + self.namespace = namespace + self.parser = parser + + def execute(self): + if hasattr(self.namespace, "command"): + args = CommandArgs(self.namespace, self.parser) + command = args.command(self.namespace) + command.run(args) + else: + self.parser.print_help() + + def _prompt_format(self) -> str: + bolt = "\033[1m" + white = "\033[38;5;255m" + reset = "\033[0;0m" + + if self.path: + prompt = f"{bolt}[{self.prompt} {white}{self.path}{reset}{bolt}]" + else: + prompt = f"{bolt}{self.prompt}" + return f"{prompt}> {reset}" + + def interactive(self): + try: + while True: + pass + # TODO: not working yet + # cmd = input(f"{self._prompt_format()}") + # namespace = self.parser.parse_args(cmd.split(" ")) + # namespace.interactive = True + # namespace.socket = self.namespace.socket + # self.namespace = namespace + # self.execute() + except KeyboardInterrupt: + pass diff --git a/python/knot_resolver/client/command.py b/python/knot_resolver/client/command.py new file mode 100644 index 00000000..af59c42e --- /dev/null +++ b/python/knot_resolver/client/command.py @@ -0,0 +1,125 @@ +import argparse +import os +from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module] +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Type, TypeVar +from urllib.parse import quote + +from knot_resolver.manager.constants import API_SOCK_ENV_VAR, CONFIG_FILE_ENV_VAR, DEFAULT_MANAGER_CONFIG_FILE +from knot_resolver.datamodel.config_schema import DEFAULT_MANAGER_API_SOCK +from knot_resolver.datamodel.types import IPAddressPort +from knot_resolver.utils.modeling import parsing +from knot_resolver.utils.modeling.exceptions import DataValidationError +from knot_resolver.utils.requests import SocketDesc + +T = TypeVar("T", bound=Type["Command"]) + +CompWords = Dict[str, Optional[str]] + +_registered_commands: List[Type["Command"]] = [] + + +def register_command(cls: T) -> T: + _registered_commands.append(cls) + return cls + + +def get_help_command() -> Type["Command"]: + for command in _registered_commands: + if command.__name__ == "HelpCommand": + return command + raise ValueError("missing HelpCommand") + + +def install_commands_parsers(parser: argparse.ArgumentParser) -> None: + subparsers = parser.add_subparsers(help="command type") + for command in _registered_commands: + subparser, typ = command.register_args_subparser(subparsers) + subparser.set_defaults(command=typ, subparser=subparser) + + +def get_socket_from_config(config: Path, optional_file: bool) -> Optional[SocketDesc]: + try: + with open(config, "r", encoding="utf8") as f: + data = parsing.try_to_parse(f.read()) + mkey = "management" + if mkey in data: + management = data[mkey] + if "unix-socket" in management: + return SocketDesc( + f'http+unix://{quote(management["unix-socket"], safe="")}/', + f'Key "/management/unix-socket" in "{config}" file', + ) + elif "interface" in management: + ip = IPAddressPort(management["interface"], object_path=f"/{mkey}/interface") + return SocketDesc( + f"http://{ip.addr}:{ip.port}", + f'Key "/management/interface" in "{config}" file', + ) + return None + except ValueError as e: + raise DataValidationError(*e.args) from e # pylint: disable=no-value-for-parameter + except OSError as e: + if not optional_file: + raise e + return None + + +def determine_socket(namespace: argparse.Namespace) -> SocketDesc: + # 1) socket from '--socket' argument + if len(namespace.socket) > 0: + return SocketDesc(namespace.socket[0], "--socket argument") + + config_path = os.getenv(CONFIG_FILE_ENV_VAR) + socket_env = os.getenv(API_SOCK_ENV_VAR) + + socket: Optional[SocketDesc] = None + # 2) socket from config file ('--config' argument) + if len(namespace.config) > 0: + socket = get_socket_from_config(namespace.config[0], False) + # 3) socket from config file (environment variable) + elif config_path: + socket = get_socket_from_config(Path(config_path), False) + # 4) socket from environment variable + elif socket_env: + socket = SocketDesc(socket_env, f'Environment variable "{API_SOCK_ENV_VAR}"') + # 5) socket from config file (default config file constant) + else: + socket = get_socket_from_config(DEFAULT_MANAGER_CONFIG_FILE, True) + + if socket: + return socket + # 6) socket default + return SocketDesc(DEFAULT_MANAGER_API_SOCK, f'Default value "{DEFAULT_MANAGER_API_SOCK}"') + + +class CommandArgs: + def __init__(self, namespace: argparse.Namespace, parser: argparse.ArgumentParser) -> None: + self.namespace = namespace + self.parser = parser + self.subparser: argparse.ArgumentParser = namespace.subparser + self.command: Type["Command"] = namespace.command + + self.socket: SocketDesc = determine_socket(namespace) + + +class Command(ABC): + @staticmethod + @abstractmethod + def register_args_subparser( + subparser: "argparse._SubParsersAction[argparse.ArgumentParser]", + ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]: + raise NotImplementedError() + + @abstractmethod + def __init__(self, namespace: argparse.Namespace) -> None: # pylint: disable=[unused-argument] + super().__init__() + + @abstractmethod + def run(self, args: CommandArgs) -> None: + raise NotImplementedError() + + @staticmethod + @abstractmethod + def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords: + raise NotImplementedError() diff --git a/python/knot_resolver/client/commands/cache.py b/python/knot_resolver/client/commands/cache.py new file mode 100644 index 00000000..60417eec --- /dev/null +++ b/python/knot_resolver/client/commands/cache.py @@ -0,0 +1,123 @@ +import argparse +import sys +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Type + +from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command +from knot_resolver.datamodel.cache_schema import CacheClearRPCSchema +from knot_resolver.utils.modeling.exceptions import AggregateDataValidationError, DataValidationError +from knot_resolver.utils.modeling.parsing import DataFormat, parse_json +from knot_resolver.utils.requests import request + + +class CacheOperations(Enum): + CLEAR = 0 + + +@register_command +class CacheCommand(Command): + def __init__(self, namespace: argparse.Namespace) -> None: + super().__init__(namespace) + self.operation: Optional[CacheOperations] = namespace.operation if hasattr(namespace, "operation") else None + self.output_format: DataFormat = ( + namespace.output_format if hasattr(namespace, "output_format") else DataFormat.YAML + ) + + # CLEAR operation + self.clear_dict: Dict[str, Any] = {} + if hasattr(namespace, "exact_name"): + self.clear_dict["exact-name"] = namespace.exact_name + if hasattr(namespace, "name"): + self.clear_dict["name"] = namespace.name + if hasattr(namespace, "rr_type"): + self.clear_dict["rr-type"] = namespace.rr_type + if hasattr(namespace, "chunk_size"): + self.clear_dict["chunk-size"] = namespace.chunk_size + + @staticmethod + def register_args_subparser( + subparser: "argparse._SubParsersAction[argparse.ArgumentParser]", + ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]: + cache_parser = subparser.add_parser("cache", help="Performs operations on the cache of the running resolver.") + + config_subparsers = cache_parser.add_subparsers(help="operation type") + + # 'clear' operation + clear_subparser = config_subparsers.add_parser( + "clear", help="Purge cache records that match specified criteria." + ) + clear_subparser.set_defaults(operation=CacheOperations.CLEAR, exact_name=False) + clear_subparser.add_argument( + "--exact-name", + help="If set, only records with the same name are purged.", + action="store_true", + dest="exact_name", + ) + clear_subparser.add_argument( + "--rr-type", + help="Optional, the resource record type to purge. It is supported only with the '--exact-name' flag set.", + action="store", + type=str, + ) + clear_subparser.add_argument( + "--chunk-size", + help="Optional, the number of records to remove in one round; the default is 100." + " The purpose is not to block the resolver for long." + " The resolver repeats the cache clearing after one millisecond until all matching data is cleared.", + action="store", + type=int, + default=100, + ) + clear_subparser.add_argument( + "name", + type=str, + nargs="?", + help="Optional, subtree name to purge; if omitted, the entire cache is purged (and all other parameters are ignored).", + default=None, + ) + + output_format = clear_subparser.add_mutually_exclusive_group() + output_format_default = DataFormat.YAML + output_format.add_argument( + "--json", + help="Set JSON as the output format.", + const=DataFormat.JSON, + action="store_const", + dest="output_format", + default=output_format_default, + ) + output_format.add_argument( + "--yaml", + help="Set YAML as the output format. YAML is the default.", + const=DataFormat.YAML, + action="store_const", + dest="output_format", + default=output_format_default, + ) + + return cache_parser, CacheCommand + + @staticmethod + def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords: + return {} + + def run(self, args: CommandArgs) -> None: + if not self.operation: + args.subparser.print_help() + sys.exit() + + if self.operation == CacheOperations.CLEAR: + try: + validated = CacheClearRPCSchema(self.clear_dict) + except (AggregateDataValidationError, DataValidationError) as e: + print(e, file=sys.stderr) + sys.exit(1) + + body: str = DataFormat.JSON.dict_dump(validated.get_unparsed_data()) + response = request(args.socket, "POST", "cache/clear", body) + body_dict = parse_json(response.body) + + if response.status != 200: + print(response, file=sys.stderr) + sys.exit(1) + print(self.output_format.dict_dump(body_dict, indent=4)) diff --git a/python/knot_resolver/client/commands/completion.py b/python/knot_resolver/client/commands/completion.py new file mode 100644 index 00000000..05fdded8 --- /dev/null +++ b/python/knot_resolver/client/commands/completion.py @@ -0,0 +1,95 @@ +import argparse +from enum import Enum +from typing import List, Tuple, Type + +from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command + + +class Shells(Enum): + BASH = 0 + FISH = 1 + + +@register_command +class CompletionCommand(Command): + def __init__(self, namespace: argparse.Namespace) -> None: + super().__init__(namespace) + self.shell: Shells = namespace.shell + self.space = namespace.space + self.comp_args: List[str] = namespace.comp_args + + if self.space: + self.comp_args.append("") + + @staticmethod + def register_args_subparser( + subparser: "argparse._SubParsersAction[argparse.ArgumentParser]", + ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]: + completion = subparser.add_parser("completion", help="commands auto-completion") + completion.add_argument( + "--space", + help="space after last word, returns all possible folowing options", + dest="space", + action="store_true", + default=False, + ) + completion.add_argument( + "comp_args", + type=str, + help="arguments to complete", + nargs="*", + ) + + shells_dest = "shell" + shells = completion.add_mutually_exclusive_group() + shells.add_argument("--bash", action="store_const", dest=shells_dest, const=Shells.BASH, default=Shells.BASH) + shells.add_argument("--fish", action="store_const", dest=shells_dest, const=Shells.FISH) + + return completion, CompletionCommand + + @staticmethod + def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords: + words: CompWords = {} + # for action in parser._actions: + # for opt in action.option_strings: + # words[opt] = action.help + # return words + return words + + def run(self, args: CommandArgs) -> None: + pass + # subparsers = args.parser._subparsers + # words: CompWords = {} + + # if subparsers: + # words = parser_words(subparsers._actions) + + # uargs = iter(self.comp_args) + # for uarg in uargs: + # subparser = subparser_by_name(uarg, subparsers._actions) # pylint: disable=W0212 + + # if subparser: + # cmd: Command = subparser_command(subparser) + # subparser_args = self.comp_args[self.comp_args.index(uarg) + 1 :] + # if subparser_args: + # words = cmd.completion(subparser_args, subparser) + # break + # elif uarg in ["-s", "--socket"]: + # # if arg is socket config, skip next arg + # next(uargs) + # continue + # elif uarg in words: + # # uarg is walid arg, continue + # continue + # else: + # raise ValueError(f"unknown argument: {uarg}") + + # # print completion words + # # based on required bash/fish shell format + # if self.shell == Shells.BASH: + # print(" ".join(words)) + # elif self.shell == Shells.FISH: + # # TODO: FISH completion implementation + # pass + # else: + # raise ValueError(f"unexpected value of {Shells}: {self.shell}") diff --git a/python/knot_resolver/client/commands/config.py b/python/knot_resolver/client/commands/config.py new file mode 100644 index 00000000..add17272 --- /dev/null +++ b/python/knot_resolver/client/commands/config.py @@ -0,0 +1,222 @@ +import argparse +import sys +from enum import Enum +from typing import List, Optional, Tuple, Type + +from typing_extensions import Literal + +from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command +from knot_resolver.utils.modeling.parsing import DataFormat, parse_json, try_to_parse +from knot_resolver.utils.requests import request + + +class Operations(Enum): + SET = 0 + DELETE = 1 + GET = 2 + + +def operation_to_method(operation: Operations) -> Literal["PUT", "GET", "DELETE"]: + if operation == Operations.SET: + return "PUT" + elif operation == Operations.DELETE: + return "DELETE" + return "GET" + + +# def _properties_words(props: Dict[str, Any]) -> CompWords: +# words: CompWords = {} +# for name, prop in props.items(): +# words[name] = prop["description"] if "description" in prop else None +# return words + + +# def _path_comp_words(node: str, nodes: List[str], props: Dict[str, Any]) -> CompWords: +# i = nodes.index(node) +# ln = len(nodes[i:]) + +# # if node is last in path, return all possible words on thi level +# if ln == 1: +# return _properties_words(props) +# # if node is valid +# elif node in props: +# node_schema = props[node] + +# if "anyOf" in node_schema: +# for item in node_schema["anyOf"]: +# print(item) + +# elif "type" not in node_schema: +# pass + +# elif node_schema["type"] == "array": +# if ln > 2: +# # skip index for item in array +# return _path_comp_words(nodes[i + 2], nodes, node_schema["items"]["properties"]) +# if "enum" in node_schema["items"]: +# print(node_schema["items"]["enum"]) +# return {"0": "first array item", "-": "last array item"} +# elif node_schema["type"] == "object": +# if "additionalProperties" in node_schema: +# print(node_schema) +# return _path_comp_words(nodes[i + 1], nodes, node_schema["properties"]) +# return {} + +# # arrays/lists must be handled sparately +# if node_schema["type"] == "array": +# if ln > 2: +# # skip index for item in array +# return _path_comp_words(nodes[i + 2], nodes, node_schema["items"]["properties"]) +# return {"0": "first array item", "-": "last array item"} +# return _path_comp_words(nodes[i + 1], nodes, node_schema["properties"]) +# else: +# # if node is not last or valid, value error +# raise ValueError(f"unknown config path node: {node}") + + +@register_command +class ConfigCommand(Command): + def __init__(self, namespace: argparse.Namespace) -> None: + super().__init__(namespace) + self.path: str = str(namespace.path) if hasattr(namespace, "path") else "" + self.format: DataFormat = namespace.format if hasattr(namespace, "format") else DataFormat.JSON + self.operation: Optional[Operations] = namespace.operation if hasattr(namespace, "operation") else None + self.file: Optional[str] = namespace.file if hasattr(namespace, "file") else None + + @staticmethod + def register_args_subparser( + subparser: "argparse._SubParsersAction[argparse.ArgumentParser]", + ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]: + config = subparser.add_parser("config", help="Performs operations on the running resolver's configuration.") + path_help = "Optional, path (JSON pointer, RFC6901) to the configuration resources. By default, the entire configuration is selected." + + config_subparsers = config.add_subparsers(help="operation type") + + # GET operation + get = config_subparsers.add_parser("get", help="Get current configuration from the resolver.") + get.set_defaults(operation=Operations.GET, format=DataFormat.YAML) + + get.add_argument( + "-p", + "--path", + help=path_help, + action="store", + type=str, + default="", + ) + get.add_argument( + "file", + help="Optional, path to the file where to save exported configuration data. If not specified, data will be printed.", + type=str, + nargs="?", + ) + + get_formats = get.add_mutually_exclusive_group() + get_formats.add_argument( + "--json", + help="Get configuration data in JSON format.", + const=DataFormat.JSON, + action="store_const", + dest="format", + ) + get_formats.add_argument( + "--yaml", + help="Get configuration data in YAML format, default.", + const=DataFormat.YAML, + action="store_const", + dest="format", + ) + + # SET operation + set = config_subparsers.add_parser("set", help="Set new configuration for the resolver.") + set.set_defaults(operation=Operations.SET) + + set.add_argument( + "-p", + "--path", + help=path_help, + action="store", + type=str, + default="", + ) + + value_or_file = set.add_mutually_exclusive_group() + value_or_file.add_argument( + "file", + help="Optional, path to file with new configuraion.", + type=str, + nargs="?", + ) + value_or_file.add_argument( + "value", + help="Optional, new configuration value.", + type=str, + nargs="?", + ) + + # DELETE operation + delete = config_subparsers.add_parser( + "delete", help="Delete given configuration property or list item at the given index." + ) + delete.set_defaults(operation=Operations.DELETE) + delete.add_argument( + "-p", + "--path", + help=path_help, + action="store", + type=str, + default="", + ) + + return config, ConfigCommand + + @staticmethod + def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords: + # words = parser_words(parser._actions) # pylint: disable=W0212 + + # for arg in args: + # if arg in words: + # continue + # elif arg.startswith("-"): + # return words + # elif arg == args[-1]: + # config_path = arg[1:].split("/") if arg.startswith("/") else arg.split("/") + # schema_props: Dict[str, Any] = KresConfig.json_schema()["properties"] + # return _path_comp_words(config_path[0], config_path, schema_props) + # else: + # break + return {} + + def run(self, args: CommandArgs) -> None: + if not self.operation: + args.subparser.print_help() + sys.exit() + + new_config = None + path = f"v1/config{self.path}" + method = operation_to_method(self.operation) + + if self.operation == Operations.SET: + if self.file: + try: + with open(self.file, "r") as f: + new_config = f.read() + except FileNotFoundError: + new_config = self.file + else: + # use STDIN also when file is not specified + new_config = input("Type new configuration: ") + + body = DataFormat.JSON.dict_dump(try_to_parse(new_config)) if new_config else None + response = request(args.socket, method, path, body) + + if response.status != 200: + print(response, file=sys.stderr) + sys.exit(1) + + if self.operation == Operations.GET and self.file: + with open(self.file, "w") as f: + f.write(self.format.dict_dump(parse_json(response.body), indent=4)) + print(f"saved to: {self.file}") + elif response.body: + print(self.format.dict_dump(parse_json(response.body), indent=4)) diff --git a/python/knot_resolver/client/commands/convert.py b/python/knot_resolver/client/commands/convert.py new file mode 100644 index 00000000..a25c5cd9 --- /dev/null +++ b/python/knot_resolver/client/commands/convert.py @@ -0,0 +1,85 @@ +import argparse +import sys +from pathlib import Path +from typing import List, Optional, Tuple, Type + +from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command +from knot_resolver.datamodel import KresConfig +from knot_resolver.datamodel.globals import ( + Context, + reset_global_validation_context, + set_global_validation_context, +) +from knot_resolver.utils.modeling import try_to_parse +from knot_resolver.utils.modeling.exceptions import DataParsingError, DataValidationError + + +@register_command +class ConvertCommand(Command): + def __init__(self, namespace: argparse.Namespace) -> None: + super().__init__(namespace) + self.input_file: str = namespace.input_file + self.output_file: Optional[str] = namespace.output_file + self.strict: bool = namespace.strict + self.type: str = namespace.type + + @staticmethod + def register_args_subparser( + subparser: "argparse._SubParsersAction[argparse.ArgumentParser]", + ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]: + convert = subparser.add_parser("convert", help="Converts JSON or YAML configuration to Lua script.") + convert.set_defaults(strict=True) + convert.add_argument( + "--no-strict", + help="Ignore strict rules during validation, e.g. path/file existence.", + action="store_false", + dest="strict", + ) + convert.add_argument( + "--type", help="The type of Lua script to generate", choices=["worker", "policy-loader"], default="worker" + ) + convert.add_argument( + "input_file", + type=str, + help="File with configuration in YAML or JSON format.", + ) + + convert.add_argument( + "output_file", + type=str, + nargs="?", + help="Optional, output file for converted configuration in Lua script. If not specified, converted configuration is printed.", + default=None, + ) + + return convert, ConvertCommand + + @staticmethod + def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords: + return {} + + def run(self, args: CommandArgs) -> None: + with open(self.input_file, "r") as f: + data = f.read() + + try: + parsed = try_to_parse(data) + set_global_validation_context(Context(Path(Path(self.input_file).parent), self.strict)) + + if self.type == "worker": + lua = KresConfig(parsed).render_lua() + elif self.type == "policy-loader": + lua = KresConfig(parsed).render_lua_policy() + else: + raise ValueError(f"Invalid self.type={self.type}") + + reset_global_validation_context() + except (DataParsingError, DataValidationError) as e: + print(e, file=sys.stderr) + sys.exit(1) + + if self.output_file: + with open(self.output_file, "w") as f: + f.write(lua) + else: + print(lua) diff --git a/python/knot_resolver/client/commands/help.py b/python/knot_resolver/client/commands/help.py new file mode 100644 index 00000000..87306c2a --- /dev/null +++ b/python/knot_resolver/client/commands/help.py @@ -0,0 +1,24 @@ +import argparse +from typing import List, Tuple, Type + +from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command + + +@register_command +class HelpCommand(Command): + def __init__(self, namespace: argparse.Namespace) -> None: + super().__init__(namespace) + + def run(self, args: CommandArgs) -> None: + args.parser.print_help() + + @staticmethod + def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords: + return {} + + @staticmethod + def register_args_subparser( + subparser: "argparse._SubParsersAction[argparse.ArgumentParser]", + ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]: + stop = subparser.add_parser("help", help="show this help message and exit") + return stop, HelpCommand diff --git a/python/knot_resolver/client/commands/metrics.py b/python/knot_resolver/client/commands/metrics.py new file mode 100644 index 00000000..058cad8b --- /dev/null +++ b/python/knot_resolver/client/commands/metrics.py @@ -0,0 +1,67 @@ +import argparse +import sys +from typing import List, Optional, Tuple, Type + +from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command +from knot_resolver.utils.modeling.parsing import DataFormat, parse_json +from knot_resolver.utils.requests import request + + +@register_command +class MetricsCommand(Command): + def __init__(self, namespace: argparse.Namespace) -> None: + self.file: Optional[str] = namespace.file + self.prometheus: bool = namespace.prometheus + + super().__init__(namespace) + + @staticmethod + def register_args_subparser( + subparser: "argparse._SubParsersAction[argparse.ArgumentParser]", + ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]: + metrics = subparser.add_parser( + "metrics", + help="Get aggregated metrics from the running resolver in JSON format (default) or optionally in Prometheus format." + "\nThe 'prometheus-client' Python package needs to be installed if you wish to use the Prometheus format." + "\nRequires a connection to the management HTTP API.", + ) + + metrics.add_argument( + "--prometheus", + help="Get metrics in Prometheus format if dependencies are met in the resolver.", + action="store_true", + default=False, + ) + + metrics.add_argument( + "file", + help="Optional. The file into which metrics will be exported." + "\nIf not specified, the metrics are printed into stdout.", + nargs="?", + default=None, + ) + return metrics, MetricsCommand + + @staticmethod + def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords: + return {} + + def run(self, args: CommandArgs) -> None: + response = request(args.socket, "GET", "metrics/prometheus" if self.prometheus else "metrics/json") + + if response.status == 200: + if self.prometheus: + metrics = response.body + else: + metrics = DataFormat.JSON.dict_dump(parse_json(response.body), indent=4) + + if self.file: + with open(self.file, "w") as f: + f.write(metrics) + else: + print(metrics) + else: + print(response, file=sys.stderr) + if self.prometheus and response.status == 404: + print("Prometheus is unavailable due to missing optional dependencies", file=sys.stderr) + sys.exit(1) diff --git a/python/knot_resolver/client/commands/reload.py b/python/knot_resolver/client/commands/reload.py new file mode 100644 index 00000000..c1350fc5 --- /dev/null +++ b/python/knot_resolver/client/commands/reload.py @@ -0,0 +1,36 @@ +import argparse +import sys +from typing import List, Tuple, Type + +from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command +from knot_resolver.utils.requests import request + + +@register_command +class ReloadCommand(Command): + def __init__(self, namespace: argparse.Namespace) -> None: + super().__init__(namespace) + + @staticmethod + def register_args_subparser( + subparser: "argparse._SubParsersAction[argparse.ArgumentParser]", + ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]: + reload = subparser.add_parser( + "reload", + help="Tells the resolver to reload YAML configuration file." + " Old processes are replaced by new ones (with updated configuration) using rolling restarts." + " So there will be no DNS service unavailability during reload operation.", + ) + + return reload, ReloadCommand + + @staticmethod + def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords: + return {} + + def run(self, args: CommandArgs) -> None: + response = request(args.socket, "POST", "reload") + + if response.status != 200: + print(response, file=sys.stderr) + sys.exit(1) diff --git a/python/knot_resolver/client/commands/schema.py b/python/knot_resolver/client/commands/schema.py new file mode 100644 index 00000000..f5538424 --- /dev/null +++ b/python/knot_resolver/client/commands/schema.py @@ -0,0 +1,55 @@ +import argparse +import json +import sys +from typing import List, Optional, Tuple, Type + +from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command +from knot_resolver.datamodel.config_schema import KresConfig +from knot_resolver.utils.requests import request + + +@register_command +class SchemaCommand(Command): + def __init__(self, namespace: argparse.Namespace) -> None: + super().__init__(namespace) + self.live: bool = namespace.live + self.file: Optional[str] = namespace.file + + @staticmethod + def register_args_subparser( + subparser: "argparse._SubParsersAction[argparse.ArgumentParser]", + ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]: + schema = subparser.add_parser( + "schema", help="Shows JSON-schema repersentation of the Knot Resolver's configuration." + ) + schema.add_argument( + "-l", + "--live", + help="Get configuration JSON-schema from the running resolver. Requires connection to the management API.", + action="store_true", + default=False, + ) + schema.add_argument("file", help="Optional, file where to export JSON-schema.", nargs="?", default=None) + + return schema, SchemaCommand + + @staticmethod + def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords: + return {} + # return parser_words(parser._actions) # pylint: disable=W0212 + + def run(self, args: CommandArgs) -> None: + if self.live: + response = request(args.socket, "GET", "schema") + if response.status != 200: + print(response, file=sys.stderr) + sys.exit(1) + schema = response.body + else: + schema = json.dumps(KresConfig.json_schema(), indent=4) + + if self.file: + with open(self.file, "w") as f: + f.write(schema) + else: + print(schema) diff --git a/python/knot_resolver/client/commands/stop.py b/python/knot_resolver/client/commands/stop.py new file mode 100644 index 00000000..35baf36c --- /dev/null +++ b/python/knot_resolver/client/commands/stop.py @@ -0,0 +1,32 @@ +import argparse +import sys +from typing import List, Tuple, Type + +from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command +from knot_resolver.utils.requests import request + + +@register_command +class StopCommand(Command): + def __init__(self, namespace: argparse.Namespace) -> None: + super().__init__(namespace) + + @staticmethod + def register_args_subparser( + subparser: "argparse._SubParsersAction[argparse.ArgumentParser]", + ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]: + stop = subparser.add_parser( + "stop", help="Tells the resolver to shutdown everthing. No process will run after this command." + ) + return stop, StopCommand + + def run(self, args: CommandArgs) -> None: + response = request(args.socket, "POST", "stop") + + if response.status != 200: + print(response, file=sys.stderr) + sys.exit(1) + + @staticmethod + def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords: + return {} diff --git a/python/knot_resolver/client/commands/validate.py b/python/knot_resolver/client/commands/validate.py new file mode 100644 index 00000000..267bf562 --- /dev/null +++ b/python/knot_resolver/client/commands/validate.py @@ -0,0 +1,63 @@ +import argparse +import sys +from pathlib import Path +from typing import List, Tuple, Type + +from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command +from knot_resolver.datamodel import KresConfig +from knot_resolver.datamodel.globals import ( + Context, + reset_global_validation_context, + set_global_validation_context, +) +from knot_resolver.utils.modeling import try_to_parse +from knot_resolver.utils.modeling.exceptions import DataParsingError, DataValidationError + + +@register_command +class ValidateCommand(Command): + def __init__(self, namespace: argparse.Namespace) -> None: + super().__init__(namespace) + self.input_file: str = namespace.input_file + self.strict: bool = namespace.strict + + @staticmethod + def register_args_subparser( + subparser: "argparse._SubParsersAction[argparse.ArgumentParser]", + ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]: + validate = subparser.add_parser("validate", help="Validates configuration in JSON or YAML format.") + validate.set_defaults(strict=True) + validate.add_argument( + "--no-strict", + help="Ignore strict rules during validation, e.g. path/file existence.", + action="store_false", + dest="strict", + ) + validate.add_argument( + "input_file", + type=str, + nargs="?", + help="File with configuration in YAML or JSON format.", + default=None, + ) + + return validate, ValidateCommand + + @staticmethod + def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords: + return {} + + def run(self, args: CommandArgs) -> None: + if self.input_file: + with open(self.input_file, "r") as f: + data = f.read() + else: + data = input("Type configuration to validate: ") + + try: + set_global_validation_context(Context(Path(self.input_file).parent, self.strict)) + KresConfig(try_to_parse(data)) + reset_global_validation_context() + except (DataParsingError, DataValidationError) as e: + print(e, file=sys.stderr) + sys.exit(1) diff --git a/python/knot_resolver/client/main.py b/python/knot_resolver/client/main.py new file mode 100644 index 00000000..933da54d --- /dev/null +++ b/python/knot_resolver/client/main.py @@ -0,0 +1,69 @@ +import argparse +import importlib +import os + +from .command import install_commands_parsers +from .client import KresClient, KRES_CLIENT_NAME + + +def auto_import_commands() -> None: + prefix = f"{'.'.join(__name__.split('.')[:-1])}.commands." + for module_name in os.listdir(os.path.dirname(__file__) + "/commands"): + if module_name[-3:] != ".py": + continue + importlib.import_module(f"{prefix}{module_name[:-3]}") + + +def create_main_argument_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + KRES_CLIENT_NAME, + description="Knot Resolver command-line utility that serves as a client for communicating with the Knot Resolver management API." + " The utility also provides tools to work with the resolver's declarative configuration (validate, convert, ...).", + ) + # parser.add_argument( + # "-i", + # "--interactive", + # action="store_true", + # help="Use the utility in interactive mode.", + # default=False, + # required=False, + # ) + config_or_socket = parser.add_mutually_exclusive_group() + config_or_socket.add_argument( + "-s", + "--socket", + action="store", + type=str, + help="Optional, path to the resolver's management API, unix-domain socket, or network interface." + " Cannot be used together with '--config'.", + default=[], + nargs=1, + required=False, + ) + config_or_socket.add_argument( + "-c", + "--config", + action="store", + type=str, + help="Optional, path to the resolver's declarative configuration to retrieve the management API configuration." + " Cannot be used together with '--socket'.", + default=[], + nargs=1, + required=False, + ) + return parser + + +def main() -> None: + auto_import_commands() + parser = create_main_argument_parser() + install_commands_parsers(parser) + + namespace = parser.parse_args() + client = KresClient(namespace, parser) + client.execute() + + # if namespace.interactive or len(vars(namespace)) == 2: + # client.interactive() + # else: + # client.execute() diff --git a/python/knot_resolver/compat/__init__.py b/python/knot_resolver/compat/__init__.py new file mode 100644 index 00000000..53993f6c --- /dev/null +++ b/python/knot_resolver/compat/__init__.py @@ -0,0 +1,3 @@ +from . import asyncio + +__all__ = ["asyncio"] diff --git a/python/knot_resolver/compat/asyncio.py b/python/knot_resolver/compat/asyncio.py new file mode 100644 index 00000000..9e10e6c6 --- /dev/null +++ b/python/knot_resolver/compat/asyncio.py @@ -0,0 +1,128 @@ +# We disable pylint checks, because it can't find methods in newer Python versions. +# +# pylint: disable=no-member + +import asyncio +import functools +import logging +import sys +from asyncio import AbstractEventLoop, coroutines, events, tasks +from typing import Any, Awaitable, Callable, Coroutine, Optional, TypeVar + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +async def to_thread(func: Callable[..., T], *args: Any, **kwargs: Any) -> T: + # version 3.9 and higher, call directly + if sys.version_info.major >= 3 and sys.version_info.minor >= 9: + return await asyncio.to_thread(func, *args, **kwargs) # type: ignore[attr-defined] + + # earlier versions, run with default executor + else: + loop = asyncio.get_event_loop() + pfunc = functools.partial(func, *args, **kwargs) + res = await loop.run_in_executor(None, pfunc) + return res + + +def async_in_a_thread(func: Callable[..., T]) -> Callable[..., Coroutine[None, None, T]]: + async def wrapper(*args: Any, **kwargs: Any) -> T: + return await to_thread(func, *args, **kwargs) + + return wrapper + + +def create_task(coro: Awaitable[T], name: Optional[str] = None) -> "asyncio.Task[T]": + # version 3.8 and higher, call directly + if sys.version_info.major >= 3 and sys.version_info.minor >= 8: + # pylint: disable=unexpected-keyword-arg + return asyncio.create_task(coro, name=name) # type: ignore[attr-defined,arg-type,call-arg] + + # version 3.7 and higher, call directly without the name argument + if sys.version_info.major >= 3 and sys.version_info.minor >= 8: + return asyncio.create_task(coro) # type: ignore[attr-defined,arg-type] + + # earlier versions, use older function + else: + return asyncio.ensure_future(coro) + + +def is_event_loop_running() -> bool: + loop = events._get_running_loop() # pylint: disable=protected-access + return loop is not None and loop.is_running() + + +def run(coro: Awaitable[T], debug: Optional[bool] = None) -> T: + # Adapted version of this: + # https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py#L8 + + # version 3.7 and higher, call directly + # disabled due to incompatibilities + if sys.version_info.major >= 3 and sys.version_info.minor >= 7: + return asyncio.run(coro, debug=debug) # type: ignore[attr-defined,arg-type] + + # earlier versions, use backported version of the function + if events._get_running_loop() is not None: # pylint: disable=protected-access + raise RuntimeError("asyncio.run() cannot be called from a running event loop") + + if not coroutines.iscoroutine(coro): + raise ValueError(f"a coroutine was expected, got {repr(coro)}") + + loop = events.new_event_loop() + try: + events.set_event_loop(loop) + if debug is not None: + loop.set_debug(debug) + return loop.run_until_complete(coro) + finally: + try: + _cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + if hasattr(loop, "shutdown_default_executor"): + loop.run_until_complete(loop.shutdown_default_executor()) # type: ignore[attr-defined] + finally: + events.set_event_loop(None) + loop.close() + + +def _cancel_all_tasks(loop: AbstractEventLoop) -> None: + # Backported from: + # https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py#L55-L74 + # + to_cancel = tasks.all_tasks(loop) + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + if sys.version_info.minor >= 7: + # since 3.7, the loop argument is implicitely the running loop + # since 3.10, the loop argument is removed + loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True)) + else: + loop.run_until_complete(tasks.gather(*to_cancel, loop=loop, return_exceptions=True)) # type: ignore[call-overload] + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) + + +def add_async_signal_handler(signal: int, callback: Callable[[], Coroutine[Any, Any, None]]) -> None: + loop = asyncio.get_event_loop() + loop.add_signal_handler(signal, lambda: create_task(callback())) + + +def remove_signal_handler(signal: int) -> bool: + loop = asyncio.get_event_loop() + return loop.remove_signal_handler(signal) diff --git a/python/knot_resolver/controller/__init__.py b/python/knot_resolver/controller/__init__.py new file mode 100644 index 00000000..2398347e --- /dev/null +++ b/python/knot_resolver/controller/__init__.py @@ -0,0 +1,94 @@ +""" +This file contains autodetection logic for available subprocess controllers. Because we have to catch errors +from imports, they are located in functions which are invoked at the end of this file. + +We supported multiple subprocess controllers while developing it. It now all converged onto just supervisord. +The interface however remains so that different controllers can be added in the future. +""" + +# pylint: disable=import-outside-toplevel + +import asyncio +import logging +from typing import List, Optional + +from knot_resolver.datamodel.config_schema import KresConfig +from knot_resolver.controller.interface import SubprocessController + +logger = logging.getLogger(__name__) + +""" +List of all subprocess controllers that are available in order of priority. +It is filled dynamically based on available modules that do not fail to import. +""" +_registered_controllers: List[SubprocessController] = [] + + +def try_supervisord(): + """ + Attempt to load supervisord controllers. + """ + try: + from knot_resolver.controller.supervisord import SupervisordSubprocessController + + _registered_controllers.append(SupervisordSubprocessController()) + except ImportError: + logger.error("Failed to import modules related to supervisord service manager", exc_info=True) + + +async def get_best_controller_implementation(config: KresConfig) -> SubprocessController: + logger.info("Starting service manager auto-selection...") + + if len(_registered_controllers) == 0: + logger.error("No controllers are available! Did you install all dependencies?") + raise LookupError("No service managers available!") + + # check all controllers concurrently + res = await asyncio.gather(*(cont.is_controller_available(config) for cont in _registered_controllers)) + logger.info( + "Available subprocess controllers are %s", + str(tuple((str(c) for r, c in zip(res, _registered_controllers) if r))), + ) + + # take the first one on the list which is available + for avail, controller in zip(res, _registered_controllers): + if avail: + logger.info("Selected controller '%s'", str(controller)) + return controller + + # or fail + raise LookupError("Can't find any available service manager!") + + +def list_controller_names() -> List[str]: + """ + Returns a list of names of registered controllers. The listed controllers are not necessarly functional. + """ + + return [str(controller) for controller in sorted(_registered_controllers, key=str)] + + +async def get_controller_by_name(config: KresConfig, name: str) -> SubprocessController: + logger.debug("Subprocess controller selected manualy by the user, testing feasibility...") + + controller: Optional[SubprocessController] = None + for c in sorted(_registered_controllers, key=str): + if str(c).startswith(name): + if str(c) != name: + logger.debug("Assuming '%s' is a shortcut for '%s'", name, str(c)) + controller = c + break + + if controller is None: + logger.error("Subprocess controller with name '%s' was not found", name) + raise LookupError(f"No subprocess controller named '{name}' found") + + if await controller.is_controller_available(config): + logger.info("Selected controller '%s'", str(controller)) + return controller + else: + raise LookupError("The selected subprocess controller is not available for use on this system.") + + +# run the imports on module load +try_supervisord() diff --git a/python/knot_resolver/controller/interface.py b/python/knot_resolver/controller/interface.py new file mode 100644 index 00000000..02bbaa50 --- /dev/null +++ b/python/knot_resolver/controller/interface.py @@ -0,0 +1,296 @@ +import asyncio +import itertools +import json +import logging +import struct +import sys +from abc import ABC, abstractmethod # pylint: disable=no-name-in-module +from enum import Enum, auto +from pathlib import Path +from typing import Dict, Iterable, Optional, Type, TypeVar +from weakref import WeakValueDictionary + +from knot_resolver.manager.constants import kresd_config_file, policy_loader_config_file +from knot_resolver.datamodel.config_schema import KresConfig +from knot_resolver.manager.exceptions import SubprocessControllerException +from knot_resolver.controller.registered_workers import register_worker, unregister_worker +from knot_resolver.utils.async_utils import writefile + +logger = logging.getLogger(__name__) + + +class SubprocessType(Enum): + KRESD = auto() + POLICY_LOADER = auto() + GC = auto() + + +class SubprocessStatus(Enum): + RUNNING = auto() + FATAL = auto() + EXITED = auto() + UNKNOWN = auto() + + +T = TypeVar("T", bound="KresID") + + +class KresID: + """ + ID object used for identifying subprocesses. + """ + + _used: "Dict[SubprocessType, WeakValueDictionary[int, KresID]]" = {k: WeakValueDictionary() for k in SubprocessType} + + @classmethod + def alloc(cls: Type[T], typ: SubprocessType) -> T: + # find free ID closest to zero + for i in itertools.count(start=0, step=1): + if i not in cls._used[typ]: + res = cls.new(typ, i) + return res + + raise RuntimeError("Reached an end of an infinite loop. How?") + + @classmethod + def new(cls: "Type[T]", typ: SubprocessType, n: int) -> "T": + if n in cls._used[typ]: + # Ignoring typing here, because I can't find a way how to make the _used dict + # typed based on subclass. I am not even sure that it's different between subclasses, + # it's probably still the same dict. But we don't really care about it + return cls._used[typ][n] # type: ignore + else: + val = cls(typ, n, _i_know_what_i_am_doing=True) + cls._used[typ][n] = val + return val + + def __init__(self, typ: SubprocessType, n: int, _i_know_what_i_am_doing: bool = False): + if not _i_know_what_i_am_doing: + raise RuntimeError("Don't do this. You seem to have no idea what it does") + + self._id = n + self._type = typ + + @property + def subprocess_type(self) -> SubprocessType: + return self._type + + def __repr__(self) -> str: + return f"KresID({self})" + + def __hash__(self) -> int: + return self._id + + def __eq__(self, o: object) -> bool: + if isinstance(o, KresID): + return self._type == o._type and self._id == o._id + return False + + def __str__(self) -> str: + """ + Returns string representation of the ID usable directly in the underlying service manager + """ + raise NotImplementedError() + + @staticmethod + def from_string(val: str) -> "KresID": + """ + Inverse of __str__ + """ + raise NotImplementedError() + + def __int__(self) -> int: + return self._id + + +class Subprocess(ABC): + """ + One SubprocessInstance corresponds to one manager's subprocess + """ + + def __init__(self, config: KresConfig, kresid: KresID) -> None: + self._id = kresid + self._config = config + self._registered_worker: bool = False + + async def start(self, new_config: Optional[KresConfig] = None) -> None: + if new_config: + self._config = new_config + + config_file: Optional[Path] = None + if self.type is SubprocessType.KRESD: + config_lua = self._config.render_lua() + config_file = kresd_config_file(self._config, self.id) + await writefile(config_file, config_lua) + elif self.type is SubprocessType.POLICY_LOADER: + config_lua = self._config.render_lua_policy() + config_file = policy_loader_config_file(self._config) + await writefile(config_file, config_lua) + + try: + await self._start() + if self.type is SubprocessType.KRESD: + register_worker(self) + self._registered_worker = True + except SubprocessControllerException as e: + if config_file: + config_file.unlink() + raise e + + async def apply_new_config(self, new_config: KresConfig) -> None: + self._config = new_config + + # update config file + logger.debug(f"Writing config file for {self.id}") + + config_file: Optional[Path] = None + if self.type is SubprocessType.KRESD: + config_lua = self._config.render_lua() + config_file = kresd_config_file(self._config, self.id) + await writefile(config_file, config_lua) + elif self.type is SubprocessType.POLICY_LOADER: + config_lua = self._config.render_lua_policy() + config_file = policy_loader_config_file(self._config) + await writefile(config_file, config_lua) + + # update runtime status + logger.debug(f"Restarting {self.id}") + await self._restart() + + async def stop(self) -> None: + if self._registered_worker: + unregister_worker(self) + await self._stop() + await self.cleanup() + + async def cleanup(self) -> None: + """ + Remove temporary files and all traces of this instance running. It is NOT SAFE to call this while + the kresd is running, because it will break automatic restarts (at the very least). + """ + + if self.type is SubprocessType.KRESD: + config_file = kresd_config_file(self._config, self.id) + config_file.unlink() + elif self.type is SubprocessType.POLICY_LOADER: + config_file = policy_loader_config_file(self._config) + config_file.unlink() + + def __eq__(self, o: object) -> bool: + return isinstance(o, type(self)) and o.type == self.type and o.id == self.id + + def __hash__(self) -> int: + return hash(type(self)) ^ hash(self.type) ^ hash(self.id) + + @abstractmethod + async def _start(self) -> None: + pass + + @abstractmethod + async def _stop(self) -> None: + pass + + @abstractmethod + async def _restart(self) -> None: + pass + + @abstractmethod + def status(self) -> SubprocessStatus: + pass + + @property + def type(self) -> SubprocessType: + return self.id.subprocess_type + + @property + def id(self) -> KresID: + return self._id + + async def command(self, cmd: str) -> object: + if not self._registered_worker: + raise RuntimeError("the command cannot be sent to a process other than the kresd worker") + + reader: asyncio.StreamReader + writer: Optional[asyncio.StreamWriter] = None + + try: + reader, writer = await asyncio.open_unix_connection(f"./control/{int(self.id)}") + + # drop prompt + _ = await reader.read(2) + + # switch to JSON mode + writer.write("__json\n".encode("utf8")) + + # write command + writer.write(cmd.encode("utf8")) + writer.write(b"\n") + await writer.drain() + + # read result + (msg_len,) = struct.unpack(">I", await reader.read(4)) + result_bytes = await reader.readexactly(msg_len) + return json.loads(result_bytes.decode("utf8")) + + finally: + if writer is not None: + writer.close() + + # proper closing of the socket is only implemented in later versions of python + if sys.version_info.minor >= 7: + await writer.wait_closed() # type: ignore + + +class SubprocessController(ABC): + """ + The common Subprocess Controller interface. This is what KresManager requires and what has to be implemented by all + controllers. + """ + + @abstractmethod + async def is_controller_available(self, config: KresConfig) -> bool: + """ + Returns bool, whether the controller is available with the given config + """ + + @abstractmethod + async def initialize_controller(self, config: KresConfig) -> None: + """ + Should be called when we want to really start using the controller with a specific configuration + """ + + @abstractmethod + async def get_all_running_instances(self) -> Iterable[Subprocess]: + """ + + Must NOT be called before initialize_controller() + """ + + @abstractmethod + async def shutdown_controller(self) -> None: + """ + Called when the manager is gracefully shutting down. Allows us to stop + the service manager process or simply cleanup, so that we don't reuse + the same resources in a new run. + + Must NOT be called before initialize_controller() + """ + + @abstractmethod + async def create_subprocess(self, subprocess_config: KresConfig, subprocess_type: SubprocessType) -> Subprocess: + """ + Return a Subprocess object which can be operated on. The subprocess is not + started or in any way active after this call. That has to be performaed manually + using the returned object itself. + + Must NOT be called before initialize_controller() + """ + + @abstractmethod + async def get_subprocess_status(self) -> Dict[KresID, SubprocessStatus]: + """ + Get a status of running subprocesses as seen by the controller. This method actively polls + for information. + + Must NOT be called before initialize_controller() + """ diff --git a/python/knot_resolver/controller/registered_workers.py b/python/knot_resolver/controller/registered_workers.py new file mode 100644 index 00000000..2d3176c3 --- /dev/null +++ b/python/knot_resolver/controller/registered_workers.py @@ -0,0 +1,49 @@ +import asyncio +import logging +from typing import TYPE_CHECKING, Dict, List, Tuple + +from knot_resolver.manager.exceptions import SubprocessControllerException + +if TYPE_CHECKING: + from knot_resolver.controller.interface import KresID, Subprocess + + +logger = logging.getLogger(__name__) + + +_REGISTERED_WORKERS: "Dict[KresID, Subprocess]" = {} + + +def get_registered_workers_kresids() -> "List[KresID]": + return list(_REGISTERED_WORKERS.keys()) + + +async def command_single_registered_worker(cmd: str) -> "Tuple[KresID, object]": + for sub in _REGISTERED_WORKERS.values(): + return sub.id, await sub.command(cmd) + raise SubprocessControllerException( + "Unable to execute the command. There is no kresd worker running to execute the command." + "Try start/restart the resolver.", + ) + + +async def command_registered_workers(cmd: str) -> "Dict[KresID, object]": + async def single_pair(sub: "Subprocess") -> "Tuple[KresID, object]": + return sub.id, await sub.command(cmd) + + pairs = await asyncio.gather(*(single_pair(inst) for inst in _REGISTERED_WORKERS.values())) + return dict(pairs) + + +def unregister_worker(subprocess: "Subprocess") -> None: + """ + Unregister kresd worker "Subprocess" from the list. + """ + del _REGISTERED_WORKERS[subprocess.id] + + +def register_worker(subprocess: "Subprocess") -> None: + """ + Register kresd worker "Subprocess" on the list. + """ + _REGISTERED_WORKERS[subprocess.id] = subprocess diff --git a/python/knot_resolver/controller/supervisord/__init__.py b/python/knot_resolver/controller/supervisord/__init__.py new file mode 100644 index 00000000..592b76be --- /dev/null +++ b/python/knot_resolver/controller/supervisord/__init__.py @@ -0,0 +1,281 @@ +import logging +from os import kill # pylint: disable=[no-name-in-module] +from pathlib import Path +from typing import Any, Dict, Iterable, NoReturn, Optional, Union, cast +from xmlrpc.client import Fault, ServerProxy + +import supervisor.xmlrpc # type: ignore[import] + +from knot_resolver.compat.asyncio import async_in_a_thread +from knot_resolver.manager.constants import supervisord_config_file, supervisord_pid_file, supervisord_sock_file +from knot_resolver.datamodel.config_schema import KresConfig +from knot_resolver.manager.exceptions import CancelStartupExecInsteadException, SubprocessControllerException +from knot_resolver.controller.interface import ( + KresID, + Subprocess, + SubprocessController, + SubprocessStatus, + SubprocessType, +) +from knot_resolver.controller.supervisord.config_file import SupervisordKresID, write_config_file +from knot_resolver.utils import which +from knot_resolver.utils.async_utils import call, readfile + +logger = logging.getLogger(__name__) + + +async def _start_supervisord(config: KresConfig) -> None: + logger.debug("Writing supervisord config") + await write_config_file(config) + logger.debug("Starting supervisord") + res = await call(["supervisord", "--configuration", str(supervisord_config_file(config).absolute())]) + if res != 0: + raise SubprocessControllerException(f"Supervisord exited with exit code {res}") + + +async def _exec_supervisord(config: KresConfig) -> NoReturn: + logger.debug("Writing supervisord config") + await write_config_file(config) + logger.debug("Execing supervisord") + raise CancelStartupExecInsteadException( + [ + str(which.which("supervisord")), + "supervisord", + "--configuration", + str(supervisord_config_file(config).absolute()), + ] + ) + + +async def _reload_supervisord(config: KresConfig) -> None: + await write_config_file(config) + try: + supervisord = _create_supervisord_proxy(config) + supervisord.reloadConfig() + except Fault as e: + raise SubprocessControllerException("supervisord reload failed") from e + + +@async_in_a_thread +def _stop_supervisord(config: KresConfig) -> None: + supervisord = _create_supervisord_proxy(config) + # pid = supervisord.getPID() + try: + # we might be trying to shut down supervisord at a moment, when it's waiting + # for us to stop. Therefore, this shutdown request for supervisord might + # die and it's not a problem. + supervisord.shutdown() + except Fault as e: + if e.faultCode == 6 and e.faultString == "SHUTDOWN_STATE": + # supervisord is already stopping, so it's fine + pass + else: + # something wrong happened, let's be loud about it + raise + + # We could remove the configuration, but there is actually no specific need to do so. + # If we leave it behind, someone might find it and use it to start us from scratch again, + # which is perfectly fine. + # supervisord_config_file(config).unlink() + + +async def _is_supervisord_available() -> bool: + # yes, it is! The code in this file wouldn't be running without it due to imports :) + + # so let's just check that we can find supervisord and supervisorctl binaries + try: + which.which("supervisord") + which.which("supervisorctl") + except RuntimeError: + logger.error("Failed to find supervisord or supervisorctl executables in $PATH") + return False + + return True + + +async def _get_supervisord_pid(config: KresConfig) -> Optional[int]: + if not Path(supervisord_pid_file(config)).exists(): + return None + + return int(await readfile(supervisord_pid_file(config))) + + +def _is_process_runinng(pid: int) -> bool: + try: + # kill with signal 0 is a safe way to test that a process exists + kill(pid, 0) + return True + except ProcessLookupError: + return False + + +async def _is_supervisord_running(config: KresConfig) -> bool: + pid = await _get_supervisord_pid(config) + if pid is None: + return False + elif not _is_process_runinng(pid): + supervisord_pid_file(config).unlink() + return False + else: + return True + + +def _create_proxy(config: KresConfig) -> ServerProxy: + return ServerProxy( + "http://127.0.0.1", + transport=supervisor.xmlrpc.SupervisorTransport( + None, None, serverurl="unix://" + str(supervisord_sock_file(config)) + ), + ) + + +def _create_supervisord_proxy(config: KresConfig) -> Any: + proxy = _create_proxy(config) + return getattr(proxy, "supervisor") + + +def _create_fast_proxy(config: KresConfig) -> Any: + proxy = _create_proxy(config) + return getattr(proxy, "fast") + + +def _convert_subprocess_status(proc: Any) -> SubprocessStatus: + conversion_tbl = { + # "STOPPED": None, # filtered out elsewhere + "STARTING": SubprocessStatus.RUNNING, + "RUNNING": SubprocessStatus.RUNNING, + "BACKOFF": SubprocessStatus.RUNNING, + "STOPPING": SubprocessStatus.RUNNING, + "EXITED": SubprocessStatus.EXITED, + "FATAL": SubprocessStatus.FATAL, + "UNKNOWN": SubprocessStatus.UNKNOWN, + } + + if proc["statename"] in conversion_tbl: + status = conversion_tbl[proc["statename"]] + else: + logger.warning(f"Unknown supervisord process state {proc['statename']}") + status = SubprocessStatus.UNKNOWN + return status + + +def _list_running_subprocesses(config: KresConfig) -> Dict[SupervisordKresID, SubprocessStatus]: + try: + supervisord = _create_supervisord_proxy(config) + processes: Any = supervisord.getAllProcessInfo() + except Fault as e: + raise SubprocessControllerException(f"failed to get info from all running processes: {e}") from e + + # there will be a manager process as well, but we don't want to report anything on ourselves + processes = [pr for pr in processes if pr["name"] != "manager"] + + # convert all the names + return { + SupervisordKresID.from_string(f"{pr['group']}:{pr['name']}"): _convert_subprocess_status(pr) + for pr in processes + if pr["statename"] != "STOPPED" + } + + +class SupervisordSubprocess(Subprocess): + def __init__( + self, + config: KresConfig, + controller: "SupervisordSubprocessController", + base_id: Union[SubprocessType, SupervisordKresID], + ): + if isinstance(base_id, SubprocessType): + super().__init__(config, SupervisordKresID.alloc(base_id)) + else: + super().__init__(config, base_id) + self._controller: "SupervisordSubprocessController" = controller + + @property + def name(self): + return str(self.id) + + def status(self) -> SubprocessStatus: + try: + supervisord = _create_supervisord_proxy(self._config) + status = supervisord.getProcessInfo(self.name) + except Fault as e: + raise SubprocessControllerException(f"failed to get status from '{self.id}' process: {e}") from e + return _convert_subprocess_status(status) + + @async_in_a_thread + def _start(self) -> None: + # +1 for canary process (same as in config_file.py) + assert int(self.id) <= int(self._config.max_workers) + 1, "trying to spawn more than allowed limit of workers" + try: + supervisord = _create_fast_proxy(self._config) + supervisord.startProcess(self.name) + except Fault as e: + raise SubprocessControllerException(f"failed to start '{self.id}'") from e + + @async_in_a_thread + def _stop(self) -> None: + supervisord = _create_supervisord_proxy(self._config) + supervisord.stopProcess(self.name) + + @async_in_a_thread + def _restart(self) -> None: + supervisord = _create_supervisord_proxy(self._config) + supervisord.stopProcess(self.name) + fast = _create_fast_proxy(self._config) + fast.startProcess(self.name) + + def get_used_config(self) -> KresConfig: + return self._config + + +class SupervisordSubprocessController(SubprocessController): + def __init__(self): # pylint: disable=super-init-not-called + self._controller_config: Optional[KresConfig] = None + + def __str__(self): + return "supervisord" + + async def is_controller_available(self, config: KresConfig) -> bool: + res = await _is_supervisord_available() + if not res: + logger.info("Failed to find usable supervisord.") + + logger.debug("Detection - supervisord controller is available for use") + return res + + async def get_all_running_instances(self) -> Iterable[Subprocess]: + assert self._controller_config is not None + + if await _is_supervisord_running(self._controller_config): + states = _list_running_subprocesses(self._controller_config) + return [ + SupervisordSubprocess(self._controller_config, self, id_) + for id_ in states + if states[id_] == SubprocessStatus.RUNNING + ] + else: + return [] + + async def initialize_controller(self, config: KresConfig) -> None: + self._controller_config = config + + if not await _is_supervisord_running(config): + logger.info( + "We want supervisord to restart us when needed, we will therefore exec() it and let it start us again." + ) + await _exec_supervisord(config) + else: + logger.info("Supervisord is already running, we will just update its config...") + await _reload_supervisord(config) + + async def shutdown_controller(self) -> None: + assert self._controller_config is not None + await _stop_supervisord(self._controller_config) + + async def create_subprocess(self, subprocess_config: KresConfig, subprocess_type: SubprocessType) -> Subprocess: + return SupervisordSubprocess(subprocess_config, self, subprocess_type) + + @async_in_a_thread + def get_subprocess_status(self) -> Dict[KresID, SubprocessStatus]: + assert self._controller_config is not None + return cast(Dict[KresID, SubprocessStatus], _list_running_subprocesses(self._controller_config)) diff --git a/python/knot_resolver/controller/supervisord/config_file.py b/python/knot_resolver/controller/supervisord/config_file.py new file mode 100644 index 00000000..d9a79f9e --- /dev/null +++ b/python/knot_resolver/controller/supervisord/config_file.py @@ -0,0 +1,197 @@ +import logging +import os + +from dataclasses import dataclass +from pathlib import Path +from jinja2 import Template +from typing_extensions import Literal + +from knot_resolver.manager.constants import ( + kres_gc_executable, + kresd_cache_dir, + kresd_config_file_supervisord_pattern, + kresd_executable, + policy_loader_config_file, + supervisord_config_file, + supervisord_config_file_tmp, + supervisord_pid_file, + supervisord_sock_file, + supervisord_subprocess_log_dir, + user_constants, +) +from knot_resolver.datamodel.config_schema import KresConfig +from knot_resolver.datamodel.logging_schema import LogTargetEnum +from knot_resolver.controller.interface import KresID, SubprocessType +from knot_resolver.utils.async_utils import read_resource, writefile + +logger = logging.getLogger(__name__) + + +class SupervisordKresID(KresID): + # WARNING: be really careful with renaming. If the naming schema is changing, + # we should be able to parse the old one as well, otherwise updating manager will + # cause weird behavior + + @staticmethod + def from_string(val: str) -> "SupervisordKresID": + # the double name is checked because thats how we read it from supervisord + if val in ("cache-gc", "cache-gc:cache-gc"): + return SupervisordKresID.new(SubprocessType.GC, 0) + elif val in ("policy-loader", "policy-loader:policy-loader"): + return SupervisordKresID.new(SubprocessType.POLICY_LOADER, 0) + else: + val = val.replace("kresd:kresd", "") + return SupervisordKresID.new(SubprocessType.KRESD, int(val)) + + def __str__(self) -> str: + if self.subprocess_type is SubprocessType.GC: + return "cache-gc" + elif self.subprocess_type is SubprocessType.POLICY_LOADER: + return "policy-loader" + elif self.subprocess_type is SubprocessType.KRESD: + return f"kresd:kresd{self._id}" + else: + raise RuntimeError(f"Unexpected subprocess type {self.subprocess_type}") + + +def kres_cache_gc_args(config: KresConfig) -> str: + args = "" + + if config.logging.level == "debug" or (config.logging.groups and "cache-gc" in config.logging.groups): + args += " -v" + + gc_config = config.cache.garbage_collector + if gc_config: + args += ( + f" -d {gc_config.interval.millis()}" + f" -u {gc_config.threshold}" + f" -f {gc_config.release}" + f" -l {gc_config.rw_deletes}" + f" -L {gc_config.rw_reads}" + f" -t {gc_config.temp_keys_space.mbytes()}" + f" -m {gc_config.rw_duration.micros()}" + f" -w {gc_config.rw_delay.micros()}" + ) + if gc_config.dry_run: + args += " -n" + return args + raise ValueError("missing configuration for the cache garbage collector") + + +@dataclass +class ProcessTypeConfig: + """ + Data structure holding data for supervisord config template + """ + + logfile: Path + workdir: str + command: str + environment: str + max_procs: int = 1 + + @staticmethod + def create_gc_config(config: KresConfig) -> "ProcessTypeConfig": + cwd = str(os.getcwd()) + return ProcessTypeConfig( # type: ignore[call-arg] + logfile=supervisord_subprocess_log_dir(config) / "gc.log", + workdir=cwd, + command=f"{kres_gc_executable()} -c {kresd_cache_dir(config)}{kres_cache_gc_args(config)}", + environment="", + ) + + @staticmethod + def create_policy_loader_config(config: KresConfig) -> "ProcessTypeConfig": + cwd = str(os.getcwd()) + return ProcessTypeConfig( # type: ignore[call-arg] + logfile=supervisord_subprocess_log_dir(config) / "policy-loader.log", + workdir=cwd, + command=f"{kresd_executable()} -c {(policy_loader_config_file(config))} -c - -n", + environment="X-SUPERVISORD-TYPE=notify", + ) + + @staticmethod + def create_kresd_config(config: KresConfig) -> "ProcessTypeConfig": + cwd = str(os.getcwd()) + return ProcessTypeConfig( # type: ignore[call-arg] + logfile=supervisord_subprocess_log_dir(config) / "kresd%(process_num)d.log", + workdir=cwd, + command=f"{kresd_executable()} -c {kresd_config_file_supervisord_pattern(config)} -n", + environment='SYSTEMD_INSTANCE="%(process_num)d",X-SUPERVISORD-TYPE=notify', + max_procs=int(config.max_workers) + 1, # +1 for the canary process + ) + + @staticmethod + def create_manager_config(_config: KresConfig) -> "ProcessTypeConfig": + # read original command from /proc + with open("/proc/self/cmdline", "rb") as f: + args = [s.decode("utf-8") for s in f.read()[:-1].split(b"\0")] + + # insert debugger when asked + if os.environ.get("KRES_DEBUG_MANAGER"): + logger.warning("Injecting debugger into the supervisord config") + # the args array looks like this: + # [PYTHON_PATH, "-m", "knot_resolver", ...] + args = args[:1] + ["-m", "debugpy", "--listen", "0.0.0.0:5678", "--wait-for-client"] + args[2:] + + cmd = '"' + '" "'.join(args) + '"' + + return ProcessTypeConfig( # type: ignore[call-arg] + workdir=user_constants().working_directory_on_startup, + command=cmd, + environment="X-SUPERVISORD-TYPE=notify", + logfile=Path(""), # this will be ignored + ) + + +@dataclass +class SupervisordConfig: + unix_http_server: Path + pid_file: Path + workdir: str + logfile: Path + loglevel: Literal["critical", "error", "warn", "info", "debug", "trace", "blather"] + target: LogTargetEnum + + @staticmethod + def create(config: KresConfig) -> "SupervisordConfig": + # determine the correct logging level + if config.logging.groups and "supervisord" in config.logging.groups: + loglevel = "info" + else: + loglevel = { + "crit": "critical", + "err": "error", + "warning": "warn", + "notice": "warn", + "info": "info", + "debug": "debug", + }[config.logging.level] + cwd = str(os.getcwd()) + return SupervisordConfig( # type: ignore[call-arg] + unix_http_server=supervisord_sock_file(config), + pid_file=supervisord_pid_file(config), + workdir=cwd, + logfile=Path("syslog" if config.logging.target == "syslog" else "/dev/null"), + loglevel=loglevel, # type: ignore[arg-type] + target=config.logging.target, + ) + + +async def write_config_file(config: KresConfig) -> None: + if not supervisord_subprocess_log_dir(config).exists(): + supervisord_subprocess_log_dir(config).mkdir(exist_ok=True) + + template = await read_resource(__package__, "supervisord.conf.j2") + assert template is not None + template = template.decode("utf8") + config_string = Template(template).render( + gc=ProcessTypeConfig.create_gc_config(config), + loader=ProcessTypeConfig.create_policy_loader_config(config), + kresd=ProcessTypeConfig.create_kresd_config(config), + manager=ProcessTypeConfig.create_manager_config(config), + config=SupervisordConfig.create(config), + ) + await writefile(supervisord_config_file_tmp(config), config_string) + # atomically replace (we don't technically need this right now, but better safe then sorry) + os.rename(supervisord_config_file_tmp(config), supervisord_config_file(config)) diff --git a/python/knot_resolver/controller/supervisord/plugin/fast_rpcinterface.py b/python/knot_resolver/controller/supervisord/plugin/fast_rpcinterface.py new file mode 100644 index 00000000..c3834784 --- /dev/null +++ b/python/knot_resolver/controller/supervisord/plugin/fast_rpcinterface.py @@ -0,0 +1,173 @@ +# type: ignore +# pylint: skip-file + +""" +This file is modified version of supervisord's source code: +https://github.com/Supervisor/supervisor/blob/5d9c39619e2e7e7fca33c890cb2a9f2d3d0ab762/supervisor/rpcinterface.py + +The changes made are: + + - removed everything that we do not need, reformatted to fit our code stylepo (2022-06-24) + - made startProcess faster by setting delay to 0 (2022-06-24) + + +The original supervisord licence follows: +-------------------------------------------------------------------- + +Supervisor is licensed under the following license: + + A copyright notice accompanies this license document that identifies + the copyright holders. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + 1. Redistributions in source code must retain the accompanying + copyright notice, this list of conditions, and the following + disclaimer. + + 2. Redistributions in binary form must reproduce the accompanying + copyright notice, this list of conditions, and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + 3. Names of the copyright holders must not be used to endorse or + promote products derived from this software without prior + written permission from the copyright holders. + + 4. If any files are modified, you must cause the modified files to + carry prominent notices stating that you changed the files and + the date of any change. + + Disclaimer + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND + ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDERS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON + ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR + TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF + THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + SUCH DAMAGE. +""" + +from supervisor.http import NOT_DONE_YET +from supervisor.options import BadCommand, NoPermission, NotExecutable, NotFound, split_namespec +from supervisor.states import RUNNING_STATES, ProcessStates, SupervisorStates +from supervisor.xmlrpc import Faults, RPCError + + +class SupervisorNamespaceRPCInterface: + def __init__(self, supervisord): + self.supervisord = supervisord + + def _update(self, text): + self.update_text = text # for unit tests, mainly + if isinstance(self.supervisord.options.mood, int) and self.supervisord.options.mood < SupervisorStates.RUNNING: + raise RPCError(Faults.SHUTDOWN_STATE) + + # RPC API methods + + def _getGroupAndProcess(self, name): + # get process to start from name + group_name, process_name = split_namespec(name) + + group = self.supervisord.process_groups.get(group_name) + if group is None: + raise RPCError(Faults.BAD_NAME, name) + + if process_name is None: + return group, None + + process = group.processes.get(process_name) + if process is None: + raise RPCError(Faults.BAD_NAME, name) + + return group, process + + def startProcess(self, name, wait=True): + """Start a process + + @param string name Process name (or ``group:name``, or ``group:*``) + @param boolean wait Wait for process to be fully started + @return boolean result Always true unless error + + """ + self._update("startProcess") + group, process = self._getGroupAndProcess(name) + if process is None: + group_name, process_name = split_namespec(name) + return self.startProcessGroup(group_name, wait) + + # test filespec, don't bother trying to spawn if we know it will + # eventually fail + try: + filename, argv = process.get_execv_args() + except NotFound as why: + raise RPCError(Faults.NO_FILE, why.args[0]) + except (BadCommand, NotExecutable, NoPermission) as why: + raise RPCError(Faults.NOT_EXECUTABLE, why.args[0]) + + if process.get_state() in RUNNING_STATES: + raise RPCError(Faults.ALREADY_STARTED, name) + + if process.get_state() == ProcessStates.UNKNOWN: + raise RPCError(Faults.FAILED, "%s is in an unknown process state" % name) + + process.spawn() + + # We call reap() in order to more quickly obtain the side effects of + # process.finish(), which reap() eventually ends up calling. This + # might be the case if the spawn() was successful but then the process + # died before its startsecs elapsed or it exited with an unexpected + # exit code. In particular, finish() may set spawnerr, which we can + # check and immediately raise an RPCError, avoiding the need to + # defer by returning a callback. + + self.supervisord.reap() + + if process.spawnerr: + raise RPCError(Faults.SPAWN_ERROR, name) + + # We call process.transition() in order to more quickly obtain its + # side effects. In particular, it might set the process' state from + # STARTING->RUNNING if the process has a startsecs==0. + process.transition() + + if wait and process.get_state() != ProcessStates.RUNNING: + # by default, this branch will almost always be hit for processes + # with default startsecs configurations, because the default number + # of startsecs for a process is "1", and the process will not have + # entered the RUNNING state yet even though we've called + # transition() on it. This is because a process is not considered + # RUNNING until it has stayed up > startsecs. + + def onwait(): + if process.spawnerr: + raise RPCError(Faults.SPAWN_ERROR, name) + + state = process.get_state() + + if state not in (ProcessStates.STARTING, ProcessStates.RUNNING): + raise RPCError(Faults.ABNORMAL_TERMINATION, name) + + if state == ProcessStates.RUNNING: + return True + + return NOT_DONE_YET + + onwait.delay = 0 + onwait.rpcinterface = self + return onwait # deferred + + return True + + +# this is not used in code but referenced via an entry point in the conf file +def make_main_rpcinterface(supervisord): + return SupervisorNamespaceRPCInterface(supervisord) diff --git a/python/knot_resolver/controller/supervisord/plugin/manager_integration.py b/python/knot_resolver/controller/supervisord/plugin/manager_integration.py new file mode 100644 index 00000000..2fc8cf94 --- /dev/null +++ b/python/knot_resolver/controller/supervisord/plugin/manager_integration.py @@ -0,0 +1,85 @@ +# type: ignore +# pylint: disable=protected-access +import atexit +import os +import signal +from typing import Any, Optional + +from supervisor.compat import as_string +from supervisor.events import ProcessStateFatalEvent, ProcessStateRunningEvent, ProcessStateStartingEvent, subscribe +from supervisor.options import ServerOptions +from supervisor.process import Subprocess +from supervisor.states import SupervisorStates +from supervisor.supervisord import Supervisor + +from knot_resolver.utils.systemd_notify import systemd_notify + +superd: Optional[Supervisor] = None + + +def check_for_fatal_manager(event: ProcessStateFatalEvent) -> None: + assert superd is not None + + proc: Subprocess = event.process + processname = as_string(proc.config.name) + if processname == "manager": + # stop the whole supervisord gracefully + superd.options.logger.critical("manager process entered FATAL state! Shutting down") + superd.options.mood = SupervisorStates.SHUTDOWN + + # force the interpreter to exit with exit code 1 + atexit.register(lambda: os._exit(1)) + + +def check_for_starting_manager(event: ProcessStateStartingEvent) -> None: + assert superd is not None + + proc: Subprocess = event.process + processname = as_string(proc.config.name) + if processname == "manager": + # manager has sucessfully started, report it upstream + systemd_notify(STATUS="Starting services...") + + +def check_for_runnning_manager(event: ProcessStateRunningEvent) -> None: + assert superd is not None + + proc: Subprocess = event.process + processname = as_string(proc.config.name) + if processname == "manager": + # manager has sucessfully started, report it upstream + systemd_notify(READY="1", STATUS="Ready") + + +def ServerOptions_get_signal(self): + sig = self.signal_receiver.get_signal() + if sig == signal.SIGHUP and superd is not None: + superd.options.logger.info("received SIGHUP, forwarding to the process 'manager'") + manager_pid = superd.process_groups["manager"].processes["manager"].pid + os.kill(manager_pid, signal.SIGHUP) + return None + + return sig + + +def inject(supervisord: Supervisor, **_config: Any) -> Any: # pylint: disable=useless-return + global superd + superd = supervisord + + # This status notification here unsets the env variable $NOTIFY_SOCKET provided by systemd + # and stores it locally. Therefore, it shouldn't clash with $NOTIFY_SOCKET we are providing + # downstream + systemd_notify(STATUS="Initializing supervisord...") + + # register events + subscribe(ProcessStateFatalEvent, check_for_fatal_manager) + subscribe(ProcessStateStartingEvent, check_for_starting_manager) + subscribe(ProcessStateRunningEvent, check_for_runnning_manager) + + # forward SIGHUP to manager + ServerOptions.get_signal = ServerOptions_get_signal + + # this method is called by supervisord when loading the plugin, + # it should return XML-RPC object, which we don't care about + # That's why why are returning just None + return None diff --git a/python/knot_resolver/controller/supervisord/plugin/notifymodule.c b/python/knot_resolver/controller/supervisord/plugin/notifymodule.c new file mode 100644 index 00000000..d56ee7d2 --- /dev/null +++ b/python/knot_resolver/controller/supervisord/plugin/notifymodule.c @@ -0,0 +1,176 @@ +#define PY_SSIZE_T_CLEAN +#include <Python.h> + +#include <stdbool.h> +#include <stdint.h> +#include <stdio.h> +#include <unistd.h> +#include <sys/types.h> +#include <stdlib.h> +#include <string.h> +#include <errno.h> +#include <sys/socket.h> +#include <fcntl.h> +#include <stddef.h> +#include <sys/socket.h> +#include <sys/un.h> + +#define CONTROL_SOCKET_NAME "knot-resolver-control-socket" +#define NOTIFY_SOCKET_NAME "NOTIFY_SOCKET" +#define MODULE_NAME "notify" +#define RECEIVE_BUFFER_SIZE 2048 + +static PyObject *NotifySocketError; + +static PyObject *init_control_socket(PyObject *self, PyObject *args) +{ + /* create socket */ + int controlfd = socket(AF_UNIX, SOCK_DGRAM | SOCK_NONBLOCK, 0); + if (controlfd == -1) { + PyErr_SetFromErrno(NotifySocketError); + return NULL; + } + + /* create address */ + struct sockaddr_un server_addr; + bzero(&server_addr, sizeof(server_addr)); + server_addr.sun_family = AF_UNIX; + server_addr.sun_path[0] = '\0'; // mark it as abstract namespace socket + strcpy(server_addr.sun_path + 1, CONTROL_SOCKET_NAME); + size_t addr_len = offsetof(struct sockaddr_un, sun_path) + + strlen(CONTROL_SOCKET_NAME) + 1; + + /* bind to the address */ + int res = bind(controlfd, (struct sockaddr *)&server_addr, addr_len); + if (res < 0) { + PyErr_SetFromErrno(NotifySocketError); + return NULL; + } + + /* make sure that we are send credentials */ + int data = (int)true; + res = setsockopt(controlfd, SOL_SOCKET, SO_PASSCRED, &data, + sizeof(data)); + if (res < 0) { + PyErr_SetFromErrno(NotifySocketError); + return NULL; + } + + /* store the name of the socket in env to fake systemd */ + char *old_value = getenv(NOTIFY_SOCKET_NAME); + if (old_value != NULL) { + printf("[notify_socket] warning, running under systemd and overwriting $%s\n", + NOTIFY_SOCKET_NAME); + // fixme + } + + res = setenv(NOTIFY_SOCKET_NAME, "@" CONTROL_SOCKET_NAME, 1); + if (res < 0) { + PyErr_SetFromErrno(NotifySocketError); + return NULL; + } + + return PyLong_FromLong((long)controlfd); +} + +static PyObject *handle_control_socket_connection_event(PyObject *self, + PyObject *args) +{ + long controlfd; + if (!PyArg_ParseTuple(args, "i", &controlfd)) + return NULL; + + /* read command assuming it fits and it was sent all at once */ + // prepare space to read filedescriptors + struct msghdr msg; + msg.msg_name = NULL; + msg.msg_namelen = 0; + + // prepare a place to read the actual message + char place_for_data[RECEIVE_BUFFER_SIZE]; + bzero(&place_for_data, sizeof(place_for_data)); + struct iovec iov = { .iov_base = &place_for_data, + .iov_len = sizeof(place_for_data) }; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + char cmsg[CMSG_SPACE(sizeof(struct ucred))]; + msg.msg_control = cmsg; + msg.msg_controllen = sizeof(cmsg); + + /* Receive real plus ancillary data */ + int len = recvmsg(controlfd, &msg, 0); + if (len == -1) { + if (errno == EWOULDBLOCK || errno == EAGAIN) { + Py_RETURN_NONE; + } else { + PyErr_SetFromErrno(NotifySocketError); + return NULL; + } + } + + /* read the sender pid */ + struct cmsghdr *cmsgp = CMSG_FIRSTHDR(&msg); + pid_t pid = -1; + while (cmsgp != NULL) { + if (cmsgp->cmsg_type == SCM_CREDENTIALS) { + if ( + cmsgp->cmsg_len != CMSG_LEN(sizeof(struct ucred)) || + cmsgp->cmsg_level != SOL_SOCKET + ) { + printf("[notify_socket] invalid cmsg data, ignoring\n"); + Py_RETURN_NONE; + } + + struct ucred cred; + memcpy(&cred, CMSG_DATA(cmsgp), sizeof(cred)); + pid = cred.pid; + } + cmsgp = CMSG_NXTHDR(&msg, cmsgp); + } + if (pid == -1) { + printf("[notify_socket] ignoring received data without credentials: %s\n", + place_for_data); + Py_RETURN_NONE; + } + + /* return received data as a tuple (pid, data bytes) */ + return Py_BuildValue("iy", pid, place_for_data); +} + +static PyMethodDef NotifyMethods[] = { + { "init_socket", init_control_socket, METH_VARARGS, + "Init notify socket. Returns it's file descriptor." }, + { "read_message", handle_control_socket_connection_event, METH_VARARGS, + "Reads datagram from notify socket. Returns tuple of PID and received bytes." }, + { NULL, NULL, 0, NULL } /* Sentinel */ +}; + +static struct PyModuleDef notifymodule = { + PyModuleDef_HEAD_INIT, MODULE_NAME, /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + NotifyMethods +}; + +PyMODINIT_FUNC PyInit_notify(void) +{ + PyObject *m; + + m = PyModule_Create(¬ifymodule); + if (m == NULL) + return NULL; + + NotifySocketError = + PyErr_NewException(MODULE_NAME ".error", NULL, NULL); + Py_XINCREF(NotifySocketError); + if (PyModule_AddObject(m, "error", NotifySocketError) < 0) { + Py_XDECREF(NotifySocketError); + Py_CLEAR(NotifySocketError); + Py_DECREF(m); + return NULL; + } + + return m; +} diff --git a/python/knot_resolver/controller/supervisord/plugin/patch_logger.py b/python/knot_resolver/controller/supervisord/plugin/patch_logger.py new file mode 100644 index 00000000..411f232e --- /dev/null +++ b/python/knot_resolver/controller/supervisord/plugin/patch_logger.py @@ -0,0 +1,97 @@ +# type: ignore +# pylint: disable=protected-access + +import os +import sys +import traceback +from typing import Any + +from supervisor.dispatchers import POutputDispatcher +from supervisor.loggers import LevelsByName, StreamHandler, SyslogHandler +from supervisor.supervisord import Supervisor +from typing_extensions import Literal + +FORWARD_LOG_LEVEL = LevelsByName.CRIT # to make sure it's always printed + + +def empty_function(*args, **kwargs): + pass + + +FORWARD_MSG_FORMAT: str = "%(name)s[%(pid)d]%(stream)s: %(data)s" + + +def POutputDispatcher_log(self: POutputDispatcher, data: bytearray): + if data: + # parse the input + if not isinstance(data, bytes): + text = data + else: + try: + text = data.decode("utf-8") + except UnicodeDecodeError: + text = "Undecodable: %r" % data + + # print line by line prepending correct prefix to match the style + config = self.process.config + config.options.logger.handlers = forward_handlers + for line in text.splitlines(): + stream = "" + if self.channel == "stderr": + stream = " (stderr)" + config.options.logger.log( + FORWARD_LOG_LEVEL, FORWARD_MSG_FORMAT, name=config.name, stream=stream, data=line, pid=self.process.pid + ) + config.options.logger.handlers = supervisord_handlers + + +def _create_handler(fmt, level, target: Literal["stdout", "stderr", "syslog"]) -> StreamHandler: + if target == "syslog": + handler = SyslogHandler() + else: + handler = StreamHandler(sys.stdout if target == "stdout" else sys.stderr) + handler.setFormat(fmt) + handler.setLevel(level) + return handler + + +supervisord_handlers = [] +forward_handlers = [] + + +def inject(supervisord: Supervisor, **config: Any) -> Any: # pylint: disable=useless-return + try: + # reconfigure log handlers + supervisord.options.logger.info("reconfiguring log handlers") + supervisord_handlers.append( + _create_handler( + f"%(asctime)s supervisor[{os.getpid()}]: [%(levelname)s] %(message)s\n", + supervisord.options.loglevel, + config["target"], + ) + ) + forward_handlers.append( + _create_handler("%(asctime)s %(message)s\n", supervisord.options.loglevel, config["target"]) + ) + supervisord.options.logger.handlers = supervisord_handlers + + # replace output handler for subprocesses + POutputDispatcher._log = POutputDispatcher_log + + # we forward stdio in all cases, even when logging to syslog. This should prevent the unforturtunate + # case of swallowing an error message leaving the users confused. To make the forwarded lines obvious + # we just prepend a explanatory string at the beginning of all messages + if config["target"] == "syslog": + global FORWARD_MSG_FORMAT + FORWARD_MSG_FORMAT = "captured stdio output from " + FORWARD_MSG_FORMAT + + # this method is called by supervisord when loading the plugin, + # it should return XML-RPC object, which we don't care about + # That's why why are returning just None + return None + + # if we fail to load the module, print some explanation + # should not happen when run by endusers + except BaseException: + traceback.print_exc() + raise diff --git a/python/knot_resolver/controller/supervisord/plugin/sd_notify.py b/python/knot_resolver/controller/supervisord/plugin/sd_notify.py new file mode 100644 index 00000000..ff32828b --- /dev/null +++ b/python/knot_resolver/controller/supervisord/plugin/sd_notify.py @@ -0,0 +1,227 @@ +# type: ignore +# pylint: disable=protected-access +# pylint: disable=c-extension-no-member +import os +import signal +import time +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar + +from supervisor.events import ProcessStateEvent, ProcessStateStartingEvent, subscribe +from supervisor.medusa.asyncore_25 import compact_traceback +from supervisor.process import Subprocess +from supervisor.states import ProcessStates +from supervisor.supervisord import Supervisor + +from knot_resolver.controller.supervisord.plugin import notify + +starting_processes: List[Subprocess] = [] + + +def is_type_notify(proc: Subprocess) -> bool: + return proc.config.environment is not None and proc.config.environment.get("X-SUPERVISORD-TYPE", None) == "notify" + + +class NotifySocketDispatcher: + """ + See supervisor.dispatcher + """ + + def __init__(self, supervisor: Supervisor, fd: int): + self._supervisor = supervisor + self.fd = fd + self.closed = False # True if close() has been called + + def __repr__(self): + return f"<{self.__class__.__name__} with fd={self.fd}>" + + def readable(self): + return True + + def writable(self): + return False + + def handle_read_event(self): + logger: Any = self._supervisor.options.logger + + res: Optional[Tuple[int, bytes]] = notify.read_message(self.fd) + if res is None: + return None # there was some junk + pid, data = res + + # pylint: disable=undefined-loop-variable + for proc in starting_processes: + if proc.pid == pid: + break + else: + logger.warn(f"ignoring ready notification from unregistered PID={pid}") + return None + + if data.startswith(b"READY=1"): + # handle case, when some process is really ready + + if is_type_notify(proc): + proc._assertInState(ProcessStates.STARTING) + proc.change_state(ProcessStates.RUNNING) + logger.info( + f"success: {proc.config.name} entered RUNNING state, process sent notification via $NOTIFY_SOCKET" + ) + else: + logger.warn(f"ignoring READY notification from {proc.config.name}, which is not configured to send it") + + elif data.startswith(b"STOPPING=1"): + # just accept the message, filter unwanted notifications and do nothing else + + if is_type_notify(proc): + logger.info( + f"success: {proc.config.name} entered STOPPING state, process sent notification via $NOTIFY_SOCKET" + ) + else: + logger.warn( + f"ignoring STOPPING notification from {proc.config.name}, which is not configured to send it" + ) + + else: + # handle case, when we got something unexpected + logger.warn(f"ignoring unrecognized data on $NOTIFY_SOCKET sent from PID={pid}, data='{data!r}'") + return None + + def handle_write_event(self): + raise ValueError("this dispatcher is not writable") + + def handle_error(self): + _nil, t, v, tbinfo = compact_traceback() + + self._supervisor.options.logger.error( + f"uncaptured python exception, closing notify socket {repr(self)} ({t}:{v} {tbinfo})" + ) + self.close() + + def close(self): + if not self.closed: + os.close(self.fd) + self.closed = True + + def flush(self): + pass + + +def keep_track_of_starting_processes(event: ProcessStateEvent) -> None: + global starting_processes + + proc: Subprocess = event.process + + if isinstance(event, ProcessStateStartingEvent): + # process is starting + # if proc not in starting_processes: + starting_processes.append(proc) + + else: + # not starting + starting_processes = [p for p in starting_processes if p.pid is not proc.pid] + + +notify_dispatcher: Optional[NotifySocketDispatcher] = None + + +def process_transition(slf: Subprocess) -> None: + if not is_type_notify(slf): + return slf + + # modified version of upstream process transition code + if slf.state == ProcessStates.STARTING: + if time.time() - slf.laststart > slf.config.startsecs: + # STARTING -> STOPPING if the process has not sent ready notification + # within proc.config.startsecs + slf.config.options.logger.warn( + f"process '{slf.config.name}' did not send ready notification within {slf.config.startsecs} secs, killing" + ) + slf.kill(signal.SIGKILL) + slf.x_notifykilled = True # used in finish() function to set to FATAL state + slf.laststart = time.time() + 1 # prevent immediate state transition to RUNNING from happening + + # return self for chaining + return slf + + +def subprocess_finish_tail(slf, pid, sts) -> Tuple[Any, Any, Any]: + if getattr(slf, "x_notifykilled", False): + # we want FATAL, not STOPPED state after timeout waiting for startup notification + # why? because it's likely not gonna help to try starting the process up again if + # it failed so early + slf.change_state(ProcessStates.FATAL) + + # clear the marker value + del slf.x_notifykilled + + # return for chaining + return slf, pid, sts + + +def supervisord_get_process_map(supervisord: Any, mp: Dict[Any, Any]) -> Dict[Any, Any]: + global notify_dispatcher + if notify_dispatcher is None: + notify_dispatcher = NotifySocketDispatcher(supervisord, notify.init_socket()) + supervisord.options.logger.info("notify: injected $NOTIFY_SOCKET into event loop") + + # add our dispatcher to the result + assert notify_dispatcher.fd not in mp + mp[notify_dispatcher.fd] = notify_dispatcher + + return mp + + +def process_spawn_as_child_add_env(slf: Subprocess, *args: Any) -> Tuple[Any, ...]: + if is_type_notify(slf): + slf.config.environment["NOTIFY_SOCKET"] = "@knot-resolver-control-socket" + return (slf, *args) + + +T = TypeVar("T") +U = TypeVar("U") + + +def chain(first: Callable[..., U], second: Callable[[U], T]) -> Callable[..., T]: + def wrapper(*args: Any, **kwargs: Any) -> T: + res = first(*args, **kwargs) + if isinstance(res, tuple): + return second(*res) + else: + return second(res) + + return wrapper + + +def append(first: Callable[..., T], second: Callable[..., None]) -> Callable[..., T]: + def wrapper(*args: Any, **kwargs: Any) -> T: + res = first(*args, **kwargs) + second(*args, **kwargs) + return res + + return wrapper + + +def monkeypatch(supervisord: Supervisor) -> None: + """Inject ourselves into supervisord code""" + + # append notify socket handler to event loopo + supervisord.get_process_map = chain(supervisord.get_process_map, partial(supervisord_get_process_map, supervisord)) + + # prepend timeout handler to transition method + Subprocess.transition = chain(process_transition, Subprocess.transition) + Subprocess.finish = append(Subprocess.finish, subprocess_finish_tail) + + # add environment variable $NOTIFY_SOCKET to starting processes + Subprocess._spawn_as_child = chain(process_spawn_as_child_add_env, Subprocess._spawn_as_child) + + # keep references to starting subprocesses + subscribe(ProcessStateEvent, keep_track_of_starting_processes) + + +def inject(supervisord: Supervisor, **_config: Any) -> Any: # pylint: disable=useless-return + monkeypatch(supervisord) + + # this method is called by supervisord when loading the plugin, + # it should return XML-RPC object, which we don't care about + # That's why why are returning just None + return None diff --git a/python/knot_resolver/controller/supervisord/supervisord.conf.j2 b/python/knot_resolver/controller/supervisord/supervisord.conf.j2 new file mode 100644 index 00000000..4179d522 --- /dev/null +++ b/python/knot_resolver/controller/supervisord/supervisord.conf.j2 @@ -0,0 +1,93 @@ +[supervisord] +pidfile = {{ config.pid_file }} +directory = {{ config.workdir }} +nodaemon = true + +{# disable initial logging until patch_logger.py takes over #} +logfile = /dev/null +logfile_maxbytes = 0 +silent = true + +{# config for patch_logger.py #} +loglevel = {{ config.loglevel }} +{# there are more options in the plugin section #} + +[unix_http_server] +file = {{ config.unix_http_server }} + +[supervisorctl] +serverurl = unix://{{ config.unix_http_server }} + +{# Extensions to changing the supervisord behavior #} +[rpcinterface:patch_logger] +supervisor.rpcinterface_factory = knot_resolver.controller.supervisord.plugin.patch_logger:inject +target = {{ config.target }} + +[rpcinterface:manager_integration] +supervisor.rpcinterface_factory = knot_resolver.controller.supervisord.plugin.manager_integration:inject + +[rpcinterface:sd_notify] +supervisor.rpcinterface_factory = knot_resolver.controller.supervisord.plugin.sd_notify:inject + +{# Extensions for actual API control #} +[rpcinterface:supervisor] +supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface + +[rpcinterface:fast] +supervisor.rpcinterface_factory = knot_resolver.controller.supervisord.plugin.fast_rpcinterface:make_main_rpcinterface + +[program:manager] +redirect_stderr=false +directory={{ manager.workdir }} +command={{ manager.command }} +stopsignal=SIGINT +killasgroup=true +autorestart=true +autostart=true +{# Note that during startup, + manager will signal being ready only after sequential startup of all kresd workers, + i.e. it might take lots of time currently, if the user configured very large rulesets (e.g. huge RPZ). + Let's permit it lots of time, assuming that useful work is being done. +#} +startsecs=600 +environment={{ manager.environment }},KRES_SUPRESS_LOG_PREFIX=true +stdout_logfile=NONE +stderr_logfile=NONE + +[program:kresd] +process_name=%(program_name)s%(process_num)d +numprocs={{ kresd.max_procs }} +directory={{ kresd.workdir }} +command={{ kresd.command }} +autostart=false +autorestart=true +stopsignal=TERM +killasgroup=true +startsecs=60 +environment={{ kresd.environment }} +stdout_logfile=NONE +stderr_logfile=NONE + +[program:policy-loader] +directory={{ loader.workdir }} +command={{ loader.command }} +autostart=false +stopsignal=TERM +killasgroup=true +startsecs=300 +environment={{ loader.environment }} +stdout_logfile=NONE +stderr_logfile=NONE + +[program:cache-gc] +redirect_stderr=false +directory={{ gc.workdir }} +command={{ gc.command }} +autostart=false +autorestart=true +stopsignal=TERM +killasgroup=true +startsecs=0 +environment={{ gc.environment }} +stdout_logfile=NONE +stderr_logfile=NONE diff --git a/python/knot_resolver/datamodel/__init__.py b/python/knot_resolver/datamodel/__init__.py new file mode 100644 index 00000000..a0174acc --- /dev/null +++ b/python/knot_resolver/datamodel/__init__.py @@ -0,0 +1,3 @@ +from .config_schema import KresConfig + +__all__ = ["KresConfig"] diff --git a/python/knot_resolver/datamodel/cache_schema.py b/python/knot_resolver/datamodel/cache_schema.py new file mode 100644 index 00000000..eca36bf2 --- /dev/null +++ b/python/knot_resolver/datamodel/cache_schema.py @@ -0,0 +1,139 @@ +from typing import List, Optional, Union + +from typing_extensions import Literal + +from knot_resolver.datamodel.templates import template_from_str +from knot_resolver.datamodel.types import ( + DNSRecordTypeEnum, + DomainName, + EscapedStr, + IntNonNegative, + IntPositive, + Percent, + ReadableFile, + SizeUnit, + TimeUnit, + WritableDir, +) +from knot_resolver.utils.modeling import ConfigSchema +from knot_resolver.utils.modeling.base_schema import lazy_default + +_CACHE_CLEAR_TEMPLATE = template_from_str( + "{% from 'macros/cache_macros.lua.j2' import cache_clear %} {{ cache_clear(params) }}" +) + + +class CacheClearRPCSchema(ConfigSchema): + name: Optional[DomainName] = None + exact_name: bool = False + rr_type: Optional[DNSRecordTypeEnum] = None + chunk_size: IntPositive = IntPositive(100) + + def _validate(self) -> None: + if self.rr_type and not self.exact_name: + raise ValueError("'rr-type' is only supported with 'exact-name: true'") + + def render_lua(self) -> str: + return _CACHE_CLEAR_TEMPLATE.render(params=self) # pyright: reportUnknownMemberType=false + + +class PrefillSchema(ConfigSchema): + """ + Prefill the cache periodically by importing zone data obtained over HTTP. + + --- + origin: Origin for the imported data. Cache prefilling is only supported for the root zone ('.'). + url: URL of the zone data to be imported. + refresh_interval: Time interval between consecutive refreshes of the imported zone data. + ca_file: Path to the file containing a CA certificate bundle that is used to authenticate the HTTPS connection. + """ + + origin: DomainName + url: EscapedStr + refresh_interval: TimeUnit = TimeUnit("1d") + ca_file: Optional[ReadableFile] = None + + def _validate(self) -> None: + if str(self.origin) != ".": + raise ValueError("cache prefilling is not yet supported for non-root zones") + + +class GarbageCollectorSchema(ConfigSchema): + """ + Configuration options of the cache garbage collector (kres-cache-gc). + + --- + interval: Time interval how often the garbage collector will be run. + threshold: Cache usage in percent that triggers the garbage collector. + release: Percent of used cache to be freed by the garbage collector. + temp_keys_space: Maximum amount of temporary memory for copied keys (0 = unlimited). + rw_deletes: Maximum number of deleted records per read-write transaction (0 = unlimited). + rw_reads: Maximum number of readed records per read-write transaction (0 = unlimited). + rw_duration: Maximum duration of read-write transaction (0 = unlimited). + rw_delay: Wait time between two read-write transactions. + dry_run: Run the garbage collector in dry-run mode. + """ + + interval: TimeUnit = TimeUnit("1s") + threshold: Percent = Percent(80) + release: Percent = Percent(10) + temp_keys_space: SizeUnit = SizeUnit("0M") + rw_deletes: IntNonNegative = IntNonNegative(100) + rw_reads: IntNonNegative = IntNonNegative(200) + rw_duration: TimeUnit = TimeUnit("0us") + rw_delay: TimeUnit = TimeUnit("0us") + dry_run: bool = False + + +class PredictionSchema(ConfigSchema): + """ + Helps keep the cache hot by prefetching expiring records and learning usage patterns and repetitive queries. + + --- + window: Sampling window length. + period: Number of windows that can be kept in memory. + """ + + window: TimeUnit = TimeUnit("15m") + period: IntPositive = IntPositive(24) + + +class PrefetchSchema(ConfigSchema): + """ + These options help keep the cache hot by prefetching expiring records or learning usage patterns and repetitive queries. + --- + expiring: Prefetch expiring records. + prediction: Prefetch record by predicting based on usage patterns and repetitive queries. + """ + + expiring: bool = False + prediction: Optional[PredictionSchema] = None + + +class CacheSchema(ConfigSchema): + """ + DNS resolver cache configuration. + + --- + storage: Cache storage of the DNS resolver. + size_max: Maximum size of the cache. + garbage_collector: Use the garbage collector (kres-cache-gc) to periodically clear cache. + ttl_min: Minimum time-to-live for the cache entries. + ttl_max: Maximum time-to-live for the cache entries. + ns_timeout: Time interval for which a nameserver address will be ignored after determining that it does not return (useful) answers. + prefill: Prefill the cache periodically by importing zone data obtained over HTTP. + prefetch: These options help keep the cache hot by prefetching expiring records or learning usage patterns and repetitive queries. + """ + + storage: WritableDir = lazy_default(WritableDir, "/var/cache/knot-resolver") + size_max: SizeUnit = SizeUnit("100M") + garbage_collector: Union[GarbageCollectorSchema, Literal[False]] = GarbageCollectorSchema() + ttl_min: TimeUnit = TimeUnit("5s") + ttl_max: TimeUnit = TimeUnit("1d") + ns_timeout: TimeUnit = TimeUnit("1000ms") + prefill: Optional[List[PrefillSchema]] = None + prefetch: PrefetchSchema = PrefetchSchema() + + def _validate(self): + if self.ttl_min.seconds() >= self.ttl_max.seconds(): + raise ValueError("'ttl-max' must be larger then 'ttl-min'") diff --git a/python/knot_resolver/datamodel/config_schema.py b/python/knot_resolver/datamodel/config_schema.py new file mode 100644 index 00000000..1ee300d8 --- /dev/null +++ b/python/knot_resolver/datamodel/config_schema.py @@ -0,0 +1,242 @@ +import logging +import os +import socket +from typing import Any, Dict, List, Optional, Tuple, Union + +from typing_extensions import Literal + +from knot_resolver.manager.constants import MAX_WORKERS +from knot_resolver.datamodel.cache_schema import CacheSchema +from knot_resolver.datamodel.dns64_schema import Dns64Schema +from knot_resolver.datamodel.dnssec_schema import DnssecSchema +from knot_resolver.datamodel.forward_schema import ForwardSchema +from knot_resolver.datamodel.local_data_schema import LocalDataSchema, RPZSchema, RuleSchema +from knot_resolver.datamodel.logging_schema import LoggingSchema +from knot_resolver.datamodel.lua_schema import LuaSchema +from knot_resolver.datamodel.management_schema import ManagementSchema +from knot_resolver.datamodel.monitoring_schema import MonitoringSchema +from knot_resolver.datamodel.network_schema import NetworkSchema +from knot_resolver.datamodel.options_schema import OptionsSchema +from knot_resolver.datamodel.templates import POLICY_CONFIG_TEMPLATE, WORKER_CONFIG_TEMPLATE +from knot_resolver.datamodel.types import EscapedStr, IntPositive, WritableDir +from knot_resolver.datamodel.view_schema import ViewSchema +from knot_resolver.datamodel.webmgmt_schema import WebmgmtSchema +from knot_resolver.utils.modeling import ConfigSchema +from knot_resolver.utils.modeling.base_schema import lazy_default +from knot_resolver.utils.modeling.exceptions import AggregateDataValidationError, DataValidationError + +_DEFAULT_RUNDIR = "/var/run/knot-resolver" + +DEFAULT_MANAGER_API_SOCK = _DEFAULT_RUNDIR + "/manager.sock" + +logger = logging.getLogger(__name__) + + +def _cpu_count() -> Optional[int]: + try: + return len(os.sched_getaffinity(0)) + except (NotImplementedError, AttributeError): + logger.warning("The number of usable CPUs could not be determined using 'os.sched_getaffinity()'.") + cpus = os.cpu_count() + if cpus is None: + logger.warning("The number of usable CPUs could not be determined using 'os.cpu_count()'.") + return cpus + + +def _default_max_worker_count() -> int: + c = _cpu_count() + if c: + return c * 10 + return MAX_WORKERS + + +def _get_views_tags(views: List[ViewSchema]) -> List[str]: + tags = [] + for view in views: + if view.tags: + tags += [str(tag) for tag in view.tags if tag not in tags] + return tags + + +def _check_local_data_tags( + views_tags: List[str], rules_or_rpz: Union[List[RuleSchema], List[RPZSchema]] +) -> Tuple[List[str], List[DataValidationError]]: + tags = [] + errs = [] + + i = 0 + for rule in rules_or_rpz: + tags_not_in = [] + if rule.tags: + for tag in rule.tags: + tag_str = str(tag) + if tag_str not in tags: + tags.append(tag_str) + if tag_str not in views_tags: + tags_not_in.append(tag_str) + if len(tags_not_in) > 0: + errs.append( + DataValidationError( + f"some tags {tags_not_in} not found in '/views' tags", f"/local-data/rules[{i}]/tags" + ) + ) + i += 1 + return tags, errs + + +class KresConfig(ConfigSchema): + class Raw(ConfigSchema): + """ + Knot Resolver declarative configuration. + + --- + version: Version of the configuration schema. By default it is the latest supported by the resolver, but couple of versions back are be supported as well. + nsid: Name Server Identifier (RFC 5001) which allows DNS clients to request resolver to send back its NSID along with the reply to a DNS request. + hostname: Internal DNS resolver hostname. Default is machine hostname. + rundir: Directory where the resolver can create files and which will be it's cwd. + workers: The number of running kresd (Knot Resolver daemon) workers. If set to 'auto', it is equal to number of CPUs available. + max_workers: The maximum number of workers allowed. Cannot be changed in runtime. + management: Configuration of management HTTP API. + webmgmt: Configuration of legacy web management endpoint. + options: Fine-tuning global parameters of DNS resolver operation. + network: Network connections and protocols configuration. + views: List of views and its configuration. + local_data: Local data for forward records (A/AAAA) and reverse records (PTR). + forward: List of Forward Zones and its configuration. + cache: DNS resolver cache configuration. + dnssec: Disable DNSSEC, enable with defaults or set new configuration. + dns64: Disable DNS64 (RFC 6147), enable with defaults or set new configuration. + logging: Logging and debugging configuration. + monitoring: Metrics exposisition configuration (Prometheus, Graphite) + lua: Custom Lua configuration. + """ + + version: int = 1 + nsid: Optional[EscapedStr] = None + hostname: Optional[EscapedStr] = None + rundir: WritableDir = lazy_default(WritableDir, _DEFAULT_RUNDIR) + workers: Union[Literal["auto"], IntPositive] = IntPositive(1) + max_workers: IntPositive = IntPositive(_default_max_worker_count()) + management: ManagementSchema = lazy_default(ManagementSchema, {"unix-socket": DEFAULT_MANAGER_API_SOCK}) + webmgmt: Optional[WebmgmtSchema] = None + options: OptionsSchema = OptionsSchema() + network: NetworkSchema = NetworkSchema() + views: Optional[List[ViewSchema]] = None + local_data: LocalDataSchema = LocalDataSchema() + forward: Optional[List[ForwardSchema]] = None + cache: CacheSchema = lazy_default(CacheSchema, {}) + dnssec: Union[bool, DnssecSchema] = True + dns64: Union[bool, Dns64Schema] = False + logging: LoggingSchema = LoggingSchema() + monitoring: MonitoringSchema = MonitoringSchema() + lua: LuaSchema = LuaSchema() + + _LAYER = Raw + + nsid: Optional[EscapedStr] + hostname: EscapedStr + rundir: WritableDir + workers: IntPositive + max_workers: IntPositive + management: ManagementSchema + webmgmt: Optional[WebmgmtSchema] + options: OptionsSchema + network: NetworkSchema + views: Optional[List[ViewSchema]] + local_data: LocalDataSchema + forward: Optional[List[ForwardSchema]] + cache: CacheSchema + dnssec: Union[Literal[False], DnssecSchema] + dns64: Union[Literal[False], Dns64Schema] + logging: LoggingSchema + monitoring: MonitoringSchema + lua: LuaSchema + + def _hostname(self, obj: Raw) -> Any: + if obj.hostname is None: + return socket.gethostname() + return obj.hostname + + def _workers(self, obj: Raw) -> Any: + if obj.workers == "auto": + count = _cpu_count() + if count: + return IntPositive(count) + raise ValueError( + "The number of available CPUs to automatically set the number of running 'kresd' workers could not be determined." + "The number of workers can be configured manually in 'workers' option." + ) + return obj.workers + + def _dnssec(self, obj: Raw) -> Any: + if obj.dnssec is True: + return DnssecSchema() + return obj.dnssec + + def _dns64(self, obj: Raw) -> Any: + if obj.dns64 is True: + return Dns64Schema() + return obj.dns64 + + def _validate(self) -> None: + # enforce max-workers config + if int(self.workers) > int(self.max_workers): + raise ValueError(f"can't run with more workers then the configured maximum {self.max_workers}") + + # sanity check + cpu_count = _cpu_count() + if cpu_count and int(self.workers) > 10 * cpu_count: + raise ValueError( + "refusing to run with more then 10 workers per cpu core, the system wouldn't behave nicely" + ) + + # get all tags from views + views_tags = [] + if self.views: + views_tags = _get_views_tags(self.views) + + # get local-data tags and check its existence in views + errs = [] + local_data_tags = [] + if self.local_data.rules: + rules_tags, rules_errs = _check_local_data_tags(views_tags, self.local_data.rules) + errs += rules_errs + local_data_tags += rules_tags + if self.local_data.rpz: + rpz_tags, rpz_errs = _check_local_data_tags(views_tags, self.local_data.rpz) + errs += rpz_errs + local_data_tags += rpz_tags + + # look for unused tags in /views + unused_tags = views_tags.copy() + for tag in local_data_tags: + if tag in unused_tags: + unused_tags.remove(tag) + if len(unused_tags) > 1: + errs.append(DataValidationError(f"unused tags {unused_tags} found", "/views")) + + # raise all validation errors + if len(errs) == 1: + raise errs[0] + elif len(errs) > 1: + raise AggregateDataValidationError("/", errs) + + def render_lua(self) -> str: + # FIXME the `cwd` argument is used only for configuring control socket path + # it should be removed and relative path used instead as soon as issue + # https://gitlab.nic.cz/knot/knot-resolver/-/issues/720 is fixed + return WORKER_CONFIG_TEMPLATE.render(cfg=self, cwd=os.getcwd()) + + def render_lua_policy(self) -> str: + return POLICY_CONFIG_TEMPLATE.render(cfg=self, cwd=os.getcwd()) + + +def get_rundir_without_validation(data: Dict[str, Any]) -> WritableDir: + """ + Without fully parsing, try to get a rundir from a raw config data, otherwise use default. + Attempts a dir validation to produce a good error message. + + Used for initial manager startup. + """ + + return WritableDir(data["rundir"] if "rundir" in data else _DEFAULT_RUNDIR, object_path="/rundir") diff --git a/python/knot_resolver/datamodel/design-notes.yml b/python/knot_resolver/datamodel/design-notes.yml new file mode 100644 index 00000000..e4424bc8 --- /dev/null +++ b/python/knot_resolver/datamodel/design-notes.yml @@ -0,0 +1,237 @@ +###### Working notes about configuration schema + + +## TODO nit: nest one level deeper inside `dnssec`, probably +dnssec: + keep-removed: 0 + refresh-time: 10s + hold-down-time: 30d + +## TODO nit: I don't like this name, at least not for the experimental thing we have there +network: + tls: + auto_discovery: boolean + +#### General questions +Plurals: do we name attributes in plural if they're a list; + some of them even allow a non-list if using a single element. + + +#### New-policy brainstorming + +dnssec: + # Convert to key: style instead of list? + # - easier to handle in API/CLI (which might be a common action on names with broken DNSSEC) + # - allows to supply a value - stamp for expiration of that NTA + # (absolute time, but I can imagine API/CLI converting from duration when executed) + # - syntax isn't really more difficult, mainly it forces one entry per line (seems OK) + negative-trust-anchors: + example.org: + my.example.net: + + +view: + # When a client request arrives, based on the `view` class of rules we may either + # decide for a direct answer or for marking the request with a set of tags. + # The concepts of matching and actions are a very good fit for this, + # and that matches our old policy approach. Matching here should avoid QNAME+QTYPE; + # instead it's e.g. suitable for access control. + # RPZ files also support rules that fall into this `view` class. + # + # Selecting a single rule: the most specific client-IP prefix + # that also matches additional conditions. + - subnet: [ 0.0.0.0/0, ::/0 ] + answer: refused + # some might prefer `allow: refused` ? + # Also, RCODEs are customary in CAPITALS though maybe not in configs. + + - subnet: [ 10.0.0.0/8, 192.168.0.0/16 ] + # Adding `tags` implies allowing the query. + tags: [ t1, t2, t3 ] # theoretically we could use space-separated string + options: # only some of the global options can be overridden in view + minimize: true + dns64: true + rate-limit: # future option, probably (optionally?) structured + # LATER: rulesets are a relatively unclear feature for now. + # Their main point is to allow prioritization and avoid + # intermixing rules that come from different sources. + # Also some properties might be specifyable per ruleset. + ruleset: tt + + - subnet: [ 10.0.10.0/24 ] # maybe allow a single value instead of a list? + # LATER: special addresses? + # - for kresd-internal requests + # - shorthands for all private IPv4 and/or IPv6; + # though yaml's repeated nodes could mostly cover that + # or just copy&paste from docs + answer: allow + +# Or perhaps a more complex approach? Probably not. +# We might have multiple conditions at once and multiple actions at once, +# but I don't expect these to be common, so the complication is probably not worth it. +# An advantage would be that the separation of the two parts would be more visible. +view: + - match: + subnet: [ 10.0.0.0/8, 192.168.0.0/16 ] + do: + tags: [ t1, t2, t3 ] + options: # ... + + +local-data: # TODO: name + #FIXME: tags - allow assigning them to (groups of) addresses/records. + + addresses: # automatically adds PTR records and NODATA (LATER: overridable NODATA?) + foo.bar: [ 127.0.0.1, ::1 ] + my.pc.corp: 192.168.12.95 + addresses-files: # files in /etc/hosts format (and semantics like `addresses`) + - /etc/hosts + + # Zonefile format seems quite handy here. Details: + # - probably use `local-data.ttl` from model as the default + # - and . root to avoid confusion if someone misses a final dot. + records: | + example.net. TXT "foo bar" + A 192.168.2.3 + A 192.168.2.4 + local.example.org AAAA ::1 + + subtrees: + nodata: true # impl ATM: defaults to false, set (only) for each rule/name separately + # impl: options like `ttl` and `nodata` might make sense to be settable (only?) per ruleset + + subtrees: # TODO: perhaps just allow in the -tagged style, if we can't avoid lists anyway? + - type: empty + roots: [ sub2.example.org ] # TODO: name it the same as for forwarding + tags: [ t2 ] + - type: nxdomain + # Will we need to support multiple file formats in future and choose here? + roots-file: /path/to/file.txt + - type: empty + roots-url: https://example.org/blocklist.txt + refresh: 1d + # Is it a separate ruleset? Optionally? Persistence? + # (probably the same questions for local files as well) + + - type: redirect + roots: [ sub4.example.org ] + addresses: [ 127.0.0.1, ::1 ] + +local-data-tagged: # TODO: name (view?); and even structure seems unclear. + # TODO: allow only one "type" per list entry? (addresses / addresses-files / subtrees / ...) + - tags: [ t1, t2 ] + addresses: #... otherwise the same as local-data + - tags: [ t2 ] + records: # ... + - tags: [ t3 ] + subtrees: empty + roots: [ sub2.example.org ] + +local-data-tagged: # this avoids lists, so it's relatively easy to amend through API + "t1 t2": # perhaps it's not nice that tags don't form a proper list? + addresses: + foo.bar: [ 127.0.0.1, ::1 ] + t4: + addresses: + foo.bar: [ 127.0.0.1, ::1 ] +local-data: # avoids lists and merges into the untagged `local-data` config subtree + tagged: # (getting quite deep, though) + t1 t2: + addresses: + foo.bar: [ 127.0.0.1, ::1 ] +# or even this ugly thing: +local-data-tagged t1 t2: + addresses: + foo.bar: [ 127.0.0.1, ::1 ] + +forward: # TODO: "name" is from Unbound, but @vcunat would prefer "subtree" or something. + - name: '.' # Root is the default so could be omitted? + servers: [2001:148f:fffe::1, 2001:148f:ffff::1, 185.43.135.1, 193.14.47.1] + # TLS forward, server authenticated using hostname and system-wide CA certificates + # https://www.knot-resolver.cz/documentation/latest/modules-policy.html?highlight=forward#tls-examples + - name: '.' + servers: + - address: [ 192.0.2.1, 192.0.2.2@5353 ] + transport: tls + pin-sha256: Wg== + - address: 2001:DB8::d0c + transport: tls + hostname: res.example.com + ca-file: /etc/knot-resolver/tlsca.crt + options: + # LATER: allow a subset of options here, per sub-tree? + # Though that's not necessarily related to forwarding (e.g. TTL limits), + # especially implementation-wise it probably won't matter. + + +# Too confusing approach, I suppose? Different from usual way of thinking but closer to internal model. +# Down-sides: +# - multiple rules for the same name won't be possible (future, with different tags) +# - loading names from a file won't be possible (or URL, etc.) +rules: + example.org: &fwd_odvr + type: forward + servers: [2001:148f:fffe::1, 2001:148f:ffff::1, 185.43.135.1, 193.14.47.1] + sub2.example.org: + type: empty + tags: [ t3, t5 ] + sub3.example.org: + type: forward-auth + dnssec: no + + +# @amrazek: current valid config + +views: + - subnets: [ 0.0.0.0/0, "::/0" ] + answer: refused + - subnets: [ 0.0.0.0/0, "::/0" ] + tags: [t01, t02, t03] + options: + minimize: true # default + dns64: true # default + - subnets: 10.0.10.0/24 # can be single value + answer: allow + +local-data: + ttl: 1d + nodata: true + addresses: + foo.bar: [ 127.0.0.1, "::1" ] + my.pc.corp: 192.168.12.95 + addresses-files: + - /etc/hosts + records: | + example.net. TXT "foo bar" + A 192.168.2.3 + A 192.168.2.4 + local.example.org AAAA ::1 + subtrees: + - type: empty + roots: [ sub2.example.org ] + tags: [ t2 ] + - type: nxdomain + roots-file: /path/to/file.txt + - type: empty + roots-url: https://example.org/blocklist.txt + refresh: 1d + - type: redirect + roots: [ sub4.example.org ] + addresses: [ 127.0.0.1, "::1" ] + +forward: + - subtree: '.' + servers: + - address: [ 192.0.2.1, 192.0.2.2@5353 ] + transport: tls + pin-sha256: Wg== + - address: 2001:DB8::d0c + transport: tls + hostname: res.example.com + ca-file: /etc/knot-resolver/tlsca.crt + options: + dnssec: true # default + - subtree: 1.168.192.in-addr.arpa + servers: [ 192.0.2.1@5353 ] + options: + dnssec: false # policy.STUB? diff --git a/python/knot_resolver/datamodel/dns64_schema.py b/python/knot_resolver/datamodel/dns64_schema.py new file mode 100644 index 00000000..cc0fa06a --- /dev/null +++ b/python/knot_resolver/datamodel/dns64_schema.py @@ -0,0 +1,19 @@ +from typing import List, Optional + +from knot_resolver.datamodel.types import IPv6Network, IPv6Network96, TimeUnit +from knot_resolver.utils.modeling import ConfigSchema + + +class Dns64Schema(ConfigSchema): + """ + DNS64 (RFC 6147) configuration. + + --- + prefix: IPv6 prefix to be used for synthesizing AAAA records. + rev_ttl: TTL in CNAME generated in the reverse 'ip6.arpa.' subtree. + exclude_subnets: IPv6 subnets that are disallowed in answer. + """ + + prefix: IPv6Network96 = IPv6Network96("64:ff9b::/96") + rev_ttl: Optional[TimeUnit] = None + exclude_subnets: Optional[List[IPv6Network]] = None diff --git a/python/knot_resolver/datamodel/dnssec_schema.py b/python/knot_resolver/datamodel/dnssec_schema.py new file mode 100644 index 00000000..6f51d5eb --- /dev/null +++ b/python/knot_resolver/datamodel/dnssec_schema.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from knot_resolver.datamodel.types import DomainName, EscapedStr, IntNonNegative, ReadableFile, TimeUnit +from knot_resolver.utils.modeling import ConfigSchema + + +class TrustAnchorFileSchema(ConfigSchema): + """ + Trust-anchor zonefile configuration. + + --- + file: Path to the zonefile that stores trust-anchors. + read_only: Blocks zonefile updates according to RFC 5011. + + """ + + file: ReadableFile + read_only: bool = False + + +class DnssecSchema(ConfigSchema): + """ + DNSSEC configuration. + + --- + trust_anchor_sentinel: Allows users of DNSSEC validating resolver to detect which root keys are configured in resolver's chain of trust. (RFC 8509) + trust_anchor_signal_query: Signaling Trust Anchor Knowledge in DNSSEC Using Key Tag Query, according to (RFC 8145#section-5). + time_skew_detection: Detection of difference between local system time and expiration time bounds in DNSSEC signatures for '. NS' records. + keep_removed: How many removed keys should be held in history (and key file) before being purged. + refresh_time: Force trust-anchors to be updated every defined time periodically instead of relying on (RFC 5011) logic and TTLs. Intended only for testing purposes. + hold_down_time: Modify hold-down timer (RFC 5011). Intended only for testing purposes. + trust_anchors: List of trust-anchors in DS/DNSKEY records format. + negative_trust_anchors: List of domain names representing negative trust-anchors. (RFC 7646) + trust_anchors_files: List of zonefiles where trust-anchors are stored. + """ + + trust_anchor_sentinel: bool = True + trust_anchor_signal_query: bool = True + time_skew_detection: bool = True + keep_removed: IntNonNegative = IntNonNegative(0) + refresh_time: Optional[TimeUnit] = None + hold_down_time: TimeUnit = TimeUnit("30d") + trust_anchors: Optional[List[EscapedStr]] = None + negative_trust_anchors: Optional[List[DomainName]] = None + trust_anchors_files: Optional[List[TrustAnchorFileSchema]] = None diff --git a/python/knot_resolver/datamodel/forward_schema.py b/python/knot_resolver/datamodel/forward_schema.py new file mode 100644 index 00000000..96c0f048 --- /dev/null +++ b/python/knot_resolver/datamodel/forward_schema.py @@ -0,0 +1,84 @@ +from typing import Any, List, Optional, Union + +from typing_extensions import Literal + +from knot_resolver.datamodel.types import ( + DomainName, + IPAddressOptionalPort, + ListOrItem, + PinSha256, + ReadableFile, +) +from knot_resolver.utils.modeling import ConfigSchema + + +class ForwardServerSchema(ConfigSchema): + """ + Forward server configuration. + + --- + address: IP address(es) of a forward server. + transport: Transport protocol for a forward server. + pin_sha256: Hash of accepted CA certificate. + hostname: Hostname of the Forward server. + ca_file: Path to CA certificate file. + """ + + address: ListOrItem[IPAddressOptionalPort] + transport: Optional[Literal["tls"]] = None + pin_sha256: Optional[ListOrItem[PinSha256]] = None + hostname: Optional[DomainName] = None + ca_file: Optional[ReadableFile] = None + + def _validate(self) -> None: + if self.pin_sha256 and (self.hostname or self.ca_file): + raise ValueError("'pin-sha256' cannot be configurad together with 'hostname' or 'ca-file'") + + +class ForwardOptionsSchema(ConfigSchema): + """ + Subtree(s) forward options. + + --- + authoritative: The forwarding target is an authoritative server. + dnssec: Enable/disable DNSSEC. + """ + + authoritative: bool = False + dnssec: bool = True + + +class ForwardSchema(ConfigSchema): + """ + Configuration of forward subtree. + + --- + subtree: Subtree(s) to forward. + servers: Forward servers configuration. + options: Subtree(s) forward options. + """ + + subtree: ListOrItem[DomainName] + servers: Union[List[IPAddressOptionalPort], List[ForwardServerSchema]] + options: ForwardOptionsSchema = ForwardOptionsSchema() + + def _validate(self) -> None: + def is_port_custom(servers: List[Any]) -> bool: + for server in servers: + if isinstance(server, IPAddressOptionalPort) and server.port: + return int(server.port) != 53 + elif isinstance(server, ForwardServerSchema): + return is_port_custom(server.address.to_std()) + return False + + def is_transport_tls(servers: List[Any]) -> bool: + for server in servers: + if isinstance(server, ForwardServerSchema): + return server.transport == "tls" + return False + + if self.options.authoritative and is_port_custom(self.servers): + raise ValueError("Forwarding to authoritative servers on a custom port is currently not supported.") + + if self.options.authoritative and is_transport_tls(self.servers): + raise ValueError("Forwarding to authoritative servers using TLS protocol is not supported.") diff --git a/python/knot_resolver/datamodel/globals.py b/python/knot_resolver/datamodel/globals.py new file mode 100644 index 00000000..610323fa --- /dev/null +++ b/python/knot_resolver/datamodel/globals.py @@ -0,0 +1,57 @@ +""" +The parsing and validation of the datamodel is dependent on a global state: +- a file system path used for resolving relative paths + + +Commentary from @vsraier: +========================= + +While this is not ideal, it is the best we can do at the moment. When I created this module, +the datamodel was dependent on the global state implicitely. The validation procedures just read +the current working directory. This module is the first step in removing the global dependency. + +At some point in the future, it might be interesting to add something like a "validation context" +to the modelling tools. It is not technically complicated, but it requires +massive model changes I am not willing to make at the moment. Ideally, when implementing this, +the BaseSchema would turn into an empty class without any logic. Not even a constructor. All logic +would be in the ObjectMapper class. Similar to how Gson works in Java or AutoMapper in C#. +""" + +from pathlib import Path +from typing import Optional + + +class Context: + resolve_root: Optional[Path] + strict_validation: bool + + def __init__(self, resolve_root: Optional[Path], strict_validation: bool = True) -> None: + self.resolve_root = resolve_root + self.strict_validation = strict_validation + + +_global_context: Context = Context(None) + + +def set_global_validation_context(context: Context) -> None: + global _global_context + _global_context = context + + +def reset_global_validation_context() -> None: + global _global_context + _global_context = Context(None) + + +def get_resolve_root() -> Path: + if _global_context.resolve_root is None: + raise RuntimeError( + "Global validation context 'resolve_root' is not set!" + " Before validation, you have to set it using `set_global_validation_context()` function!" + ) + + return _global_context.resolve_root + + +def get_strict_validation() -> bool: + return _global_context.strict_validation diff --git a/python/knot_resolver/datamodel/local_data_schema.py b/python/knot_resolver/datamodel/local_data_schema.py new file mode 100644 index 00000000..2fe5a03a --- /dev/null +++ b/python/knot_resolver/datamodel/local_data_schema.py @@ -0,0 +1,95 @@ +from typing import Dict, List, Optional + +from typing_extensions import Literal + +from knot_resolver.datamodel.types import ( + DomainName, + EscapedStr, + IDPattern, + IPAddress, + ListOrItem, + ReadableFile, + TimeUnit, +) +from knot_resolver.utils.modeling import ConfigSchema + + +class RuleSchema(ConfigSchema): + """ + Local data advanced rule configuration. + + --- + name: Hostname(s). + subtree: Type of subtree. + address: Address(es) to pair with hostname(s). + file: Path to file(s) with hostname and IP address(es) pairs in '/etc/hosts' like format. + records: Direct addition of records in DNS zone file format. + tags: Tags to link with other policy rules. + ttl: Optional, TTL value used for these answers. + nodata: Optional, use NODATA synthesis. NODATA will be synthesised for matching name, but mismatching type(e.g. AAAA query when only A exists). + """ + + name: Optional[ListOrItem[DomainName]] = None + subtree: Optional[Literal["empty", "nxdomain", "redirect"]] = None + address: Optional[ListOrItem[IPAddress]] = None + file: Optional[ListOrItem[ReadableFile]] = None + records: Optional[EscapedStr] = None + tags: Optional[List[IDPattern]] = None + ttl: Optional[TimeUnit] = None + nodata: Optional[bool] = None + + def _validate(self) -> None: + options_sum = sum([bool(self.address), bool(self.subtree), bool(self.file), bool(self.records)]) + if options_sum == 2 and bool(self.address) and self.subtree in {"empty", "redirect"}: + pass # these combinations still make sense + elif options_sum > 1: + raise ValueError("only one of 'address', 'subtree' or 'file' can be configured") + elif options_sum < 1: + raise ValueError("one of 'address', 'subtree', 'file' or 'records' must be configured") + + options_sum2 = sum([bool(self.name), bool(self.file), bool(self.records)]) + if options_sum2 != 1: + raise ValueError("one of 'name', 'file or 'records' must be configured") + + if bool(self.nodata) and bool(self.subtree) and not bool(self.address): + raise ValueError("'nodata' defined but unused with 'subtree'") + + +class RPZSchema(ConfigSchema): + """ + Configuration or Response Policy Zone (RPZ). + + --- + file: Path to the RPZ zone file. + tags: Tags to link with other policy rules. + """ + + file: ReadableFile + tags: Optional[List[IDPattern]] = None + + +class LocalDataSchema(ConfigSchema): + """ + Local data for forward records (A/AAAA) and reverse records (PTR). + + --- + ttl: Default TTL value used for added local data/records. + nodata: Use NODATA synthesis. NODATA will be synthesised for matching name, but mismatching type(e.g. AAAA query when only A exists). + root_fallback_addresses: Direct replace of root hints. + root_fallback_addresses_files: Direct replace of root hints from a zonefile. + addresses: Direct addition of hostname and IP addresses pairs. + addresses_files: Direct addition of hostname and IP addresses pairs from files in '/etc/hosts' like format. + records: Direct addition of records in DNS zone file format. + rules: Local data rules. + rpz: List of Response Policy Zones and its configuration. + """ + + ttl: Optional[TimeUnit] = None + nodata: bool = True + root_fallback_addresses: Optional[Dict[DomainName, ListOrItem[IPAddress]]] = None + root_fallback_addresses_files: Optional[List[ReadableFile]] = None + addresses: Optional[Dict[DomainName, ListOrItem[IPAddress]]] = None + addresses_files: Optional[List[ReadableFile]] = None + records: Optional[EscapedStr] = None + rules: Optional[List[RuleSchema]] = None + rpz: Optional[List[RPZSchema]] = None diff --git a/python/knot_resolver/datamodel/logging_schema.py b/python/knot_resolver/datamodel/logging_schema.py new file mode 100644 index 00000000..e2985dd1 --- /dev/null +++ b/python/knot_resolver/datamodel/logging_schema.py @@ -0,0 +1,153 @@ +import os +from typing import Any, List, Optional, Set, Type, Union, cast + +from typing_extensions import Literal + +from knot_resolver.datamodel.types import TimeUnit, WritableFilePath +from knot_resolver.utils.modeling import ConfigSchema +from knot_resolver.utils.modeling.base_schema import is_obj_type_valid + +try: + # On Debian 10, the typing_extensions library does not contain TypeAlias. + # We don't strictly need the import for anything except for type checking, + # so this try-except makes sure it works either way. + from typing_extensions import TypeAlias # pylint: disable=ungrouped-imports +except ImportError: + TypeAlias = None # type: ignore + + +LogLevelEnum = Literal["crit", "err", "warning", "notice", "info", "debug"] +LogTargetEnum = Literal["syslog", "stderr", "stdout"] +LogGroupsEnum: TypeAlias = Literal[ + "manager", + "supervisord", + "cache-gc", + ## Now the LOG_GRP_*_TAG defines, exactly from ../../../lib/log.h + "system", + "cache", + "io", + "net", + "ta", + "tasent", + "tasign", + "taupd", + "tls", + "gnutls", + "tls_cl", + "xdp", + "doh", + "dnssec", + "hint", + "plan", + "iterat", + "valdtr", + "resolv", + "select", + "zoncut", + "cookie", + "statis", + "rebind", + "worker", + "policy", + "daf", + "timejm", + "timesk", + "graphi", + "prefil", + "primin", + "srvstl", + "wtchdg", + "nsid", + "dnstap", + "tests", + "dotaut", + "http", + "contrl", + "module", + "devel", + "renum", + "exterr", + "rules", + "prlayr", + # "reqdbg",... (non-displayed section of the enum) +] + + +class DnstapSchema(ConfigSchema): + """ + Logging DNS queries and responses to a unix socket. + + --- + unix_socket: Path to unix domain socket where dnstap messages will be sent. + log_queries: Log queries from downstream in wire format. + log_responses: Log responses to downstream in wire format. + log_tcp_rtt: Log TCP RTT (Round-trip time). + """ + + unix_socket: WritableFilePath + log_queries: bool = True + log_responses: bool = True + log_tcp_rtt: bool = True + + +class DebuggingSchema(ConfigSchema): + """ + Advanced debugging parameters for kresd (Knot Resolver daemon). + + --- + assertion_abort: Allow the process to be aborted in case it encounters a failed assertion. + assertion_fork: Fork and abord child kresd process to obtain a coredump, while the parent process recovers and keeps running. + """ + + assertion_abort: bool = False + assertion_fork: TimeUnit = TimeUnit("5m") + + +class LoggingSchema(ConfigSchema): + class Raw(ConfigSchema): + """ + Logging and debugging configuration. + + --- + level: Global logging level. + target: Global logging stream target. "from-env" uses $KRES_LOGGING_TARGET and defaults to "stdout". + groups: List of groups for which 'debug' logging level is set. + dnssec_bogus: Logging a message for each DNSSEC validation failure. + dnstap: Logging DNS requests and responses to a unix socket. + debugging: Advanced debugging parameters for kresd (Knot Resolver daemon). + """ + + level: LogLevelEnum = "notice" + target: Union[LogTargetEnum, Literal["from-env"]] = "from-env" + groups: Optional[List[LogGroupsEnum]] = None + dnssec_bogus: bool = False + dnstap: Union[Literal[False], DnstapSchema] = False + debugging: DebuggingSchema = DebuggingSchema() + + _LAYER = Raw + + level: LogLevelEnum + target: LogTargetEnum + groups: Optional[List[LogGroupsEnum]] + dnssec_bogus: bool + dnstap: Union[Literal[False], DnstapSchema] + debugging: DebuggingSchema + + def _target(self, raw: Raw) -> LogTargetEnum: + if raw.target == "from-env": + target = os.environ.get("KRES_LOGGING_TARGET") or "stdout" + if not is_obj_type_valid(target, cast(Type[Any], LogTargetEnum)): + raise ValueError(f"logging target '{target}' read from $KRES_LOGGING_TARGET is invalid") + return cast(LogTargetEnum, target) + else: + return raw.target + + def _validate(self): + if self.groups is None: + return + + checked: Set[str] = set() + for i, g in enumerate(self.groups): + if g in checked: + raise ValueError(f"duplicate logging group '{g}' on index {i}") + checked.add(g) diff --git a/python/knot_resolver/datamodel/lua_schema.py b/python/knot_resolver/datamodel/lua_schema.py new file mode 100644 index 00000000..56e8ee09 --- /dev/null +++ b/python/knot_resolver/datamodel/lua_schema.py @@ -0,0 +1,23 @@ +from typing import Optional + +from knot_resolver.datamodel.types import ReadableFile +from knot_resolver.utils.modeling import ConfigSchema + + +class LuaSchema(ConfigSchema): + """ + Custom Lua configuration. + + --- + script_only: Ignore declarative configuration and use only Lua script or file defined in this section. + script: Custom Lua configuration script. + script_file: Path to file that contains Lua configuration script. + """ + + script_only: bool = False + script: Optional[str] = None + script_file: Optional[ReadableFile] = None + + def _validate(self) -> None: + if self.script and self.script_file: + raise ValueError("'lua.script' and 'lua.script-file' are both defined, only one can be used") diff --git a/python/knot_resolver/datamodel/management_schema.py b/python/knot_resolver/datamodel/management_schema.py new file mode 100644 index 00000000..b338c32a --- /dev/null +++ b/python/knot_resolver/datamodel/management_schema.py @@ -0,0 +1,21 @@ +from typing import Optional + +from knot_resolver.datamodel.types import WritableFilePath, IPAddressPort +from knot_resolver.utils.modeling import ConfigSchema + + +class ManagementSchema(ConfigSchema): + """ + Configuration of management HTTP API. + + --- + unix_socket: Path to unix domain socket to listen to. + interface: IP address and port number to listen to. + """ + + unix_socket: Optional[WritableFilePath] = None + interface: Optional[IPAddressPort] = None + + def _validate(self) -> None: + if bool(self.unix_socket) == bool(self.interface): + raise ValueError("One of 'interface' or 'unix-socket' must be configured.") diff --git a/python/knot_resolver/datamodel/monitoring_schema.py b/python/knot_resolver/datamodel/monitoring_schema.py new file mode 100644 index 00000000..3b3ad6d9 --- /dev/null +++ b/python/knot_resolver/datamodel/monitoring_schema.py @@ -0,0 +1,25 @@ +from typing import Union + +from typing_extensions import Literal + +from knot_resolver.datamodel.types import DomainName, EscapedStr, IPAddress, PortNumber, TimeUnit +from knot_resolver.utils.modeling import ConfigSchema + + +class GraphiteSchema(ConfigSchema): + host: Union[IPAddress, DomainName] + port: PortNumber = PortNumber(2003) + prefix: EscapedStr = EscapedStr("") + interval: TimeUnit = TimeUnit("5s") + tcp: bool = False + + +class MonitoringSchema(ConfigSchema): + """ + --- + enabled: configures, whether statistics module will be loaded into resolver + graphite: optionally configures where should graphite metrics be sent to + """ + + enabled: Literal["manager-only", "lazy", "always"] = "lazy" + graphite: Union[Literal[False], GraphiteSchema] = False diff --git a/python/knot_resolver/datamodel/network_schema.py b/python/knot_resolver/datamodel/network_schema.py new file mode 100644 index 00000000..e766d499 --- /dev/null +++ b/python/knot_resolver/datamodel/network_schema.py @@ -0,0 +1,181 @@ +from typing import List, Optional, Union + +from typing_extensions import Literal + +from knot_resolver.datamodel.types import ( + EscapedStr32B, + WritableFilePath, + Int0_512, + Int0_65535, + InterfaceOptionalPort, + IPAddress, + IPAddressEM, + IPNetwork, + IPv4Address, + IPv6Address, + ListOrItem, + PortNumber, + ReadableFile, + SizeUnit, +) +from knot_resolver.utils.modeling import ConfigSchema + +KindEnum = Literal["dns", "xdp", "dot", "doh-legacy", "doh2"] + + +class EdnsBufferSizeSchema(ConfigSchema): + """ + EDNS payload size advertised in DNS packets. + + --- + upstream: Maximum EDNS upstream (towards other DNS servers) payload size. + downstream: Maximum EDNS downstream (towards clients) payload size for communication. + """ + + upstream: SizeUnit = SizeUnit("1232B") + downstream: SizeUnit = SizeUnit("1232B") + + +class AddressRenumberingSchema(ConfigSchema): + """ + Renumbers addresses in answers to different address space. + + --- + source: Source subnet. + destination: Destination address prefix. + """ + + source: IPNetwork + destination: Union[IPAddressEM, IPAddress] + + +class TLSSchema(ConfigSchema): + """ + TLS configuration, also affects DNS over TLS and DNS over HTTPS. + + --- + cert_file: Path to certificate file. + key_file: Path to certificate key file. + sticket_secret: Secret for TLS session resumption via tickets. (RFC 5077). + sticket_secret_file: Path to file with secret for TLS session resumption via tickets. (RFC 5077). + auto_discovery: Experimental automatic discovery of authoritative servers supporting DNS-over-TLS. + padding: EDNS(0) padding of queries and answers sent over an encrypted channel. + """ + + cert_file: Optional[ReadableFile] = None + key_file: Optional[ReadableFile] = None + sticket_secret: Optional[EscapedStr32B] = None + sticket_secret_file: Optional[ReadableFile] = None + auto_discovery: bool = False + padding: Union[bool, Int0_512] = True + + def _validate(self): + if self.sticket_secret and self.sticket_secret_file: + raise ValueError("'sticket_secret' and 'sticket_secret_file' are both defined, only one can be used") + + +class ListenSchema(ConfigSchema): + class Raw(ConfigSchema): + """ + Configuration of listening interface. + + --- + unix_socket: Path to unix domain socket to listen to. + interface: IP address or interface name with optional port number to listen to. + port: Port number to listen to. + kind: Specifies DNS query transport protocol. + freebind: Used for binding to non-local address. + """ + + interface: Optional[ListOrItem[InterfaceOptionalPort]] = None + unix_socket: Optional[ListOrItem[WritableFilePath]] = None + port: Optional[PortNumber] = None + kind: KindEnum = "dns" + freebind: bool = False + + _LAYER = Raw + + interface: Optional[ListOrItem[InterfaceOptionalPort]] + unix_socket: Optional[ListOrItem[WritableFilePath]] + port: Optional[PortNumber] + kind: KindEnum + freebind: bool + + def _interface(self, origin: Raw) -> Optional[ListOrItem[InterfaceOptionalPort]]: + if origin.interface: + port_set: Optional[bool] = None + for intrfc in origin.interface: # type: ignore[attr-defined] + if origin.port and intrfc.port: + raise ValueError("The port number is defined in two places ('port' option and '@<port>' syntax).") + if port_set is not None and (bool(intrfc.port) != port_set): + raise ValueError( + "The '@<port>' syntax must be used either for all or none of the interface in the list." + ) + port_set = bool(intrfc.port) + return origin.interface + + def _port(self, origin: Raw) -> Optional[PortNumber]: + if origin.port: + return origin.port + # default port number based on kind + elif origin.interface: + if origin.kind == "dot": + return PortNumber(853) + elif origin.kind in ["doh-legacy", "doh2"]: + return PortNumber(443) + return PortNumber(53) + return None + + def _validate(self) -> None: + if bool(self.unix_socket) == bool(self.interface): + raise ValueError("One of 'interface' or 'unix-socket' must be configured.") + if self.port and self.unix_socket: + raise ValueError( + "'unix-socket' and 'port' are not compatible options." + " Port configuration can only be used with 'interface' option." + ) + + +class ProxyProtocolSchema(ConfigSchema): + """ + PROXYv2 protocol configuration. + + --- + allow: Allow usage of the PROXYv2 protocol headers by clients on the specified addresses. + """ + + allow: List[Union[IPAddress, IPNetwork]] + + +class NetworkSchema(ConfigSchema): + """ + Network connections and protocols configuration. + + --- + do_ipv4: Enable/disable using IPv4 for contacting upstream nameservers. + do_ipv6: Enable/disable using IPv6 for contacting upstream nameservers. + out_interface_v4: IPv4 address used to perform queries. Not set by default, which lets the OS choose any address. + out_interface_v6: IPv6 address used to perform queries. Not set by default, which lets the OS choose any address. + tcp_pipeline: TCP pipeline limit. The number of outstanding queries that a single client connection can make in parallel. + edns_tcp_keepalive: Allows clients to discover the connection timeout. (RFC 7828) + edns_buffer_size: Maximum EDNS payload size advertised in DNS packets. Different values can be configured for communication downstream (towards clients) and upstream (towards other DNS servers). + address_renumbering: Renumbers addresses in answers to different address space. + tls: TLS configuration, also affects DNS over TLS and DNS over HTTPS. + proxy_protocol: PROXYv2 protocol configuration. + listen: List of interfaces to listen to and its configuration. + """ + + do_ipv4: bool = True + do_ipv6: bool = True + out_interface_v4: Optional[IPv4Address] = None + out_interface_v6: Optional[IPv6Address] = None + tcp_pipeline: Int0_65535 = Int0_65535(100) + edns_tcp_keepalive: bool = True + edns_buffer_size: EdnsBufferSizeSchema = EdnsBufferSizeSchema() + address_renumbering: Optional[List[AddressRenumberingSchema]] = None + tls: TLSSchema = TLSSchema() + proxy_protocol: Union[Literal[False], ProxyProtocolSchema] = False + listen: List[ListenSchema] = [ + ListenSchema({"interface": "127.0.0.1"}), + ListenSchema({"interface": "::1", "freebind": True}), + ] diff --git a/python/knot_resolver/datamodel/options_schema.py b/python/knot_resolver/datamodel/options_schema.py new file mode 100644 index 00000000..9230c7f0 --- /dev/null +++ b/python/knot_resolver/datamodel/options_schema.py @@ -0,0 +1,36 @@ +from typing_extensions import Literal + +from knot_resolver.utils.modeling import ConfigSchema + +GlueCheckingEnum = Literal["normal", "strict", "permissive"] + + +class OptionsSchema(ConfigSchema): + """ + Fine-tuning global parameters of DNS resolver operation. + + --- + glue_checking: Glue records scrictness checking level. + minimize: Send minimum amount of information in recursive queries to enhance privacy. + query_loopback: Permits queries to loopback addresses. + reorder_rrset: Controls whether resource records within a RRSet are reordered each time it is served from the cache. + query_case_randomization: Randomize Query Character Case. + priming: Initializing DNS resolver cache with Priming Queries (RFC 8109) + rebinding_protection: Protection against DNS Rebinding attack. + refuse_no_rd: Queries without RD (recursion desired) bit set in query are answered with REFUSED. + time_jump_detection: Detection of difference between local system time and expiration time bounds in DNSSEC signatures for '. NS' records. + violators_workarounds: Workarounds for known DNS protocol violators. + serve_stale: Allows using timed-out records in case DNS resolver is unable to contact upstream servers. + """ + + glue_checking: GlueCheckingEnum = "normal" + minimize: bool = True + query_loopback: bool = False + reorder_rrset: bool = True + query_case_randomization: bool = True + priming: bool = True + rebinding_protection: bool = False + refuse_no_rd: bool = True + time_jump_detection: bool = True + violators_workarounds: bool = False + serve_stale: bool = False diff --git a/python/knot_resolver/datamodel/policy_schema.py b/python/knot_resolver/datamodel/policy_schema.py new file mode 100644 index 00000000..8f9d8b26 --- /dev/null +++ b/python/knot_resolver/datamodel/policy_schema.py @@ -0,0 +1,126 @@ +from typing import List, Optional, Union + +from knot_resolver.datamodel.forward_schema import ForwardServerSchema +from knot_resolver.datamodel.network_schema import AddressRenumberingSchema +from knot_resolver.datamodel.types import ( + DNSRecordTypeEnum, + IPAddressOptionalPort, + PolicyActionEnum, + PolicyFlagEnum, + TimeUnit, +) +from knot_resolver.utils.modeling import ConfigSchema + + +class FilterSchema(ConfigSchema): + """ + Query filtering configuration. + + --- + suffix: Filter based on the suffix of the query name. + pattern: Filter based on the pattern that match query name. + qtype: Filter based on the DNS query type. + """ + + suffix: Optional[str] = None + pattern: Optional[str] = None + qtype: Optional[DNSRecordTypeEnum] = None + + +class AnswerSchema(ConfigSchema): + """ + Configuration of custom resource record for DNS answer. + + --- + rtype: Type of DNS resource record. + rdata: Data of DNS resource record. + ttl: Time-to-live value for defined answer. + nodata: Answer with NODATA If requested type is not configured in the answer. Otherwise policy rule is ignored. + """ + + rtype: DNSRecordTypeEnum + rdata: str + ttl: TimeUnit = TimeUnit("1s") + nodata: bool = False + + +def _validate_policy_action(policy_action: Union["ActionSchema", "PolicySchema"]) -> None: + servers = ["mirror", "forward", "stub"] + + def _field(ac: str) -> str: + if ac in servers: + return "servers" + return "message" if ac == "deny" else ac + + configurable_actions = ["deny", "reroute", "answer"] + servers + + # checking for missing mandatory fields for actions + field = _field(policy_action.action) + if policy_action.action in configurable_actions and not getattr(policy_action, field): + raise ValueError(f"missing mandatory field '{field}' for '{policy_action.action}' action") + + # checking for unnecessary fields + for ac in configurable_actions + ["deny"]: + field = _field(ac) + if getattr(policy_action, field) and _field(policy_action.action) != field: + raise ValueError(f"'{field}' field can only be defined for '{ac}' action") + + # ForwardServerSchema is valid only for 'forward' action + if policy_action.servers: + for server in policy_action.servers: # pylint: disable=not-an-iterable + if policy_action.action != "forward" and isinstance(server, ForwardServerSchema): + raise ValueError( + f"'ForwardServerSchema' in 'servers' is valid only for 'forward' action, got '{policy_action.action}'" + ) + + +class ActionSchema(ConfigSchema): + """ + Configuration of policy action. + + --- + action: Policy action. + message: Deny message for 'deny' action. + reroute: Configuration for 'reroute' action. + answer: Answer definition for 'answer' action. + servers: Servers configuration for 'mirror', 'forward' and 'stub' action. + """ + + action: PolicyActionEnum + message: Optional[str] = None + reroute: Optional[List[AddressRenumberingSchema]] = None + answer: Optional[AnswerSchema] = None + servers: Optional[Union[List[IPAddressOptionalPort], List[ForwardServerSchema]]] = None + + def _validate(self) -> None: + _validate_policy_action(self) + + +class PolicySchema(ConfigSchema): + """ + Configuration of policy rule. + + --- + action: Policy rule action. + priority: Policy rule priority. + filter: Query filtering configuration. + views: Use policy rule only for clients defined by views. + options: Configuration flags for policy rule. + message: Deny message for 'deny' action. + reroute: Configuration for 'reroute' action. + answer: Answer definition for 'answer' action. + servers: Servers configuration for 'mirror', 'forward' and 'stub' action. + """ + + action: PolicyActionEnum + priority: Optional[int] = None + filter: Optional[FilterSchema] = None + views: Optional[List[str]] = None + options: Optional[List[PolicyFlagEnum]] = None + message: Optional[str] = None + reroute: Optional[List[AddressRenumberingSchema]] = None + answer: Optional[AnswerSchema] = None + servers: Optional[Union[List[IPAddressOptionalPort], List[ForwardServerSchema]]] = None + + def _validate(self) -> None: + _validate_policy_action(self) diff --git a/python/knot_resolver/datamodel/rpz_schema.py b/python/knot_resolver/datamodel/rpz_schema.py new file mode 100644 index 00000000..96d79293 --- /dev/null +++ b/python/knot_resolver/datamodel/rpz_schema.py @@ -0,0 +1,29 @@ +from typing import List, Optional + +from knot_resolver.datamodel.types import PolicyActionEnum, PolicyFlagEnum, ReadableFile +from knot_resolver.utils.modeling import ConfigSchema + + +class RPZSchema(ConfigSchema): + """ + Configuration or Response Policy Zone (RPZ). + + --- + action: RPZ rule action, typically 'deny'. + file: Path to the RPZ zone file. + watch: Reload the file when it changes. + views: Use RPZ rule only for clients defined by views. + options: Configuration flags for RPZ rule. + message: Deny message for 'deny' action. + """ + + action: PolicyActionEnum + file: ReadableFile + watch: bool = True + views: Optional[List[str]] = None + options: Optional[List[PolicyFlagEnum]] = None + message: Optional[str] = None + + def _validate(self) -> None: + if self.message and not self.action == "deny": + raise ValueError("'message' field can only be defined for 'deny' action") diff --git a/python/knot_resolver/datamodel/slice_schema.py b/python/knot_resolver/datamodel/slice_schema.py new file mode 100644 index 00000000..1586cab7 --- /dev/null +++ b/python/knot_resolver/datamodel/slice_schema.py @@ -0,0 +1,21 @@ +from typing import List, Optional + +from typing_extensions import Literal + +from knot_resolver.datamodel.policy_schema import ActionSchema +from knot_resolver.utils.modeling import ConfigSchema + + +class SliceSchema(ConfigSchema): + """ + Split the entire DNS namespace into distinct slices. + + --- + function: Slicing function that returns index based on query + views: Use this Slice only for clients defined by views. + actions: Actions for slice. + """ + + function: Literal["randomize-psl"] = "randomize-psl" + views: Optional[List[str]] = None + actions: List[ActionSchema] diff --git a/python/knot_resolver/datamodel/static_hints_schema.py b/python/knot_resolver/datamodel/static_hints_schema.py new file mode 100644 index 00000000..ac64c311 --- /dev/null +++ b/python/knot_resolver/datamodel/static_hints_schema.py @@ -0,0 +1,27 @@ +from typing import Dict, List, Optional + +from knot_resolver.datamodel.types import DomainName, IPAddress, ReadableFile, TimeUnit +from knot_resolver.utils.modeling import ConfigSchema + + +class StaticHintsSchema(ConfigSchema): + """ + Static hints for forward records (A/AAAA) and reverse records (PTR) + + --- + ttl: TTL value used for records added from static hints. + nodata: Use NODATA synthesis. NODATA will be synthesised for matching hint name, but mismatching type. + etc_hosts: Add hints from '/etc/hosts' file. + root_hints: Direct addition of root hints pairs (hostname, list of addresses). + root_hints_file: Path to root hints in zonefile. Replaces all current root hints. + hints: Direct addition of hints pairs (hostname, list of addresses). + hints_files: Path to hints in hosts-like file. + """ + + ttl: Optional[TimeUnit] = None + nodata: bool = True + etc_hosts: bool = False + root_hints: Optional[Dict[DomainName, List[IPAddress]]] = None + root_hints_file: Optional[ReadableFile] = None + hints: Optional[Dict[DomainName, List[IPAddress]]] = None + hints_files: Optional[List[ReadableFile]] = None diff --git a/python/knot_resolver/datamodel/stub_zone_schema.py b/python/knot_resolver/datamodel/stub_zone_schema.py new file mode 100644 index 00000000..afd1cc79 --- /dev/null +++ b/python/knot_resolver/datamodel/stub_zone_schema.py @@ -0,0 +1,32 @@ +from typing import List, Optional, Union + +from knot_resolver.datamodel.types import DomainName, IPAddressOptionalPort, PolicyFlagEnum +from knot_resolver.utils.modeling import ConfigSchema + + +class StubServerSchema(ConfigSchema): + """ + Configuration of Stub server. + + --- + address: IP address of Stub server. + """ + + address: IPAddressOptionalPort + + +class StubZoneSchema(ConfigSchema): + """ + Configuration of Stub Zone. + + --- + subtree: Domain name of the zone. + servers: IP address of Stub server. + views: Use this Stub Zone only for clients defined by views. + options: Configuration flags for Stub Zone. + """ + + subtree: DomainName + servers: Union[List[IPAddressOptionalPort], List[StubServerSchema]] + views: Optional[List[str]] = None + options: Optional[List[PolicyFlagEnum]] = None diff --git a/python/knot_resolver/datamodel/templates/__init__.py b/python/knot_resolver/datamodel/templates/__init__.py new file mode 100644 index 00000000..fdb91dd2 --- /dev/null +++ b/python/knot_resolver/datamodel/templates/__init__.py @@ -0,0 +1,43 @@ +import os +import sys + +from jinja2 import Environment, FileSystemLoader, Template + + +def _get_templates_dir() -> str: + module = sys.modules["knot_resolver.datamodel"].__file__ + if module: + templates_dir = os.path.join(os.path.dirname(module), "templates") + if os.path.isdir(templates_dir): + return templates_dir + raise NotADirectoryError(f"the templates dir '{templates_dir}' is not a directory or does not exist") + raise OSError("package 'knot_resolver.datamodel' cannot be located or loaded") + + +_TEMPLATES_DIR = _get_templates_dir() + + +def _import_kresd_worker_config_template() -> Template: + path = os.path.join(_TEMPLATES_DIR, "worker-config.lua.j2") + with open(path, "r", encoding="UTF-8") as file: + template = file.read() + return template_from_str(template) + + +def _import_kresd_policy_config_template() -> Template: + path = os.path.join(_TEMPLATES_DIR, "policy-config.lua.j2") + with open(path, "r", encoding="UTF-8") as file: + template = file.read() + return template_from_str(template) + + +def template_from_str(template: str) -> Template: + ldr = FileSystemLoader(_TEMPLATES_DIR) + env = Environment(trim_blocks=True, lstrip_blocks=True, loader=ldr) + return env.from_string(template) + + +WORKER_CONFIG_TEMPLATE = _import_kresd_worker_config_template() + + +POLICY_CONFIG_TEMPLATE = _import_kresd_policy_config_template() diff --git a/python/knot_resolver/datamodel/templates/cache.lua.j2 b/python/knot_resolver/datamodel/templates/cache.lua.j2 new file mode 100644 index 00000000..f0176a59 --- /dev/null +++ b/python/knot_resolver/datamodel/templates/cache.lua.j2 @@ -0,0 +1,32 @@ +cache.open({{ cfg.cache.size_max.bytes() }}, 'lmdb://{{ cfg.cache.storage }}') +cache.min_ttl({{ cfg.cache.ttl_min.seconds() }}) +cache.max_ttl({{ cfg.cache.ttl_max.seconds() }}) +cache.ns_tout({{ cfg.cache.ns_timeout.millis() }}) + +{% if cfg.cache.prefill %} +-- cache.prefill +modules.load('prefill') +prefill.config({ +{% for item in cfg.cache.prefill %} + ['{{ item.origin.punycode() }}'] = { + url = '{{ item.url }}', + interval = {{ item.refresh_interval.seconds() }} + {{ "ca_file = '"+item.ca_file+"'," if item.ca_file }} + } +{% endfor %} +}) +{% endif %} + +{% if cfg.cache.prefetch.expiring %} +-- cache.prefetch.expiring +modules.load('prefetch') +{% endif %} + +{% if cfg.cache.prefetch.prediction %} +-- cache.prefetch.prediction +modules.load('predict') +predict.config({ + window = {{ cfg.cache.prefetch.prediction.window.minutes() }}, + period = {{ cfg.cache.prefetch.prediction.period }}, +}) +{% endif %} diff --git a/python/knot_resolver/datamodel/templates/dns64.lua.j2 b/python/knot_resolver/datamodel/templates/dns64.lua.j2 new file mode 100644 index 00000000..c5239f00 --- /dev/null +++ b/python/knot_resolver/datamodel/templates/dns64.lua.j2 @@ -0,0 +1,17 @@ +{% from 'macros/common_macros.lua.j2' import string_table %} + +{% if cfg.dns64 %} +-- load dns64 module +modules.load('dns64') + +-- dns64.prefix +dns64.config({ + prefix = '{{ cfg.dns64.prefix.to_std().network_address|string }}', +{% if cfg.dns64.rev_ttl %} + rev_ttl = {{ cfg.dns64.rev_ttl.seconds() }}, +{% endif %} +{% if cfg.dns64.exclude_subnets %} + exclude_subnets = {{ string_table(cfg.dns64.exclude_subnets) }}, +{% endif %} +}) +{% endif %}
\ No newline at end of file diff --git a/python/knot_resolver/datamodel/templates/dnssec.lua.j2 b/python/knot_resolver/datamodel/templates/dnssec.lua.j2 new file mode 100644 index 00000000..05d1fa68 --- /dev/null +++ b/python/knot_resolver/datamodel/templates/dnssec.lua.j2 @@ -0,0 +1,60 @@ +{% from 'macros/common_macros.lua.j2' import boolean %} + +{% if not cfg.dnssec %} +-- disable dnssec +trust_anchors.remove('.') +{% endif %} + +-- options.trust-anchor-sentinel +{% if cfg.dnssec.trust_anchor_sentinel %} +modules.load('ta_sentinel') +{% else %} +modules.unload('ta_sentinel') +{% endif %} + +-- options.trust-anchor-signal-query +{% if cfg.dnssec.trust_anchor_signal_query %} +modules.load('ta_signal_query') +{% else %} +modules.unload('ta_signal_query') +{% endif %} + +-- options.time-skew-detection +{% if cfg.dnssec.time_skew_detection %} +modules.load('detect_time_skew') +{% else %} +modules.unload('detect_time_skew') +{% endif %} + +{% if cfg.dnssec.keep_removed %} +-- dnssec.keep-removed +trust_anchors.keep_removed = {{ cfg.dnssec.keep_removed }} +{% endif %} + +{% if cfg.dnssec.refresh_time %} +-- dnssec.refresh-time +trust_anchors.refresh_time = {{ cfg.dnssec.refresh_time.seconds()|string }} +{% endif %} + +{% if cfg.dnssec.trust_anchors %} +-- dnssec.trust-anchors +{% for ta in cfg.dnssec.trust_anchors %} +trust_anchors.add('{{ ta }}') +{% endfor %} +{% endif %} + +{% if cfg.dnssec.negative_trust_anchors %} +-- dnssec.negative-trust-anchors +trust_anchors.set_insecure({ +{% for nta in cfg.dnssec.negative_trust_anchors %} + '{{ nta }}', +{% endfor %} +}) +{% endif %} + +{% if cfg.dnssec.trust_anchors_files %} +-- dnssec.trust-anchors-files +{% for taf in cfg.dnssec.trust_anchors_files %} +trust_anchors.add_file('{{ taf.file }}', readonly = {{ boolean(taf.read_only) }}) +{% endfor %} +{% endif %}
\ No newline at end of file diff --git a/python/knot_resolver/datamodel/templates/forward.lua.j2 b/python/knot_resolver/datamodel/templates/forward.lua.j2 new file mode 100644 index 00000000..24311da1 --- /dev/null +++ b/python/knot_resolver/datamodel/templates/forward.lua.j2 @@ -0,0 +1,9 @@ +{% from 'macros/forward_macros.lua.j2' import policy_rule_forward_add %} + +{% if cfg.forward %} +{% for fwd in cfg.forward %} +{% for subtree in fwd.subtree %} +{{ policy_rule_forward_add(subtree,fwd.options,fwd.servers) }} +{% endfor %} +{% endfor %} +{% endif %} diff --git a/python/knot_resolver/datamodel/templates/local_data.lua.j2 b/python/knot_resolver/datamodel/templates/local_data.lua.j2 new file mode 100644 index 00000000..8882471f --- /dev/null +++ b/python/knot_resolver/datamodel/templates/local_data.lua.j2 @@ -0,0 +1,41 @@ +{% from 'macros/local_data_macros.lua.j2' import local_data_rules, local_data_records, local_data_root_fallback_addresses, local_data_root_fallback_addresses_files, local_data_addresses, local_data_addresses_files %} +{% from 'macros/common_macros.lua.j2' import boolean %} + +modules = { 'hints > iterate' } + +{# root-fallback-addresses #} +{% if cfg.local_data.root_fallback_addresses -%} +{{ local_data_root_fallback_addresses(cfg.local_data.root_fallback_addresses) }} +{%- endif %} + +{# root-fallback-addresses-files #} +{% if cfg.local_data.root_fallback_addresses_files -%} +{{ local_data_root_fallback_addresses_files(cfg.local_data.root_fallback_addresses_files) }} +{%- endif %} + +{# addresses #} +{% if cfg.local_data.addresses -%} +{{ local_data_addresses(cfg.local_data.addresses, cfg.local_data.nodata, cfg.local_data.ttl) }} +{%- endif %} + +{# addresses-files #} +{% if cfg.local_data.addresses_files -%} +{{ local_data_addresses_files(cfg.local_data.addresses_files, cfg.local_data.nodata, cfg.local_data.ttl) }} +{%- endif %} + +{# records #} +{% if cfg.local_data.records -%} +{{ local_data_records(cfg.local_data.records, false, cfg.local_data.nodata, cfg.local_data.ttl) }} +{%- endif %} + +{# rules #} +{% if cfg.local_data.rules -%} +{{ local_data_rules(cfg.local_data.rules, cfg.local_data.nodata, cfg.local_data.ttl) }} +{%- endif %} + +{# rpz #} +{% if cfg.local_data.rpz -%} +{% for rpz in cfg.local_data.rpz %} +{{ local_data_records(rpz.file, true, cfg.local_data.nodata, cfg.local_data.ttl, rpz.tags) }} +{% endfor %} +{%- endif %} diff --git a/python/knot_resolver/datamodel/templates/logging.lua.j2 b/python/knot_resolver/datamodel/templates/logging.lua.j2 new file mode 100644 index 00000000..2d5937a8 --- /dev/null +++ b/python/knot_resolver/datamodel/templates/logging.lua.j2 @@ -0,0 +1,43 @@ +{% from 'macros/common_macros.lua.j2' import boolean %} + +-- logging.level +log_level('{{ cfg.logging.level }}') + +{% if cfg.logging.target -%} +-- logging.target +log_target('{{ cfg.logging.target }}') +{%- endif %} + +{% if cfg.logging.groups %} +-- logging.groups +log_groups({ +{% for g in cfg.logging.groups %} +{% if g != "manager" and g != "supervisord" and g != "cache-gc" %} + '{{ g }}', +{% endif %} +{% endfor %} +}) +{% endif %} + +{% if cfg.logging.dnssec_bogus %} +modules.load('bogus_log') +{% endif %} + +{% if cfg.logging.dnstap -%} +-- logging.dnstap +modules.load('dnstap') +dnstap.config({ + socket_path = '{{ cfg.logging.dnstap.unix_socket }}', + client = { + log_queries = {{ boolean(cfg.logging.dnstap.log_queries) }}, + log_responses = {{ boolean(cfg.logging.dnstap.log_responses) }}, + log_tcp_rtt = {{ boolean(cfg.logging.dnstap.log_tcp_rtt) }} + } +}) +{%- endif %} + +-- logging.debugging.assertion-abort +debugging.assertion_abort = {{ boolean(cfg.logging.debugging.assertion_abort) }} + +-- logging.debugging.assertion-fork +debugging.assertion_fork = {{ cfg.logging.debugging.assertion_fork.millis() }} diff --git a/python/knot_resolver/datamodel/templates/macros/cache_macros.lua.j2 b/python/knot_resolver/datamodel/templates/macros/cache_macros.lua.j2 new file mode 100644 index 00000000..51df48da --- /dev/null +++ b/python/knot_resolver/datamodel/templates/macros/cache_macros.lua.j2 @@ -0,0 +1,11 @@ +{% from 'macros/common_macros.lua.j2' import boolean, quotes, qtype_table %} + + +{% macro cache_clear(params) -%} +cache.clear( +{{- quotes(params.name) if params.name else 'nil' -}}, +{{- boolean(params.exact_name) -}}, +{{- qtype_table(params.rr_type) if params.rr_type else 'nil' -}}, +{{- params.chunk_size if not params.exact_name else 'nil' -}} +) +{%- endmacro %} diff --git a/python/knot_resolver/datamodel/templates/macros/common_macros.lua.j2 b/python/knot_resolver/datamodel/templates/macros/common_macros.lua.j2 new file mode 100644 index 00000000..4c2ba11a --- /dev/null +++ b/python/knot_resolver/datamodel/templates/macros/common_macros.lua.j2 @@ -0,0 +1,101 @@ +{% macro quotes(string) -%} +'{{ string }}' +{%- endmacro %} + +{% macro boolean(val, negation=false) -%} +{%- if negation -%} +{{ 'false' if val else 'true' }} +{%- else-%} +{{ 'true' if val else 'false' }} +{%- endif -%} +{%- endmacro %} + +{# Return string or table of strings #} +{% macro string_table(table) -%} +{%- if table is string -%} +'{{ table|string }}' +{%- else-%} +{ +{%- for item in table -%} +'{{ item|string }}', +{%- endfor -%} +} +{%- endif -%} +{%- endmacro %} + +{# Return str2ip or table of str2ip #} +{% macro str2ip_table(table) -%} +{%- if table is string -%} +kres.str2ip('{{ table|string }}') +{%- else-%} +{ +{%- for item in table -%} +kres.str2ip('{{ item|string }}'), +{%- endfor -%} +} +{%- endif -%} +{%- endmacro %} + +{# Return qtype or table of qtype #} +{% macro qtype_table(table) -%} +{%- if table is string -%} +kres.type.{{ table|string }} +{%- else-%} +{ +{%- for item in table -%} +kres.type.{{ item|string }}, +{%- endfor -%} +} +{%- endif -%} +{%- endmacro %} + +{# Return server address or table of server addresses #} +{% macro servers_table(servers) -%} +{%- if servers is string -%} +'{{ servers|string }}' +{%- else-%} +{ +{%- for item in servers -%} +{%- if item.address -%} +'{{ item.address|string }}', +{%- else -%} +'{{ item|string }}', +{%- endif -%} +{%- endfor -%} +} +{%- endif -%} +{%- endmacro %} + +{# Return server address or table of server addresses #} +{% macro tls_servers_table(servers) -%} +{ +{%- for item in servers -%} +{%- if item.address -%} +{'{{ item.address|string }}',{{ tls_server_auth(item) }}}, +{%- else -%} +'{{ item|string }}', +{%- endif -%} +{%- endfor -%} +} +{%- endmacro %} + +{% macro tls_server_auth(server) -%} +{%- if server.hostname -%} +hostname='{{ server.hostname|string }}', +{%- endif -%} +{%- if server.ca_file -%} +ca_file='{{ server.ca_file|string }}', +{%- endif -%} +{%- if server.pin_sha256 -%} +pin_sha256= +{%- if server.pin_sha256 is string -%} +'{{ server.pin_sha256|string }}', +{%- else -%} +{ +{%- for pin in server.pin_sha256 -%} +'{{ pin|string }}', +{%- endfor -%} +} +{%- endif -%} +{%- endif -%} +{%- endmacro %} diff --git a/python/knot_resolver/datamodel/templates/macros/forward_macros.lua.j2 b/python/knot_resolver/datamodel/templates/macros/forward_macros.lua.j2 new file mode 100644 index 00000000..b7723fb0 --- /dev/null +++ b/python/knot_resolver/datamodel/templates/macros/forward_macros.lua.j2 @@ -0,0 +1,42 @@ +{% from 'macros/common_macros.lua.j2' import boolean, string_table %} + +{% macro forward_options(options) -%} +{dnssec={{ boolean(options.dnssec) }},auth={{ boolean(options.authoritative) }}} +{%- endmacro %} + +{% macro forward_server(server) -%} +{%- if server.address -%} +{%- for addr in server.address -%} +{'{{ addr }}', +{%- if server.transport == 'tls' -%} +tls=true, +{%- else -%} +tls=false, +{%- endif -%} +{%- if server.hostname -%} +hostname='{{ server.hostname }}', +{%- endif -%} +{%- if server.pin_sha256 -%} +pin_sha256={{ string_table(server.pin_sha256) }}, +{%- endif -%} +{%- if server.ca_file -%} +ca_file='{{ server.ca_file }}', +{%- endif -%} +}, +{%- endfor -%} +{% else %} +{'{{ server }}'}, +{%- endif -%} +{%- endmacro %} + +{% macro forward_servers(servers) -%} +{ +{%- for server in servers -%} +{{ forward_server(server) }} +{%- endfor -%} +} +{%- endmacro %} + +{% macro policy_rule_forward_add(subtree,options,servers) -%} +policy.rule_forward_add('{{ subtree }}',{{ forward_options(options) }},{{ forward_servers(servers) }}) +{%- endmacro %} diff --git a/python/knot_resolver/datamodel/templates/macros/local_data_macros.lua.j2 b/python/knot_resolver/datamodel/templates/macros/local_data_macros.lua.j2 new file mode 100644 index 00000000..0898571c --- /dev/null +++ b/python/knot_resolver/datamodel/templates/macros/local_data_macros.lua.j2 @@ -0,0 +1,101 @@ +{% from 'macros/common_macros.lua.j2' import string_table, boolean %} +{% from 'macros/policy_macros.lua.j2' import policy_get_tagset, policy_todname %} + + +{% macro local_data_root_fallback_addresses(pairs) -%} +hints.root({ +{% for name, addresses in pairs.items() %} + ['{{ name }}']={{ string_table(addresses) }}, +{% endfor %} +}) +{%- endmacro %} + + +{% macro local_data_root_fallback_addresses_files(files) -%} +{% for file in files %} +hints.root_file('{{ file }}') +{% endfor %} +{%- endmacro %} + +{%- macro local_data_ttl(ttl) -%} +{%- if ttl -%} +{{ ttl.seconds() }} +{%- else -%} +{{ 'C.KR_RULE_TTL_DEFAULT' }} +{%- endif -%} +{%- endmacro -%} + + +{% macro kr_rule_local_address(name, address, nodata, ttl, tags=none) -%} +assert(C.kr_rule_local_address('{{ name }}', '{{ address }}', + {{ boolean(nodata) }}, {{ local_data_ttl(ttl)}}, {{ policy_get_tagset(tags) }}) == 0) +{%- endmacro -%} + + +{% macro local_data_addresses(pairs, nodata, ttl) -%} +{% for name, addresses in pairs.items() %} +{% for address in addresses %} +{{ kr_rule_local_address(name, address, nodata, ttl) }} +{% endfor %} +{% endfor%} +{%- endmacro %} + + +{% macro kr_rule_local_hosts(file, nodata, ttl, tags=none) -%} +assert(C.kr_rule_local_hosts('{{ file }}', {{ boolean(nodata) }}, + {{ local_data_ttl(ttl)}}, {{ policy_get_tagset(tags) }}) == 0) +{%- endmacro %} + + +{% macro local_data_addresses_files(files, nodata, ttl, tags) -%} +{% for file in files %} +{{ kr_rule_local_hosts(file, nodata, ttl, tags) }} +{% endfor %} +{%- endmacro %} + + +{% macro local_data_records(input_str, is_rpz, nodata, ttl, tags=none, id='rrs') -%} +{{ id }} = ffi.new('struct kr_rule_zonefile_config') +{{ id }}.ttl = {{ local_data_ttl(ttl) }} +{{ id }}.tags = {{ policy_get_tagset(tags) }} +{{ id }}.nodata = {{ boolean(nodata) }} +{{ id }}.is_rpz = {{ boolean(is_rpz) }} +{% if is_rpz -%} +{{ id }}.filename = '{{ input_str }}' +{% else %} +{{ id }}.input_str = [[ +{{ input_str.multiline() }} +]] +{% endif %} +assert(C.kr_rule_zonefile({{ id }})==0) +{%- endmacro %} + + +{% macro kr_rule_local_subtree(name, type, ttl, tags=none) -%} +assert(C.kr_rule_local_subtree(todname('{{ name }}'), + C.KR_RULE_SUB_{{ type.upper() }}, {{ local_data_ttl(ttl) }}, {{ policy_get_tagset(tags) }}) == 0) +{%- endmacro %} + + +{% macro local_data_rules(items, nodata, ttl) -%} +{% for item in items %} +{% if item.name %} +{% for name in item.name %} +{% if item.address %} +{% for address in item.address %} +{{ kr_rule_local_address(name, address, nodata if item.nodata is none else item.nodata, item.ttl or ttl, item.tags) }} +{% endfor %} +{% endif %} +{% if item.subtree %} +{{ kr_rule_local_subtree(name, item.subtree, item.ttl or ttl, item.tags) }} +{% endif %} +{% endfor %} +{% elif item.file %} +{% for file in item.file %} +{{ kr_rule_local_hosts(file, nodata if item.nodata is none else item.nodata, item.ttl or ttl, item.tags) }} +{% endfor %} +{% elif item.records %} +{{ local_data_records(item.records, false, nodata if item.nodata is none else item.nodata, item.ttl or ttl, tags) }} +{% endif %} +{% endfor %} +{%- endmacro %} diff --git a/python/knot_resolver/datamodel/templates/macros/network_macros.lua.j2 b/python/knot_resolver/datamodel/templates/macros/network_macros.lua.j2 new file mode 100644 index 00000000..79800f7d --- /dev/null +++ b/python/knot_resolver/datamodel/templates/macros/network_macros.lua.j2 @@ -0,0 +1,55 @@ +{% macro http_config(http_cfg, kind, tls=true) -%} +http.config({tls={{ 'true' if tls else 'false'}}, +{%- if http_cfg.cert_file -%} + cert='{{ http_cfg.cert_file }}', +{%- endif -%} +{%- if http_cfg.key_file -%} + key='{{ http_cfg.key_file }}', +{%- endif -%} +},'{{ kind }}') +{%- endmacro %} + + +{% macro listen_kind(kind) -%} +{%- if kind == "dot" -%} +'tls' +{%- elif kind == "doh-legacy" -%} +'doh_legacy' +{%- else -%} +'{{ kind }}' +{%- endif -%} +{%- endmacro %} + + +{% macro net_listen_unix_socket(path, kind, freebind) -%} +net.listen('{{ path }}',nil,{kind={{ listen_kind(kind) }},freebind={{ 'true' if freebind else 'false'}}}) +{%- endmacro %} + + +{% macro net_listen_interface(interface, kind, freebind, port) -%} +net.listen( +{%- if interface.addr -%} +'{{ interface.addr }}', +{%- elif interface.if_name -%} +net['{{ interface.if_name }}'], +{%- endif -%} +{%- if interface.port -%} +{{ interface.port }}, +{%- else -%} +{{ port }}, +{%- endif -%} +{kind={{ listen_kind(kind) }},freebind={{ 'true' if freebind else 'false'}}}) +{%- endmacro %} + + +{% macro network_listen(listen) -%} +{%- if listen.unix_socket -%} +{% for path in listen.unix_socket %} +{{ net_listen_unix_socket(path, listen.kind, listen.freebind) }} +{% endfor %} +{%- elif listen.interface -%} +{% for interface in listen.interface %} +{{ net_listen_interface(interface, listen.kind, listen.freebind, listen.port) }} +{% endfor %} +{%- endif -%} +{%- endmacro %}
\ No newline at end of file diff --git a/python/knot_resolver/datamodel/templates/macros/policy_macros.lua.j2 b/python/knot_resolver/datamodel/templates/macros/policy_macros.lua.j2 new file mode 100644 index 00000000..347532e6 --- /dev/null +++ b/python/knot_resolver/datamodel/templates/macros/policy_macros.lua.j2 @@ -0,0 +1,279 @@ +{% from 'macros/common_macros.lua.j2' import string_table, str2ip_table, qtype_table, servers_table, tls_servers_table %} + + +{# Add policy #} + +{% macro policy_add(rule, postrule=false) -%} +{%- if postrule -%} +policy.add({{ rule }},true) +{%- else -%} +policy.add({{ rule }}) +{%- endif -%} +{%- endmacro %} + + +{# Slice #} + +{% macro policy_slice_randomize_psl(seed='') -%} +{%- if seed == '' -%} +policy.slice_randomize_psl() +{%- else -%} +policy.slice_randomize_psl(seed={{ seed }}) +{%- endif -%} +{%- endmacro %} + +{% macro policy_slice(func, actions) -%} +policy.slice( +{%- if func == 'randomize-psl' -%} +policy.slice_randomize_psl() +{%- else -%} +policy.slice_randomize_psl() +{%- endif -%} +,{{ actions }}) +{%- endmacro %} + + +{# Flags #} + +{% macro policy_flags(flags) -%} +policy.FLAGS({ +{{- flags -}} +}) +{%- endmacro %} + + +{# Tags assign #} + +{% macro policy_tags_assign(tags) -%} +policy.TAGS_ASSIGN({{ string_table(tags) }}) +{%- endmacro %} + +{% macro policy_get_tagset(tags) -%} +{%- if tags -%} +policy.get_tagset({{ string_table(tags) }}) +{%- else -%} +0 +{%- endif -%} +{%- endmacro %} + + +{# Filters #} + +{% macro policy_all(action) -%} +policy.all({{ action }}) +{%- endmacro %} + +{% macro policy_suffix(action, suffix_table) -%} +policy.suffix({{ action }},{{ suffix_table }}) +{%- endmacro %} + +{% macro policy_suffix_common(action, suffix_table, common_suffix=none) -%} +policy.suffix_common({{ action }},{{ suffix_table }} +{%- if common_suffix -%} +,{{ common_suffix }} +{%- endif -%} +) +{%- endmacro %} + +{% macro policy_pattern(action, pattern) -%} +policy.pattern({{ action }},'{{ pattern }}') +{%- endmacro %} + +{% macro policy_rpz(action, path, watch=true) -%} +policy.rpz({{ action|string }},'{{ path|string }}',{{ 'true' if watch else 'false' }}) +{%- endmacro %} + + +{# Custom filters #} + +{% macro declare_policy_qtype_custom_filter() -%} +function policy_qtype(action, qtype) + + local function has_value (tab, val) + for index, value in ipairs(tab) do + if value == val then + return true + end + end + + return false + end + + return function (state, query) + if query.stype == qtype then + return action + elseif has_value(qtype, query.stype) then + return action + else + return nil + end + end +end +{%- endmacro %} + +{% macro policy_qtype_custom_filter(action, qtype) -%} +policy_qtype({{ action }}, {{ qtype }}) +{%- endmacro %} + + +{# Auto Filter #} + +{% macro policy_auto_filter(action, filter=none) -%} +{%- if filter.suffix -%} +{{ policy_suffix(action, policy_todname(filter.suffix)) }} +{%- elif filter.pattern -%} +{{ policy_pattern(action, filter.pattern) }} +{%- elif filter.qtype -%} +{{ policy_qtype_custom_filter(action, qtype_table(filter.qtype)) }} +{%- else -%} +{{ policy_all(action) }} +{%- endif %} +{%- endmacro %} + + +{# Non-chain actions #} + +{% macro policy_pass() -%} +policy.PASS +{%- endmacro %} + +{% macro policy_deny() -%} +policy.DENY +{%- endmacro %} + +{% macro policy_deny_msg(message) -%} +policy.DENY_MSG('{{ message|string }}') +{%- endmacro %} + +{% macro policy_drop() -%} +policy.DROP +{%- endmacro %} + +{% macro policy_refuse() -%} +policy.REFUSE +{%- endmacro %} + +{% macro policy_tc() -%} +policy.TC +{%- endmacro %} + +{% macro policy_reroute(reroute) -%} +policy.REROUTE( +{%- for item in reroute -%} +{['{{ item.source }}']='{{ item.destination }}'}, +{%- endfor -%} +) +{%- endmacro %} + +{% macro policy_answer(answer) -%} +policy.ANSWER({[kres.type.{{ answer.rtype }}]={rdata= +{%- if answer.rtype in ['A','AAAA'] -%} +{{ str2ip_table(answer.rdata) }}, +{%- elif answer.rtype == '' -%} +{# TODO: Do the same for other record types that require a special rdata type in Lua. +By default, the raw string from config is used. #} +{%- else -%} +{{ string_table(answer.rdata) }}, +{%- endif -%} +ttl={{ answer.ttl.seconds()|int }}}},{{ 'true' if answer.nodata else 'false' }}) +{%- endmacro %} + +{# policy.ANSWER( { [kres.type.A] = { rdata=kres.str2ip('192.0.2.7'), ttl=300 }}) #} + +{# Chain actions #} + +{% macro policy_mirror(mirror) -%} +policy.MIRROR( +{% if mirror is string %} +'{{ mirror }}' +{% else %} +{ +{%- for addr in mirror -%} +'{{ addr }}', +{%- endfor -%} +} +{%- endif -%} +) +{%- endmacro %} + +{% macro policy_debug_always() -%} +policy.DEBUG_ALWAYS +{%- endmacro %} + +{% macro policy_debug_cache_miss() -%} +policy.DEBUG_CACHE_MISS +{%- endmacro %} + +{% macro policy_qtrace() -%} +policy.QTRACE +{%- endmacro %} + +{% macro policy_reqtrace() -%} +policy.REQTRACE +{%- endmacro %} + +{% macro policy_stub(servers) -%} +policy.STUB({{ servers_table(servers) }}) +{%- endmacro %} + +{% macro policy_forward(servers) -%} +policy.FORWARD({{ servers_table(servers) }}) +{%- endmacro %} + +{% macro policy_tls_forward(servers) -%} +policy.TLS_FORWARD({{ tls_servers_table(servers) }}) +{%- endmacro %} + + +{# Auto action #} + +{% macro policy_auto_action(rule) -%} +{%- if rule.action == 'pass' -%} +{{ policy_pass() }} +{%- elif rule.action == 'deny' -%} +{%- if rule.message -%} +{{ policy_deny_msg(rule.message) }} +{%- else -%} +{{ policy_deny() }} +{%- endif -%} +{%- elif rule.action == 'drop' -%} +{{ policy_drop() }} +{%- elif rule.action == 'refuse' -%} +{{ policy_refuse() }} +{%- elif rule.action == 'tc' -%} +{{ policy_tc() }} +{%- elif rule.action == 'reroute' -%} +{{ policy_reroute(rule.reroute) }} +{%- elif rule.action == 'answer' -%} +{{ policy_answer(rule.answer) }} +{%- elif rule.action == 'mirror' -%} +{{ policy_mirror(rule.mirror) }} +{%- elif rule.action == 'debug-always' -%} +{{ policy_debug_always() }} +{%- elif rule.action == 'debug-cache-miss' -%} +{{ policy_sebug_cache_miss() }} +{%- elif rule.action == 'qtrace' -%} +{{ policy_qtrace() }} +{%- elif rule.action == 'reqtrace' -%} +{{ policy_reqtrace() }} +{%- endif -%} +{%- endmacro %} + + +{# Other #} + +{% macro policy_todname(name) -%} +todname('{{ name.punycode()|string }}') +{%- endmacro %} + +{% macro policy_todnames(names) -%} +policy.todnames({ +{%- if names is string -%} +'{{ names.punycode()|string }}' +{%- else -%} +{%- for name in names -%} +'{{ name.punycode()|string }}', +{%- endfor -%} +{%- endif -%} +}) +{%- endmacro %}
\ No newline at end of file diff --git a/python/knot_resolver/datamodel/templates/macros/view_macros.lua.j2 b/python/knot_resolver/datamodel/templates/macros/view_macros.lua.j2 new file mode 100644 index 00000000..2f1a7964 --- /dev/null +++ b/python/knot_resolver/datamodel/templates/macros/view_macros.lua.j2 @@ -0,0 +1,25 @@ +{%- macro get_proto_set(protocols) -%} +0 +{%- for p in protocols or [] -%} + + 2^C.KR_PROTO_{{ p.upper() }} +{%- endfor -%} +{%- endmacro -%} + +{% macro view_flags(options) -%} +{% if not options.minimize -%} +"NO_MINIMIZE", +{%- endif %} +{% if not options.dns64 -%} +"DNS64_DISABLE", +{%- endif %} +{%- endmacro %} + +{% macro view_answer(answer) -%} +{%- if answer == 'allow' -%} +policy.TAGS_ASSIGN({}) +{%- elif answer == 'refused' -%} +'policy.REFUSE' +{%- elif answer == 'noanswer' -%} +'policy.NO_ANSWER' +{%- endif -%} +{%- endmacro %} diff --git a/python/knot_resolver/datamodel/templates/monitoring.lua.j2 b/python/knot_resolver/datamodel/templates/monitoring.lua.j2 new file mode 100644 index 00000000..624b59ab --- /dev/null +++ b/python/knot_resolver/datamodel/templates/monitoring.lua.j2 @@ -0,0 +1,33 @@ +--- control socket location +local ffi = require('ffi') +local id = os.getenv('SYSTEMD_INSTANCE') +if not id then + log_error(ffi.C.LOG_GRP_SYSTEM, 'environment variable $SYSTEMD_INSTANCE not set, which should not have been possible due to running under manager') +else + -- Bind to control socket in CWD (= rundir in config) + -- FIXME replace with relative path after fixing https://gitlab.nic.cz/knot/knot-resolver/-/issues/720 + local path = '{{ cwd }}/control/'..id + log_warn(ffi.C.LOG_GRP_SYSTEM, 'path = ' .. path) + local ok, err = pcall(net.listen, path, nil, { kind = 'control' }) + if not ok then + log_warn(ffi.C.LOG_GRP_NETWORK, 'bind to '..path..' failed '..err) + end +end + +{% if cfg.monitoring.enabled == "always" %} +modules.load('stats') +{% endif %} + +--- function used for statistics collection +function collect_lazy_statistics() + if stats == nil then + modules.load('stats') + end + + return stats.list() +end + +--- function used for statistics collection +function collect_statistics() + return stats.list() +end diff --git a/python/knot_resolver/datamodel/templates/network.lua.j2 b/python/knot_resolver/datamodel/templates/network.lua.j2 new file mode 100644 index 00000000..665ee454 --- /dev/null +++ b/python/knot_resolver/datamodel/templates/network.lua.j2 @@ -0,0 +1,102 @@ +{% from 'macros/common_macros.lua.j2' import boolean %} +{% from 'macros/network_macros.lua.j2' import network_listen, http_config %} + +-- network.do-ipv4/6 +net.ipv4 = {{ boolean(cfg.network.do_ipv4) }} +net.ipv6 = {{ boolean(cfg.network.do_ipv6) }} + +{% if cfg.network.out_interface_v4 %} +-- network.out-interface-v4 +net.outgoing_v4('{{ cfg.network.out_interface_v4 }}') +{% endif %} + +{% if cfg.network.out_interface_v6 %} +-- network.out-interface-v6 +net.outgoing_v6('{{ cfg.network.out_interface_v6 }}') +{% endif %} + +-- network.tcp-pipeline +net.tcp_pipeline({{ cfg.network.tcp_pipeline }}) + +-- network.edns-keep-alive +{% if cfg.network.edns_tcp_keepalive %} +modules.load('edns_keepalive') +{% else %} +modules.unload('edns_keepalive') +{% endif %} + +-- network.edns-buffer-size +net.bufsize( + {{ cfg.network.edns_buffer_size.upstream.bytes() }}, + {{ cfg.network.edns_buffer_size.downstream.bytes() }} +) + +{% if cfg.network.tls.cert_file and cfg.network.tls.key_file %} +-- network.tls +net.tls('{{ cfg.network.tls.cert_file }}', '{{ cfg.network.tls.key_file }}') +{% endif %} + +{% if cfg.network.tls.sticket_secret %} +-- network.tls.sticket-secret +net.tls_sticket_secret('{{ cfg.network.tls.sticket_secret }}') +{% endif %} + +{% if cfg.network.tls.sticket_secret_file %} +-- network.tls.sticket-secret-file +net.tls_sticket_secret_file('{{ cfg.network.tls.sticket_secret_file }}') +{% endif %} + +{% if cfg.network.tls.auto_discovery %} +-- network.tls.auto-discovery +modules.load('experimental_dot_auth') +{% else %} +-- modules.unload('experimental_dot_auth') +{% endif %} + +-- network.tls.padding +net.tls_padding( +{%- if cfg.network.tls.padding == true -%} +true +{%- elif cfg.network.tls.padding == false -%} +false +{%- else -%} +{{ cfg.network.tls.padding }} +{%- endif -%} +) + +{% if cfg.network.address_renumbering %} +-- network.address-renumbering +modules.load('renumber') +renumber.config = { +{% for item in cfg.network.address_renumbering %} + {'{{ item.source }}', '{{ item.destination }}'}, +{% endfor %} +} +{% endif %} + +{%- set vars = {'doh_legacy': False} -%} +{% for listen in cfg.network.listen if listen.kind == "doh-legacy" -%} +{%- if vars.update({'doh_legacy': True}) -%}{%- endif -%} +{%- endfor %} + +{% if vars.doh_legacy %} +-- doh_legacy http config +modules.load('http') +{{ http_config(cfg.network.tls,"doh_legacy") }} +{% endif %} + +{% if cfg.network.proxy_protocol %} +-- network.proxy-protocol +net.proxy_allowed({ +{% for item in cfg.network.proxy_protocol.allow %} +'{{ item }}', +{% endfor %} +}) +{% else %} +net.proxy_allowed({}) +{% endif %} + +-- network.listen +{% for listen in cfg.network.listen %} +{{ network_listen(listen) }} +{% endfor %}
\ No newline at end of file diff --git a/python/knot_resolver/datamodel/templates/options.lua.j2 b/python/knot_resolver/datamodel/templates/options.lua.j2 new file mode 100644 index 00000000..8210fb6d --- /dev/null +++ b/python/knot_resolver/datamodel/templates/options.lua.j2 @@ -0,0 +1,52 @@ +{% from 'macros/common_macros.lua.j2' import boolean %} + +-- options.glue-checking +mode('{{ cfg.options.glue_checking }}') + +{% if cfg.options.rebinding_protection %} +-- options.rebinding-protection +modules.load('rebinding < iterate') +{% endif %} + +{% if cfg.options.violators_workarounds %} +-- options.violators-workarounds +modules.load('workarounds < iterate') +{% endif %} + +{% if cfg.options.serve_stale %} +-- options.serve-stale +modules.load('serve_stale < cache') +{% endif %} + +-- options.query-priming +{% if cfg.options.priming %} +modules.load('priming') +{% else %} +modules.unload('priming') +{% endif %} + +-- options.time-jump-detection +{% if cfg.options.time_jump_detection %} +modules.load('detect_time_jump') +{% else %} +modules.unload('detect_time_jump') +{% endif %} + +-- options.refuse-no-rd +{% if cfg.options.refuse_no_rd %} +modules.load('refuse_nord') +{% else %} +modules.unload('refuse_nord') +{% endif %} + +-- options.qname-minimisation +option('NO_MINIMIZE', {{ boolean(cfg.options.minimize,true) }}) + +-- options.query-loopback +option('ALLOW_LOCAL', {{ boolean(cfg.options.query_loopback) }}) + +-- options.reorder-rrset +option('REORDER_RR', {{ boolean(cfg.options.reorder_rrset) }}) + +-- options.query-case-randomization +option('NO_0X20', {{ boolean(cfg.options.query_case_randomization,true) }})
\ No newline at end of file diff --git a/python/knot_resolver/datamodel/templates/policy-config.lua.j2 b/python/knot_resolver/datamodel/templates/policy-config.lua.j2 new file mode 100644 index 00000000..4c5c9048 --- /dev/null +++ b/python/knot_resolver/datamodel/templates/policy-config.lua.j2 @@ -0,0 +1,40 @@ +{% if not cfg.lua.script_only %} + +-- FFI library +ffi = require('ffi') +local C = ffi.C + +-- logging.level +log_level('{{ cfg.logging.level }}') + +{% if cfg.logging.target -%} +-- logging.target +log_target('{{ cfg.logging.target }}') +{%- endif %} + +{% if cfg.logging.groups %} +-- logging.groups +log_groups({ +{% for g in cfg.logging.groups %} +{% if g != "manager" and g != "supervisord" and g != "cache-gc" %} + '{{ g }}', +{% endif %} +{% endfor %} +}) +{% endif %} + +-- Config required for the cache opening +cache.open({{ cfg.cache.size_max.bytes() }}, 'lmdb://{{ cfg.cache.storage }}') + +-- VIEWS section ------------------------------------ +{% include "views.lua.j2" %} + +-- LOCAL-DATA section ------------------------------- +{% include "local_data.lua.j2" %} + +-- FORWARD section ---------------------------------- +{% include "forward.lua.j2" %} + +{% endif %} + +quit() diff --git a/python/knot_resolver/datamodel/templates/static_hints.lua.j2 b/python/knot_resolver/datamodel/templates/static_hints.lua.j2 new file mode 100644 index 00000000..130facf9 --- /dev/null +++ b/python/knot_resolver/datamodel/templates/static_hints.lua.j2 @@ -0,0 +1,51 @@ +{% if cfg.static_hints.etc_hosts or cfg.static_hints.root_hints_file or cfg.static_hints.hints_files or cfg.static_hints.root_hints or cfg.static_hints.hints %} +modules.load('hints > iterate') + +{% if cfg.static_hints.ttl %} +-- static-hints.ttl +hints.ttl({{ cfg.static_hints.ttl.seconds()|string }}) +{% endif %} + +-- static-hints.no-data +hints.use_nodata({{ 'true' if cfg.static_hints.nodata else 'false' }}) + +{% if cfg.static_hints.etc_hosts %} +-- static-hints.etc-hosts +hints.add_hosts('/etc/hosts') +{% endif %} + +{% if cfg.static_hints.root_hints_file %} +-- static-hints.root-hints-file +hints.root_file('{{ cfg.static_hints.root_hints_file }}') +{% endif %} + +{% if cfg.static_hints.hints_files %} +-- static-hints.hints-files +{% for item in cfg.static_hints.hints_files %} +hints.add_hosts('{{ item }}') +{% endfor %} +{% endif %} + +{% if cfg.static_hints.root_hints %} +-- static-hints.root-hints +hints.root({ +{% for name, addrs in cfg.static_hints.root_hints.items() %} +['{{ name.punycode() }}'] = { +{% for addr in addrs %} + '{{ addr }}', +{% endfor %} + }, +{% endfor %} +}) +{% endif %} + +{% if cfg.static_hints.hints %} +-- static-hints.hints +{% for name, addrs in cfg.static_hints.hints.items() %} +{% for addr in addrs %} +hints.set('{{ name.punycode() }} {{ addr }}') +{% endfor %} +{% endfor %} +{% endif %} + +{% endif %}
\ No newline at end of file diff --git a/python/knot_resolver/datamodel/templates/views.lua.j2 b/python/knot_resolver/datamodel/templates/views.lua.j2 new file mode 100644 index 00000000..81de8c7b --- /dev/null +++ b/python/knot_resolver/datamodel/templates/views.lua.j2 @@ -0,0 +1,25 @@ +{% from 'macros/common_macros.lua.j2' import quotes %} +{% from 'macros/view_macros.lua.j2' import get_proto_set, view_flags, view_answer %} +{% from 'macros/policy_macros.lua.j2' import policy_flags, policy_tags_assign %} + +{% if cfg.views %} +{% for view in cfg.views %} +{% for subnet in view.subnets %} + +assert(C.kr_view_insert_action('{{ subnet }}', '{{ view.dst_subnet or '' }}', + {{ get_proto_set(view.protocols) }}, policy.COMBINE({ +{%- set flags = view_flags(view.options) -%} +{% if flags %} + {{ quotes(policy_flags(flags)) }}, +{%- endif %} + +{% if view.tags %} + {{ policy_tags_assign(view.tags) }}, +{% elif view.answer %} + {{ view_answer(view.answer) }}, +{%- endif %} + })) == 0) + +{% endfor %} +{% endfor %} +{% endif %} diff --git a/python/knot_resolver/datamodel/templates/webmgmt.lua.j2 b/python/knot_resolver/datamodel/templates/webmgmt.lua.j2 new file mode 100644 index 00000000..938ea8da --- /dev/null +++ b/python/knot_resolver/datamodel/templates/webmgmt.lua.j2 @@ -0,0 +1,25 @@ +{% from 'macros/common_macros.lua.j2' import boolean %} + +{% if cfg.webmgmt -%} +-- webmgmt +modules.load('http') +http.config({tls = {{ boolean(cfg.webmgmt.tls) }}, +{%- if cfg.webmgmt.cert_file -%} + cert = '{{ cfg.webmgmt.cert_file }}', +{%- endif -%} +{%- if cfg.webmgmt.cert_file -%} + key = '{{ cfg.webmgmt.key_file }}', +{%- endif -%} +}, 'webmgmt') +net.listen( +{%- if cfg.webmgmt.unix_socket -%} + '{{ cfg.webmgmt.unix_socket }}',nil, +{%- elif cfg.webmgmt.interface -%} + {%- if cfg.webmgmt.interface.addr -%} + '{{ cfg.webmgmt.interface.addr }}',{{ cfg.webmgmt.interface.port }}, + {%- elif cfg.webmgmt.interface.if_name -%} + net.{{ cfg.webmgmt.interface.if_name }},{{ cfg.webmgmt.interface.port }}, + {%- endif -%} +{%- endif -%} +{ kind = 'webmgmt' }) +{%- endif %}
\ No newline at end of file diff --git a/python/knot_resolver/datamodel/templates/worker-config.lua.j2 b/python/knot_resolver/datamodel/templates/worker-config.lua.j2 new file mode 100644 index 00000000..17c49fb0 --- /dev/null +++ b/python/knot_resolver/datamodel/templates/worker-config.lua.j2 @@ -0,0 +1,58 @@ +{% if not cfg.lua.script_only %} + +-- FFI library +ffi = require('ffi') +local C = ffi.C + +-- Do not clear the DB with rules; we had it prepared by a different process. +assert(C.kr_rules_init(nil, 0, false) == 0) + +-- hostname +hostname('{{ cfg.hostname }}') + +{% if cfg.nsid %} +-- nsid +modules.load('nsid') +nsid.name('{{ cfg.nsid }}' .. worker.id) +{% endif %} + +-- LOGGING section ---------------------------------- +{% include "logging.lua.j2" %} + +-- MONITORING section ------------------------------- +{% include "monitoring.lua.j2" %} + +-- WEBMGMT section ---------------------------------- +{% include "webmgmt.lua.j2" %} + +-- OPTIONS section ---------------------------------- +{% include "options.lua.j2" %} + +-- NETWORK section ---------------------------------- +{% include "network.lua.j2" %} + +-- DNSSEC section ----------------------------------- +{% include "dnssec.lua.j2" %} + +-- FORWARD section ---------------------------------- +{% include "forward.lua.j2" %} + +-- CACHE section ------------------------------------ +{% include "cache.lua.j2" %} + +-- DNS64 section ------------------------------------ +{% include "dns64.lua.j2" %} + +{% endif %} + +-- LUA section -------------------------------------- +-- Custom Lua code cannot be validated + +{% if cfg.lua.script_file %} +{% import cfg.lua.script_file as script_file %} +{{ script_file }} +{% endif %} + +{% if cfg.lua.script %} +{{ cfg.lua.script }} +{% endif %} diff --git a/python/knot_resolver/datamodel/types/__init__.py b/python/knot_resolver/datamodel/types/__init__.py new file mode 100644 index 00000000..a3d7db3e --- /dev/null +++ b/python/knot_resolver/datamodel/types/__init__.py @@ -0,0 +1,69 @@ +from .enums import DNSRecordTypeEnum, PolicyActionEnum, PolicyFlagEnum +from .files import AbsoluteDir, Dir, File, FilePath, ReadableFile, WritableDir, WritableFilePath +from .generic_types import ListOrItem +from .types import ( + DomainName, + EscapedStr, + EscapedStr32B, + IDPattern, + Int0_512, + Int0_65535, + InterfaceName, + InterfaceOptionalPort, + InterfacePort, + IntNonNegative, + IntPositive, + IPAddress, + IPAddressEM, + IPAddressOptionalPort, + IPAddressPort, + IPNetwork, + IPv4Address, + IPv6Address, + IPv6Network, + IPv6Network96, + Percent, + PinSha256, + PortNumber, + SizeUnit, + TimeUnit, +) + +__all__ = [ + "PolicyActionEnum", + "PolicyFlagEnum", + "DNSRecordTypeEnum", + "DomainName", + "EscapedStr", + "EscapedStr32B", + "IDPattern", + "Int0_512", + "Int0_65535", + "InterfaceName", + "InterfaceOptionalPort", + "InterfacePort", + "IntNonNegative", + "IntPositive", + "IPAddress", + "IPAddressEM", + "IPAddressOptionalPort", + "IPAddressPort", + "IPNetwork", + "IPv4Address", + "IPv6Address", + "IPv6Network", + "IPv6Network96", + "ListOrItem", + "Percent", + "PinSha256", + "PortNumber", + "SizeUnit", + "TimeUnit", + "AbsoluteDir", + "ReadableFile", + "WritableDir", + "WritableFilePath", + "File", + "FilePath", + "Dir", +] diff --git a/python/knot_resolver/datamodel/types/base_types.py b/python/knot_resolver/datamodel/types/base_types.py new file mode 100644 index 00000000..c2d60312 --- /dev/null +++ b/python/knot_resolver/datamodel/types/base_types.py @@ -0,0 +1,227 @@ +import re +from typing import Any, Dict, Pattern, Type + +from knot_resolver.utils.modeling import BaseValueType + + +class IntBase(BaseValueType): + """ + Base class to work with integer value. + """ + + _orig_value: int + _value: int + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + if isinstance(source_value, int) and not isinstance(source_value, bool): + self._orig_value = source_value + self._value = source_value + else: + raise ValueError( + f"Unexpected value for '{type(self)}'." + f" Expected integer, got '{source_value}' with type '{type(source_value)}'", + object_path, + ) + + def __int__(self) -> int: + return self._value + + def __str__(self) -> str: + return str(self._value) + + def __repr__(self) -> str: + return f'{type(self).__name__}("{self._value}")' + + def __eq__(self, o: object) -> bool: + return isinstance(o, IntBase) and o._value == self._value + + def serialize(self) -> Any: + return self._orig_value + + @classmethod + def json_schema(cls: Type["IntBase"]) -> Dict[Any, Any]: + return {"type": "integer"} + + +class StrBase(BaseValueType): + """ + Base class to work with string value. + """ + + _orig_value: str + _value: str + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + if isinstance(source_value, (str, int)) and not isinstance(source_value, bool): + self._orig_value = str(source_value) + self._value = str(source_value) + else: + raise ValueError( + f"Unexpected value for '{type(self)}'." + f" Expected string, got '{source_value}' with type '{type(source_value)}'", + object_path, + ) + + def __int__(self) -> int: + raise ValueError("Can't convert string to an integer.") + + def __str__(self) -> str: + return self._value + + def __repr__(self) -> str: + return f'{type(self).__name__}("{self._value}")' + + def __hash__(self) -> int: + return hash(self._value) + + def __eq__(self, o: object) -> bool: + return isinstance(o, StrBase) and o._value == self._value + + def serialize(self) -> Any: + return self._orig_value + + @classmethod + def json_schema(cls: Type["StrBase"]) -> Dict[Any, Any]: + return {"type": "string"} + + +class StringLengthBase(StrBase): + """ + Base class to work with string value length. + Just inherit the class and set the values for '_min_bytes' and '_max_bytes'. + + class String32B(StringLengthBase): + _min_bytes: int = 32 + """ + + _min_bytes: int = 1 + _max_bytes: int + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + super().__init__(source_value, object_path) + value_bytes = len(self._value.encode("utf-8")) + if hasattr(self, "_min_bytes") and (value_bytes < self._min_bytes): + raise ValueError( + f"the string value {source_value} is shorter than the minimum {self._min_bytes} bytes.", object_path + ) + if hasattr(self, "_max_bytes") and (value_bytes > self._max_bytes): + raise ValueError( + f"the string value {source_value} is longer than the maximum {self._max_bytes} bytes.", object_path + ) + + @classmethod + def json_schema(cls: Type["StringLengthBase"]) -> Dict[Any, Any]: + typ: Dict[str, Any] = {"type": "string"} + if hasattr(cls, "_min_bytes"): + typ["minLength"] = cls._min_bytes + if hasattr(cls, "_max_bytes"): + typ["maxLength"] = cls._max_bytes + return typ + + +class IntRangeBase(IntBase): + """ + Base class to work with integer value in range. + Just inherit the class and set the values for '_min' and '_max'. + + class IntNonNegative(IntRangeBase): + _min: int = 0 + """ + + _min: int + _max: int + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + super().__init__(source_value, object_path) + if hasattr(self, "_min") and (self._value < self._min): + raise ValueError(f"value {self._value} is lower than the minimum {self._min}.", object_path) + if hasattr(self, "_max") and (self._value > self._max): + raise ValueError(f"value {self._value} is higher than the maximum {self._max}", object_path) + + @classmethod + def json_schema(cls: Type["IntRangeBase"]) -> Dict[Any, Any]: + typ: Dict[str, Any] = {"type": "integer"} + if hasattr(cls, "_min"): + typ["minimum"] = cls._min + if hasattr(cls, "_max"): + typ["maximum"] = cls._max + return typ + + +class PatternBase(StrBase): + """ + Base class to work with string value that match regex pattern. + Just inherit the class and set regex pattern for '_re'. + + class ABPattern(PatternBase): + _re: Pattern[str] = re.compile(r"ab*") + """ + + _re: Pattern[str] + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + super().__init__(source_value, object_path) + if not type(self)._re.match(self._value): + raise ValueError(f"'{self._value}' does not match '{self._re.pattern}' pattern", object_path) + + @classmethod + def json_schema(cls: Type["PatternBase"]) -> Dict[Any, Any]: + return {"type": "string", "pattern": rf"{cls._re.pattern}"} + + +class UnitBase(StrBase): + """ + Base class to work with string value that match regex pattern. + Just inherit the class and set '_units'. + + class CustomUnit(PatternBase): + _units = {"b": 1, "kb": 1000} + """ + + _re: Pattern[str] + _units: Dict[str, int] + _base_value: int + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + super().__init__(source_value, object_path) + + type(self)._re = re.compile(rf"^(\d+)({r'|'.join(type(self)._units.keys())})$") + grouped = self._re.search(self._value) + if grouped: + val, unit = grouped.groups() + if unit is None: + raise ValueError(f"Missing units. Accepted units are {list(type(self)._units.keys())}", object_path) + elif unit not in type(self)._units: + raise ValueError( + f"Used unexpected unit '{unit}' for {type(self).__name__}." + f" Accepted units are {list(type(self)._units.keys())}", + object_path, + ) + self._base_value = int(val) * type(self)._units[unit] + else: + raise ValueError( + f"Unexpected value for '{type(self)}'." + " Expected string that matches pattern " + rf"'{type(self)._re.pattern}'." + f" Positive integer and one of the units {list(type(self)._units.keys())}, got '{source_value}'.", + object_path, + ) + + def __int__(self) -> int: + return self._base_value + + def __repr__(self) -> str: + return f"Unit[{type(self).__name__},{self._value}]" + + def __eq__(self, o: object) -> bool: + """ + Two instances are equal when they represent the same size + regardless of their string representation. + """ + return isinstance(o, UnitBase) and o._value == self._value + + def serialize(self) -> Any: + return self._orig_value + + @classmethod + def json_schema(cls: Type["UnitBase"]) -> Dict[Any, Any]: + return {"type": "string", "pattern": rf"{cls._re.pattern}"} diff --git a/python/knot_resolver/datamodel/types/enums.py b/python/knot_resolver/datamodel/types/enums.py new file mode 100644 index 00000000..bc93ae2f --- /dev/null +++ b/python/knot_resolver/datamodel/types/enums.py @@ -0,0 +1,153 @@ +from typing_extensions import Literal + +# Policy actions +PolicyActionEnum = Literal[ + # Nonchain actions + "pass", + "deny", + "drop", + "refuse", + "tc", + "reroute", + "answer", + # Chain actions + "mirror", + "forward", + "stub", + "debug-always", + "debug-cache-miss", + "qtrace", + "reqtrace", +] + +# FLAGS from https://www.knot-resolver.cz/documentation/latest/lib.html?highlight=options#c.kr_qflags +PolicyFlagEnum = Literal[ + "no-minimize", + "no-ipv4", + "no-ipv6", + "tcp", + "resolved", + "await-ipv4", + "await-ipv6", + "await-cut", + "no-edns", + "cached", + "no-cache", + "expiring", + "allow_local", + "dnssec-want", + "dnssec-bogus", + "dnssec-insecure", + "dnssec-cd", + "stub", + "always-cut", + "dnssec-wexpand", + "permissive", + "strict", + "badcookie-again", + "cname", + "reorder-rr", + "trace", + "no-0x20", + "dnssec-nods", + "dnssec-optout", + "nonauth", + "forward", + "dns64-mark", + "cache-tried", + "no-ns-found", + "pkt-is-sane", + "dns64-disable", +] + +# DNS records from 'kres.type' table +DNSRecordTypeEnum = Literal[ + "A", + "A6", + "AAAA", + "AFSDB", + "ANY", + "APL", + "ATMA", + "AVC", + "AXFR", + "CAA", + "CDNSKEY", + "CDS", + "CERT", + "CNAME", + "CSYNC", + "DHCID", + "DLV", + "DNAME", + "DNSKEY", + "DOA", + "DS", + "EID", + "EUI48", + "EUI64", + "GID", + "GPOS", + "HINFO", + "HIP", + "HTTPS", + "IPSECKEY", + "ISDN", + "IXFR", + "KEY", + "KX", + "L32", + "L64", + "LOC", + "LP", + "MAILA", + "MAILB", + "MB", + "MD", + "MF", + "MG", + "MINFO", + "MR", + "MX", + "NAPTR", + "NID", + "NIMLOC", + "NINFO", + "NS", + "NSAP", + "NSAP-PTR", + "NSEC", + "NSEC3", + "NSEC3PARAM", + "NULL", + "NXT", + "OPENPGPKEY", + "OPT", + "PTR", + "PX", + "RKEY", + "RP", + "RRSIG", + "RT", + "SIG", + "SINK", + "SMIMEA", + "SOA", + "SPF", + "SRV", + "SSHFP", + "SVCB", + "TA", + "TALINK", + "TKEY", + "TLSA", + "TSIG", + "TXT", + "UID", + "UINFO", + "UNSPEC", + "URI", + "WKS", + "X25", + "ZONEMD", +] diff --git a/python/knot_resolver/datamodel/types/files.py b/python/knot_resolver/datamodel/types/files.py new file mode 100644 index 00000000..920d90b1 --- /dev/null +++ b/python/knot_resolver/datamodel/types/files.py @@ -0,0 +1,245 @@ +import os +import stat +from enum import auto, Flag +from grp import getgrnam +from pathlib import Path +from pwd import getpwnam +from typing import Any, Dict, Tuple, Type, TypeVar + +from knot_resolver.manager.constants import kresd_group, kresd_user +from knot_resolver.datamodel.globals import get_resolve_root, get_strict_validation +from knot_resolver.utils.modeling.base_value_type import BaseValueType + + +class UncheckedPath(BaseValueType): + """ + Wrapper around pathlib.Path object. Can represent pretty much any Path. No checks are + performed on the value. The value is taken as is. + """ + + _value: Path + + def __init__( + self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/" + ) -> None: + self._object_path: str = object_path + self._parents: Tuple[UncheckedPath, ...] = parents + self.strict_validation: bool = get_strict_validation() + + if isinstance(source_value, str): + # we do not load global validation context if the path is absolute + # this prevents errors when constructing defaults in the schema + if source_value.startswith("/"): + resolve_root = Path("/") + else: + resolve_root = get_resolve_root() + + self._raw_value: str = source_value + if self._parents: + pp = map(lambda p: p.to_path(), self._parents) + self._value: Path = Path(resolve_root, *pp, source_value) + else: + self._value: Path = Path(resolve_root, source_value) + else: + raise ValueError(f"expected file path in a string, got '{source_value}' with type '{type(source_value)}'.") + + def __str__(self) -> str: + return str(self._value) + + def __eq__(self, o: object) -> bool: + if not isinstance(o, UncheckedPath): + return False + + return o._value == self._value + + def __int__(self) -> int: + raise RuntimeError("Path cannot be converted to type <int>") + + def to_path(self) -> Path: + return self._value + + def serialize(self) -> Any: + return self._raw_value + + def relative_to(self, parent: "UncheckedPath") -> "UncheckedPath": + """return a path with an added parent part""" + return UncheckedPath(self._raw_value, parents=(parent, *self._parents), object_path=self._object_path) + + UPT = TypeVar("UPT", bound="UncheckedPath") + + def reconstruct(self, cls: Type[UPT]) -> UPT: + """ + Rebuild this object as an instance of its subclass. Practically, allows for conversions from + """ + return cls(self._raw_value, parents=self._parents, object_path=self._object_path) + + @classmethod + def json_schema(cls: Type["UncheckedPath"]) -> Dict[Any, Any]: + return { + "type": "string", + } + + +class Dir(UncheckedPath): + """ + Path, that is enforced to be: + - an existing directory + """ + + def __init__( + self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/" + ) -> None: + super().__init__(source_value, parents=parents, object_path=object_path) + if self.strict_validation and not self._value.is_dir(): + raise ValueError(f"path '{self._value}' does not point to an existing directory") + + +class AbsoluteDir(Dir): + """ + Path, that is enforced to be: + - absolute + - an existing directory + """ + + def __init__( + self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/" + ) -> None: + super().__init__(source_value, parents=parents, object_path=object_path) + if self.strict_validation and not self._value.is_absolute(): + raise ValueError(f"path '{self._value}' is not absolute") + + +class File(UncheckedPath): + """ + Path, that is enforced to be: + - an existing file + """ + + def __init__( + self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/" + ) -> None: + super().__init__(source_value, parents=parents, object_path=object_path) + if self.strict_validation and not self._value.exists(): + raise ValueError(f"file '{self._value}' does not exist") + if self.strict_validation and not self._value.is_file(): + raise ValueError(f"path '{self._value}' is not a file") + + +class FilePath(UncheckedPath): + """ + Path, that is enforced to be: + - parent of the last path segment is an existing directory + - it does not point to a dir + """ + + def __init__( + self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/" + ) -> None: + super().__init__(source_value, parents=parents, object_path=object_path) + p = self._value.parent + if self.strict_validation and (not p.exists() or not p.is_dir()): + raise ValueError(f"path '{self._value}' does not point inside an existing directory") + + if self.strict_validation and self._value.is_dir(): + raise ValueError(f"path '{self._value}' points to a directory when we expected a file") + + +class _PermissionMode(Flag): + READ = auto() + WRITE = auto() + EXECUTE = auto() + + +def _kres_accessible(dest_path: Path, perm_mode: _PermissionMode) -> bool: + chflags = { + _PermissionMode.READ: [stat.S_IRUSR, stat.S_IRGRP, stat.S_IROTH], + _PermissionMode.WRITE: [stat.S_IWUSR, stat.S_IWGRP, stat.S_IWOTH], + _PermissionMode.EXECUTE: [stat.S_IXUSR, stat.S_IXGRP, stat.S_IXOTH], + } + + username = kresd_user() + groupname = kresd_group() + + if username is None or groupname is None: + return True + + user_uid = getpwnam(username).pw_uid + user_gid = getgrnam(groupname).gr_gid + + dest_stat = os.stat(dest_path) + dest_uid = dest_stat.st_uid + dest_gid = dest_stat.st_gid + dest_mode = dest_stat.st_mode + + def accessible(perm: _PermissionMode) -> bool: + if user_uid == dest_uid: + return bool(dest_mode & chflags[perm][0]) + b_groups = os.getgrouplist(os.getlogin(), user_gid) + if user_gid == dest_gid or dest_gid in b_groups: + return bool(dest_mode & chflags[perm][1]) + return bool(dest_mode & chflags[perm][2]) + + # __iter__ for class enum.Flag added in python3.11 + # 'for perm in perm_mode:' failes for <=python3.11 + for perm in _PermissionMode: + if perm in perm_mode: + if not accessible(perm): + return False + return True + + +class ReadableFile(File): + """ + Path, that is enforced to be: + - an existing file + - readable by knot-resolver processes + """ + + def __init__( + self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/" + ) -> None: + super().__init__(source_value, parents=parents, object_path=object_path) + + if self.strict_validation and not _kres_accessible(self._value, _PermissionMode.READ): + raise ValueError(f"{kresd_user()}:{kresd_group()} has insufficient permissions to read '{self._value}'") + + +class WritableDir(Dir): + """ + Path, that is enforced to be: + - an existing directory + - writable/executable by knot-resolver processes + """ + + def __init__( + self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/" + ) -> None: + super().__init__(source_value, parents=parents, object_path=object_path) + + if self.strict_validation and not _kres_accessible( + self._value, _PermissionMode.WRITE | _PermissionMode.EXECUTE + ): + raise ValueError( + f"{kresd_user()}:{kresd_group()} has insufficient permissions to write/execute '{self._value}'" + ) + + +class WritableFilePath(FilePath): + """ + Path, that is enforced to be: + - parent of the last path segment is an existing directory + - it does not point to a dir + - writable/executable parent directory by knot-resolver processes + """ + + def __init__( + self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/" + ) -> None: + super().__init__(source_value, parents=parents, object_path=object_path) + + if self.strict_validation and not _kres_accessible( + self._value.parent, _PermissionMode.WRITE | _PermissionMode.EXECUTE + ): + raise ValueError( + f"{kresd_user()}:{kresd_group()} has insufficient permissions to write/execute'{self._value.parent}'" + ) diff --git a/python/knot_resolver/datamodel/types/generic_types.py b/python/knot_resolver/datamodel/types/generic_types.py new file mode 100644 index 00000000..8649a0f0 --- /dev/null +++ b/python/knot_resolver/datamodel/types/generic_types.py @@ -0,0 +1,38 @@ +from typing import Any, List, TypeVar, Union + +from knot_resolver.utils.modeling import BaseGenericTypeWrapper + +T = TypeVar("T") + + +class ListOrItem(BaseGenericTypeWrapper[Union[List[T], T]]): + _value_orig: Union[List[T], T] + _list: List[T] + + def __init__(self, source_value: Any, object_path: str = "/") -> None: # pylint: disable=unused-argument + self._value_orig: Union[List[T], T] = source_value + + self._list: List[T] = source_value if isinstance(source_value, list) else [source_value] + if len(self) == 0: + raise ValueError("empty list is not allowed") + + def __getitem__(self, index: Any) -> T: + return self._list[index] + + def __int__(self) -> int: + raise ValueError(f"Can't convert '{type(self).__name__}' to an integer.") + + def __str__(self) -> str: + return str(self._value_orig) + + def to_std(self) -> List[T]: + return self._list + + def __eq__(self, o: object) -> bool: + return isinstance(o, ListOrItem) and o._value_orig == self._value_orig + + def __len__(self) -> int: + return len(self._list) + + def serialize(self) -> Union[List[T], T]: + return self._value_orig diff --git a/python/knot_resolver/datamodel/types/types.py b/python/knot_resolver/datamodel/types/types.py new file mode 100644 index 00000000..ca8706d2 --- /dev/null +++ b/python/knot_resolver/datamodel/types/types.py @@ -0,0 +1,526 @@ +import ipaddress +import re +from typing import Any, Dict, Optional, Type, Union + +from knot_resolver.datamodel.types.base_types import ( + IntRangeBase, + PatternBase, + StrBase, + StringLengthBase, + UnitBase, +) +from knot_resolver.utils.modeling import BaseValueType + + +class IntNonNegative(IntRangeBase): + _min: int = 0 + + +class IntPositive(IntRangeBase): + _min: int = 1 + + +class Int0_512(IntRangeBase): + _min: int = 0 + _max: int = 512 + + +class Int0_65535(IntRangeBase): + _min: int = 0 + _max: int = 65_535 + + +class Percent(IntRangeBase): + _min: int = 0 + _max: int = 100 + + +class PortNumber(IntRangeBase): + _min: int = 1 + _max: int = 65_535 + + @classmethod + def from_str(cls: Type["PortNumber"], port: str, object_path: str = "/") -> "PortNumber": + try: + return cls(int(port), object_path) + except ValueError as e: + raise ValueError(f"invalid port number {port}") from e + + +class SizeUnit(UnitBase): + _units = {"B": 1, "K": 1024, "M": 1024**2, "G": 1024**3} + + def bytes(self) -> int: + return self._base_value + + def mbytes(self) -> int: + return self._base_value // 1024**2 + + +class TimeUnit(UnitBase): + _units = {"us": 1, "ms": 10**3, "s": 10**6, "m": 60 * 10**6, "h": 3600 * 10**6, "d": 24 * 3600 * 10**6} + + def minutes(self) -> int: + return self._base_value // 1000**2 // 60 + + def seconds(self) -> int: + return self._base_value // 1000**2 + + def millis(self) -> int: + return self._base_value // 1000 + + def micros(self) -> int: + return self._base_value + + +class EscapedStr(StrBase): + """ + A string where escape sequences are ignored and quotes are escaped. + """ + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + super().__init__(source_value, object_path) + + escape = { + "'": r"\'", + '"': r"\"", + "\a": r"\a", + "\n": r"\n", + "\r": r"\r", + "\t": r"\t", + "\b": r"\b", + "\f": r"\f", + "\v": r"\v", + "\0": r"\0", + } + + s = list(self._value) + for i, c in enumerate(self._value): + if c in escape: + s[i] = escape[c] + elif not c.isalnum(): + s[i] = repr(c)[1:-1] + self._value = "".join(s) + + def multiline(self) -> str: + """ + Lua multiline string is enclosed in double square brackets '[[ ]]'. + This method makes sure that double square brackets are escaped. + """ + + replace = { + "[[": r"\[\[", + "]]": r"\]\]", + } + + ml = self._orig_value + for s, r in replace.items(): + ml = ml.replace(s, r) + return ml + + +class EscapedStr32B(EscapedStr, StringLengthBase): + """ + Same as 'EscapedStr', but minimal length is 32 bytes. + """ + + _min_bytes: int = 32 + + +class DomainName(StrBase): + """ + Fully or partially qualified domain name. + """ + + _punycode: str + _re = re.compile( + r"(?=^.{,253}\.?$)" # max 253 chars + r"(^(?!\.)" # do not start name with dot + r"((?!-)" # do not start label with hyphen + r"\.?[a-zA-Z0-9-]{,62}" # max 63 chars in label + r"[a-zA-Z0-9])+" # do not end label with hyphen + r"\.?$)" # end with or without '.' + r"|^\.$" # allow root-zone + ) + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + super().__init__(source_value, object_path) + try: + punycode = self._value.encode("idna").decode("utf-8") if self._value != "." else "." + except ValueError as e: + raise ValueError( + f"conversion of '{self._value}' to IDN punycode representation failed", + object_path, + ) from e + + if type(self)._re.match(punycode): + self._punycode = punycode + else: + raise ValueError( + f"'{source_value}' represented in punycode '{punycode}' does not match '{self._re.pattern}' pattern", + object_path, + ) + + def __hash__(self) -> int: + if self._value.endswith("."): + return hash(self._value) + return hash(f"{self._value}.") + + def punycode(self) -> str: + return self._punycode + + @classmethod + def json_schema(cls: Type["DomainName"]) -> Dict[Any, Any]: + return {"type": "string", "pattern": rf"{cls._re.pattern}"} + + +class InterfaceName(PatternBase): + """ + Network interface name. + """ + + _re = re.compile(r"^[a-zA-Z0-9]+(?:[-_][a-zA-Z0-9]+)*$") + + +class IDPattern(PatternBase): + """ + Alphanumerical ID for identifying systemd slice. + """ + + _re = re.compile(r"^(?!-)[a-z0-9-]*[a-z0-9]+$") + + +class PinSha256(PatternBase): + """ + A string that stores base64 encoded sha256. + """ + + _re = re.compile(r"^[A-Za-z\d+/]{43}=$") + + +class InterfacePort(StrBase): + addr: Union[None, ipaddress.IPv4Address, ipaddress.IPv6Address] = None + if_name: Optional[InterfaceName] = None + port: PortNumber + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + super().__init__(source_value, object_path) + + parts = self._value.split("@") + if len(parts) == 2: + try: + self.addr = ipaddress.ip_address(parts[0]) + except ValueError as e1: + try: + self.if_name = InterfaceName(parts[0]) + except ValueError as e2: + raise ValueError( + f"expected IP address or interface name, got '{parts[0]}'.", object_path + ) from e1 and e2 + self.port = PortNumber.from_str(parts[1], object_path) + else: + raise ValueError(f"expected '<ip-address|interface-name>@<port>', got '{source_value}'.", object_path) + + +class InterfaceOptionalPort(StrBase): + addr: Union[None, ipaddress.IPv4Address, ipaddress.IPv6Address] = None + if_name: Optional[InterfaceName] = None + port: Optional[PortNumber] = None + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + super().__init__(source_value, object_path) + + parts = self._value.split("@") + if 0 < len(parts) < 3: + try: + self.addr = ipaddress.ip_address(parts[0]) + except ValueError as e1: + try: + self.if_name = InterfaceName(parts[0]) + except ValueError as e2: + raise ValueError( + f"expected IP address or interface name, got '{parts[0]}'.", object_path + ) from e1 and e2 + if len(parts) == 2: + self.port = PortNumber.from_str(parts[1], object_path) + else: + raise ValueError(f"expected '<ip-address|interface-name>[@<port>]', got '{parts}'.", object_path) + + +class IPAddressPort(StrBase): + addr: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + port: PortNumber + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + super().__init__(source_value, object_path) + + parts = self._value.split("@") + if len(parts) == 2: + self.port = PortNumber.from_str(parts[1], object_path) + try: + self.addr = ipaddress.ip_address(parts[0]) + except ValueError as e: + raise ValueError(f"failed to parse IP address '{parts[0]}'.", object_path) from e + else: + raise ValueError(f"expected '<ip-address>@<port>', got '{source_value}'.", object_path) + + +class IPAddressOptionalPort(StrBase): + addr: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + port: Optional[PortNumber] = None + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + super().__init__(source_value) + parts = source_value.split("@") + if 0 < len(parts) < 3: + try: + self.addr = ipaddress.ip_address(parts[0]) + except ValueError as e: + raise ValueError(f"failed to parse IP address '{parts[0]}'.", object_path) from e + if len(parts) == 2: + self.port = PortNumber.from_str(parts[1], object_path) + else: + raise ValueError(f"expected '<ip-address>[@<port>]', got '{parts}'.", object_path) + + +class IPv4Address(BaseValueType): + _value: ipaddress.IPv4Address + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + if isinstance(source_value, str): + try: + self._value: ipaddress.IPv4Address = ipaddress.IPv4Address(source_value) + except ValueError as e: + raise ValueError("failed to parse IPv4 address.") from e + else: + raise ValueError( + "Unexpected value for a IPv4 address." + f" Expected string, got '{source_value}' with type '{type(source_value)}'", + object_path, + ) + + def to_std(self) -> ipaddress.IPv4Address: + return self._value + + def __str__(self) -> str: + return str(self._value) + + def __int__(self) -> int: + raise ValueError("Can't convert IPv4 address to an integer") + + def __repr__(self) -> str: + return f'{type(self).__name__}("{self._value}")' + + def __eq__(self, o: object) -> bool: + """ + Two instances of IPv4Address are equal when they represent same IPv4 address as string. + """ + return isinstance(o, IPv4Address) and str(o._value) == str(self._value) + + def serialize(self) -> Any: + return str(self._value) + + @classmethod + def json_schema(cls: Type["IPv4Address"]) -> Dict[Any, Any]: + return { + "type": "string", + } + + +class IPv6Address(BaseValueType): + _value: ipaddress.IPv6Address + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + if isinstance(source_value, str): + try: + self._value: ipaddress.IPv6Address = ipaddress.IPv6Address(source_value) + except ValueError as e: + raise ValueError("failed to parse IPv6 address.") from e + else: + raise ValueError( + "Unexpected value for a IPv6 address." + f" Expected string, got '{source_value}' with type '{type(source_value)}'", + object_path, + ) + + def to_std(self) -> ipaddress.IPv6Address: + return self._value + + def __str__(self) -> str: + return str(self._value) + + def __int__(self) -> int: + raise ValueError("Can't convert IPv6 address to an integer") + + def __repr__(self) -> str: + return f'{type(self).__name__}("{self._value}")' + + def __eq__(self, o: object) -> bool: + """ + Two instances of IPv6Address are equal when they represent same IPv6 address as string. + """ + return isinstance(o, IPv6Address) and str(o._value) == str(self._value) + + def serialize(self) -> Any: + return str(self._value) + + @classmethod + def json_schema(cls: Type["IPv6Address"]) -> Dict[Any, Any]: + return { + "type": "string", + } + + +IPAddress = Union[IPv4Address, IPv6Address] + + +class IPAddressEM(BaseValueType): + """ + IP address with exclamation mark suffix, e.g. '127.0.0.1!'. + """ + + _value: str + _addr: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + if isinstance(source_value, str): + if source_value.endswith("!"): + addr, suff = source_value.split("!", 1) + if suff != "": + raise ValueError(f"suffix '{suff}' found after '!'.") + else: + raise ValueError("string does not end with '!'.") + try: + self._addr: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] = ipaddress.ip_address(addr) + self._value = source_value + except ValueError as e: + raise ValueError("failed to parse IP address.") from e + else: + raise ValueError( + "Unexpected value for a IPv6 address." + f" Expected string, got '{source_value}' with type '{type(source_value)}'", + object_path, + ) + + def to_std(self) -> str: + return self._value + + def __str__(self) -> str: + return self._value + + def __int__(self) -> int: + raise ValueError("Can't convert to an integer") + + def __repr__(self) -> str: + return f'{type(self).__name__}("{self._value}")' + + def __eq__(self, o: object) -> bool: + """ + Two instances of IPAddressEM are equal when they represent same string. + """ + return isinstance(o, IPAddressEM) and o._value == self._value + + def serialize(self) -> Any: + return self._value + + @classmethod + def json_schema(cls: Type["IPAddressEM"]) -> Dict[Any, Any]: + return { + "type": "string", + } + + +class IPNetwork(BaseValueType): + _value: Union[ipaddress.IPv4Network, ipaddress.IPv6Network] + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + if isinstance(source_value, str): + try: + self._value: Union[ipaddress.IPv4Network, ipaddress.IPv6Network] = ipaddress.ip_network(source_value) + except ValueError as e: + raise ValueError("failed to parse IP network.") from e + else: + raise ValueError( + "Unexpected value for a network subnet." + f" Expected string, got '{source_value}' with type '{type(source_value)}'" + ) + + def __int__(self) -> int: + raise ValueError("Can't convert network prefix to an integer") + + def __str__(self) -> str: + return self._value.with_prefixlen + + def __repr__(self) -> str: + return f'{type(self).__name__}("{self._value}")' + + def __eq__(self, o: object) -> bool: + return isinstance(o, IPNetwork) and o._value == self._value + + def to_std(self) -> Union[ipaddress.IPv4Network, ipaddress.IPv6Network]: + return self._value + + def serialize(self) -> Any: + return self._value.with_prefixlen + + @classmethod + def json_schema(cls: Type["IPNetwork"]) -> Dict[Any, Any]: + return { + "type": "string", + } + + +class IPv6Network(BaseValueType): + _value: ipaddress.IPv6Network + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + if isinstance(source_value, str): + try: + self._value: ipaddress.IPv6Network = ipaddress.IPv6Network(source_value) + except ValueError as e: + raise ValueError("failed to parse IPv6 network.") from e + else: + raise ValueError( + "Unexpected value for a IPv6 network subnet." + f" Expected string, got '{source_value}' with type '{type(source_value)}'" + ) + + def to_std(self) -> ipaddress.IPv6Network: + return self._value + + def __str__(self) -> str: + return self._value.with_prefixlen + + def __int__(self) -> int: + raise ValueError("Can't convert network prefix to an integer") + + def __repr__(self) -> str: + return f'{type(self).__name__}("{self._value}")' + + def __eq__(self, o: object) -> bool: + return isinstance(o, IPv6Network) and o._value == self._value + + def serialize(self) -> Any: + return self._value.with_prefixlen + + @classmethod + def json_schema(cls: Type["IPv6Network"]) -> Dict[Any, Any]: + return { + "type": "string", + } + + +class IPv6Network96(IPv6Network): + def __init__(self, source_value: Any, object_path: str = "/") -> None: + super().__init__(source_value, object_path=object_path) + if self._value.prefixlen == 128: + raise ValueError( + "Expected IPv6 network address with /96 prefix length." + " Submitted address has been interpreted as /128." + " Maybe, you forgot to add /96 after the base address?" + ) + + if self._value.prefixlen != 96: + raise ValueError( + "expected IPv6 network address with /96 prefix length." f" Got prefix lenght of {self._value.prefixlen}" + ) diff --git a/python/knot_resolver/datamodel/view_schema.py b/python/knot_resolver/datamodel/view_schema.py new file mode 100644 index 00000000..b1d3adbe --- /dev/null +++ b/python/knot_resolver/datamodel/view_schema.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from typing_extensions import Literal + +from knot_resolver.datamodel.types import IDPattern, IPNetwork +from knot_resolver.utils.modeling import ConfigSchema + + +class ViewOptionsSchema(ConfigSchema): + """ + Configuration options for clients identified by the view. + + --- + minimize: Send minimum amount of information in recursive queries to enhance privacy. + dns64: Enable/disable DNS64. + """ + + minimize: bool = True + dns64: bool = True + + +class ViewSchema(ConfigSchema): + """ + Configuration parameters that allow you to create personalized policy rules and other. + + --- + subnets: Identifies the client based on his subnet. Rule with more precise subnet takes priority. + dst_subnet: Destination subnet, as an additional condition. + protocols: Transport protocol, as an additional condition. + tags: Tags to link with other policy rules. + answer: Direct approach how to handle request from clients identified by the view. + options: Configuration options for clients identified by the view. + """ + + subnets: List[IPNetwork] + dst_subnet: Optional[IPNetwork] = None # could be a list as well, iterated in template + protocols: Optional[List[Literal["udp53", "tcp53", "dot", "doh", "doq"]]] = None + + tags: Optional[List[IDPattern]] = None + answer: Optional[Literal["allow", "refused", "noanswer"]] = None + options: ViewOptionsSchema = ViewOptionsSchema() + + def _validate(self) -> None: + if bool(self.tags) == bool(self.answer): + raise ValueError("exactly one of 'tags' and 'answer' must be configured") diff --git a/python/knot_resolver/datamodel/webmgmt_schema.py b/python/knot_resolver/datamodel/webmgmt_schema.py new file mode 100644 index 00000000..ce18376b --- /dev/null +++ b/python/knot_resolver/datamodel/webmgmt_schema.py @@ -0,0 +1,27 @@ +from typing import Optional + +from knot_resolver.datamodel.types import WritableFilePath, InterfacePort, ReadableFile +from knot_resolver.utils.modeling import ConfigSchema + + +class WebmgmtSchema(ConfigSchema): + """ + Configuration of legacy web management endpoint. + + --- + unix_socket: Path to unix domain socket to listen to. + interface: IP address or interface name with port number to listen to. + tls: Enable/disable TLS. + cert_file: Path to certificate file. + key_file: Path to certificate key. + """ + + unix_socket: Optional[WritableFilePath] = None + interface: Optional[InterfacePort] = None + tls: bool = False + cert_file: Optional[ReadableFile] = None + key_file: Optional[ReadableFile] = None + + def _validate(self) -> None: + if bool(self.unix_socket) == bool(self.interface): + raise ValueError("One of 'interface' or 'unix-socket' must be configured.") diff --git a/python/knot_resolver/manager/__init__.py b/python/knot_resolver/manager/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/python/knot_resolver/manager/__init__.py diff --git a/python/knot_resolver/manager/__main__.py b/python/knot_resolver/manager/__main__.py new file mode 100644 index 00000000..26aae1d6 --- /dev/null +++ b/python/knot_resolver/manager/__main__.py @@ -0,0 +1,5 @@ +from knot_resolver.manager.main import main + + +if __name__ == "__main__": + main() diff --git a/python/knot_resolver/manager/config_store.py b/python/knot_resolver/manager/config_store.py new file mode 100644 index 00000000..1c0174f2 --- /dev/null +++ b/python/knot_resolver/manager/config_store.py @@ -0,0 +1,101 @@ +import asyncio +from asyncio import Lock +from typing import Any, Awaitable, Callable, List, Tuple + +from knot_resolver.datamodel import KresConfig +from knot_resolver.manager.exceptions import KresManagerException +from knot_resolver.utils.functional import Result +from knot_resolver.utils.modeling.exceptions import DataParsingError +from knot_resolver.utils.modeling.types import NoneType + +VerifyCallback = Callable[[KresConfig, KresConfig], Awaitable[Result[None, str]]] +UpdateCallback = Callable[[KresConfig], Awaitable[None]] + + +class ConfigStore: + def __init__(self, initial_config: KresConfig): + self._config = initial_config + self._verifiers: List[VerifyCallback] = [] + self._callbacks: List[UpdateCallback] = [] + self._update_lock: Lock = Lock() + + async def update(self, config: KresConfig) -> None: + # invoke pre-change verifiers + results: Tuple[Result[None, str], ...] = tuple( + await asyncio.gather(*[ver(self._config, config) for ver in self._verifiers]) + ) + err_res = filter(lambda r: r.is_err(), results) + errs = list(map(lambda r: r.unwrap_err(), err_res)) + if len(errs) > 0: + raise KresManagerException("Configuration validation failed. The reasons are:\n - " + "\n - ".join(errs)) + + async with self._update_lock: + # update the stored config with the new version + self._config = config + + # invoke change callbacks + for call in self._callbacks: + await call(config) + + async def renew(self) -> None: + await self.update(self._config) + + async def register_verifier(self, verifier: VerifyCallback) -> None: + self._verifiers.append(verifier) + res = await verifier(self.get(), self.get()) + if res.is_err(): + raise DataParsingError(f"Initial config verification failed with error: {res.unwrap_err()}") + + async def register_on_change_callback(self, callback: UpdateCallback) -> None: + """ + Registers new callback and immediatelly calls it with current config + """ + + self._callbacks.append(callback) + await callback(self.get()) + + def get(self) -> KresConfig: + return self._config + + +def only_on_real_changes_update(selector: Callable[[KresConfig], Any]) -> Callable[[UpdateCallback], UpdateCallback]: + def decorator(orig_func: UpdateCallback) -> UpdateCallback: + original_value_set: Any = False + original_value: Any = None + + async def new_func_update(config: KresConfig) -> None: + nonlocal original_value_set + nonlocal original_value + if not original_value_set: + original_value_set = True + original_value = selector(config) + await orig_func(config) + elif original_value != selector(config): + original_value = selector(config) + await orig_func(config) + + return new_func_update + + return decorator + + +def only_on_real_changes_verifier(selector: Callable[[KresConfig], Any]) -> Callable[[VerifyCallback], VerifyCallback]: + def decorator(orig_func: VerifyCallback) -> VerifyCallback: + original_value_set: Any = False + original_value: Any = None + + async def new_func_verifier(old: KresConfig, new: KresConfig) -> Result[NoneType, str]: + nonlocal original_value_set + nonlocal original_value + if not original_value_set: + original_value_set = True + original_value = selector(new) + await orig_func(old, new) + elif original_value != selector(new): + original_value = selector(new) + await orig_func(old, new) + return Result.ok(None) + + return new_func_verifier + + return decorator diff --git a/python/knot_resolver/manager/constants.py b/python/knot_resolver/manager/constants.py new file mode 100644 index 00000000..832d3fa9 --- /dev/null +++ b/python/knot_resolver/manager/constants.py @@ -0,0 +1,108 @@ +import importlib.util +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Optional + +# Install config is semi-optional - only needed to actually run Manager, but not +# for its unit tests. +if importlib.util.find_spec("knot_resolver"): + import knot_resolver # type: ignore[import-not-found] +else: + knot_resolver = None + +if TYPE_CHECKING: + from knot_resolver.manager.config_store import ConfigStore + from knot_resolver.datamodel.config_schema import KresConfig + from knot_resolver.controller.interface import KresID + +STARTUP_LOG_LEVEL = logging.DEBUG +DEFAULT_MANAGER_CONFIG_FILE = Path("/etc/knot-resolver/config.yaml") +CONFIG_FILE_ENV_VAR = "KRES_MANAGER_CONFIG" +API_SOCK_ENV_VAR = "KRES_MANAGER_API_SOCK" +MANAGER_FIX_ATTEMPT_MAX_COUNTER = 2 +FIX_COUNTER_DECREASE_INTERVAL_SEC = 30 * 60 +PID_FILE_NAME = "manager.pid" +MAX_WORKERS = 256 + + +def kresd_executable() -> Path: + assert knot_resolver is not None + return knot_resolver.sbin_dir / "kresd" + + +def kres_gc_executable() -> Path: + assert knot_resolver is not None + return knot_resolver.sbin_dir / "kres-cache-gc" + + +def kresd_user(): + return None if knot_resolver is None else knot_resolver.user + + +def kresd_group(): + return None if knot_resolver is None else knot_resolver.group + + +def kresd_cache_dir(config: "KresConfig") -> Path: + return config.cache.storage.to_path() + + +def policy_loader_config_file(_config: "KresConfig") -> Path: + return Path("policy-loader.conf") + + +def kresd_config_file(_config: "KresConfig", kres_id: "KresID") -> Path: + return Path(f"kresd{int(kres_id)}.conf") + + +def kresd_config_file_supervisord_pattern(_config: "KresConfig") -> Path: + return Path("kresd%(process_num)d.conf") + + +def supervisord_config_file(_config: "KresConfig") -> Path: + return Path("supervisord.conf") + + +def supervisord_config_file_tmp(_config: "KresConfig") -> Path: + return Path("supervisord.conf.tmp") + + +def supervisord_pid_file(_config: "KresConfig") -> Path: + return Path("supervisord.pid") + + +def supervisord_sock_file(_config: "KresConfig") -> Path: + return Path("supervisord.sock") + + +def supervisord_subprocess_log_dir(_config: "KresConfig") -> Path: + return Path("logs") + + +WATCHDOG_INTERVAL: float = 5 +""" +Used in KresdManager. It's a number of seconds in between system health checks. +""" + + +class _UserConstants: + """ + Class for accessing constants, which are technically not constants as they are user configurable. + """ + + def __init__(self, config_store: "ConfigStore", working_directory_on_startup: str) -> None: + self._config_store = config_store + self.working_directory_on_startup = working_directory_on_startup + + +_user_constants: Optional[_UserConstants] = None + + +async def init_user_constants(config_store: "ConfigStore", working_directory_on_startup: str) -> None: + global _user_constants + _user_constants = _UserConstants(config_store, working_directory_on_startup) + + +def user_constants() -> _UserConstants: + assert _user_constants is not None + return _user_constants diff --git a/python/knot_resolver/manager/exceptions.py b/python/knot_resolver/manager/exceptions.py new file mode 100644 index 00000000..5b05d98e --- /dev/null +++ b/python/knot_resolver/manager/exceptions.py @@ -0,0 +1,28 @@ +from typing import List + + +class CancelStartupExecInsteadException(Exception): + """ + Exception used for terminating system startup and instead + causing an exec of something else. Could be used by subprocess + controllers such as supervisord to allow them to run as top-level + process in a process tree. + """ + + def __init__(self, exec_args: List[str], *args: object) -> None: + self.exec_args = exec_args + super().__init__(*args) + + +class KresManagerException(Exception): + """ + Base class for all custom exceptions we use in our code + """ + + +class SubprocessControllerException(KresManagerException): + pass + + +class SubprocessControllerTimeoutException(KresManagerException): + pass diff --git a/python/knot_resolver/manager/kres_manager.py b/python/knot_resolver/manager/kres_manager.py new file mode 100644 index 00000000..3dbc1079 --- /dev/null +++ b/python/knot_resolver/manager/kres_manager.py @@ -0,0 +1,429 @@ +import asyncio +import logging +import sys +import time +from secrets import token_hex +from subprocess import SubprocessError +from typing import Any, Callable, List, Optional + +from knot_resolver.compat.asyncio import create_task +from knot_resolver.manager.config_store import ( + ConfigStore, + only_on_real_changes_update, + only_on_real_changes_verifier, +) +from knot_resolver.manager.constants import ( + FIX_COUNTER_DECREASE_INTERVAL_SEC, + MANAGER_FIX_ATTEMPT_MAX_COUNTER, + WATCHDOG_INTERVAL, +) +from knot_resolver.manager.exceptions import SubprocessControllerException +from knot_resolver.controller.interface import ( + Subprocess, + SubprocessController, + SubprocessStatus, + SubprocessType, +) +from knot_resolver.controller.registered_workers import ( + command_registered_workers, + get_registered_workers_kresids, +) +from knot_resolver.utils.functional import Result +from knot_resolver.utils.modeling.types import NoneType + +from knot_resolver import KresConfig + +logger = logging.getLogger(__name__) + + +class _FixCounter: + def __init__(self) -> None: + self._counter = 0 + self._timestamp = time.time() + + def increase(self) -> None: + self._counter += 1 + self._timestamp = time.time() + + def try_decrease(self) -> None: + if time.time() - self._timestamp > FIX_COUNTER_DECREASE_INTERVAL_SEC: + if self._counter > 0: + logger.info( + f"Enough time has passed since last detected instability, decreasing fix attempt counter to {self._counter}" + ) + self._counter -= 1 + self._timestamp = time.time() + + def __str__(self) -> str: + return str(self._counter) + + def is_too_high(self) -> bool: + return self._counter >= MANAGER_FIX_ATTEMPT_MAX_COUNTER + + +async def _deny_max_worker_changes(config_old: KresConfig, config_new: KresConfig) -> Result[None, str]: + if config_old.max_workers != config_new.max_workers: + return Result.err( + "Changing 'max-workers', the maximum number of workers allowed to run, is not allowed at runtime." + ) + + return Result.ok(None) + + +class KresManager: # pylint: disable=too-many-instance-attributes + """ + Core of the whole operation. Orchestrates individual instances under some + service manager like systemd. + + Instantiate with `KresManager.create()`, not with the usual constructor! + """ + + def __init__(self, shutdown_trigger: Callable[[int], None], _i_know_what_i_am_doing: bool = False): + if not _i_know_what_i_am_doing: + logger.error( + "Trying to create an instance of KresManager using normal constructor. Please use " + "`KresManager.get_instance()` instead" + ) + assert False + + self._workers: List[Subprocess] = [] + self._gc: Optional[Subprocess] = None + self._policy_loader: Optional[Subprocess] = None + self._manager_lock = asyncio.Lock() + self._workers_reset_needed: bool = False + self._controller: SubprocessController + self._watchdog_task: Optional["asyncio.Task[None]"] = None + self._fix_counter: _FixCounter = _FixCounter() + self._config_store: ConfigStore + self._shutdown_trigger: Callable[[int], None] = shutdown_trigger + + @staticmethod + async def create( + subprocess_controller: SubprocessController, + config_store: ConfigStore, + shutdown_trigger: Callable[[int], None], + ) -> "KresManager": + """ + Creates new instance of KresManager. + """ + + inst = KresManager(shutdown_trigger, _i_know_what_i_am_doing=True) + await inst._async_init(subprocess_controller, config_store) # pylint: disable=protected-access + return inst + + async def _async_init(self, subprocess_controller: SubprocessController, config_store: ConfigStore) -> None: + self._controller = subprocess_controller + self._config_store = config_store + + # initialize subprocess controller + logger.debug("Starting controller") + await self._controller.initialize_controller(config_store.get()) + self._watchdog_task = create_task(self._watchdog()) + logger.debug("Looking for already running workers") + await self._collect_already_running_workers() + + # register and immediately call a verifier that loads policy rules into the rules database + await config_store.register_verifier(self.load_policy_rules) + + # configuration nodes that are relevant to kresd workers and the cache garbage collector + def config_nodes(config: KresConfig) -> List[Any]: + return [ + config.nsid, + config.hostname, + config.workers, + config.max_workers, + config.webmgmt, + config.options, + config.network, + config.forward, + config.cache, + config.dnssec, + config.dns64, + config.logging, + config.monitoring, + config.lua, + ] + + # register and immediately call a verifier that validates config with 'canary' kresd process + await config_store.register_verifier(only_on_real_changes_verifier(config_nodes)(self.validate_config)) + + # register and immediately call a callback to apply config to all 'kresd' workers and 'cache-gc' + await config_store.register_on_change_callback(only_on_real_changes_update(config_nodes)(self.apply_config)) + + # register callback to reset policy rules for each 'kresd' worker + await config_store.register_on_change_callback(self.reset_workers_policy_rules) + + # register and immediately call a callback to set new TLS session ticket secret for 'kresd' workers + await config_store.register_on_change_callback( + only_on_real_changes_update(config_nodes)(self.set_new_tls_sticket_secret) + ) + + # register controller config change listeners + await config_store.register_verifier(_deny_max_worker_changes) + + async def _spawn_new_worker(self, config: KresConfig) -> None: + subprocess = await self._controller.create_subprocess(config, SubprocessType.KRESD) + await subprocess.start() + self._workers.append(subprocess) + + async def _stop_a_worker(self) -> None: + if len(self._workers) == 0: + raise IndexError("Can't stop a kresd when there are no running") + + subprocess = self._workers.pop() + await subprocess.stop() + + async def _collect_already_running_workers(self) -> None: + for subp in await self._controller.get_all_running_instances(): + if subp.type == SubprocessType.KRESD: + self._workers.append(subp) + elif subp.type == SubprocessType.GC: + assert self._gc is None + self._gc = subp + elif subp.type == SubprocessType.POLICY_LOADER: + assert self._policy_loader is None + self._policy_loader = subp + else: + raise RuntimeError("unexpected subprocess type") + + async def _rolling_restart(self, new_config: KresConfig) -> None: + for kresd in self._workers: + await kresd.apply_new_config(new_config) + + async def _ensure_number_of_children(self, config: KresConfig, n: int) -> None: + # kill children that are not needed + while len(self._workers) > n: + await self._stop_a_worker() + + # spawn new children if needed + while len(self._workers) < n: + await self._spawn_new_worker(config) + + async def _run_policy_loader(self, config: KresConfig) -> None: + if self._policy_loader: + await self._policy_loader.start(config) + else: + subprocess = await self._controller.create_subprocess(config, SubprocessType.POLICY_LOADER) + await subprocess.start() + self._policy_loader = subprocess + + def _is_policy_loader_exited(self) -> bool: + if self._policy_loader: + return self._policy_loader.status() is SubprocessStatus.EXITED + return False + + def _is_gc_running(self) -> bool: + return self._gc is not None + + async def _start_gc(self, config: KresConfig) -> None: + subprocess = await self._controller.create_subprocess(config, SubprocessType.GC) + await subprocess.start() + self._gc = subprocess + + async def _stop_gc(self) -> None: + assert self._gc is not None + await self._gc.stop() + self._gc = None + + async def validate_config(self, _old: KresConfig, new: KresConfig) -> Result[NoneType, str]: + async with self._manager_lock: + logger.debug("Testing the new config with a canary process") + try: + # technically, this has side effects of leaving a new process runnning + # but it's practically not a problem, because + # if it keeps running, the config is valid and others will soon join as well + # if it crashes and the startup fails, then well, it's not running anymore... :) + await self._spawn_new_worker(new) + except (SubprocessError, SubprocessControllerException): + logger.error("Kresd with the new config failed to start, rejecting config") + return Result.err("canary kresd process failed to start. Config might be invalid.") + + logger.debug("Canary process test passed.") + return Result.ok(None) + + async def _reload_system_state(self) -> None: + async with self._manager_lock: + self._workers = [] + self._policy_loader = None + self._gc = None + await self._collect_already_running_workers() + + async def reset_workers_policy_rules(self, _config: KresConfig) -> None: + + # command all running 'kresd' workers to reset their old policy rules, + # unless the workers have already been started with a new config so reset is not needed + if self._workers_reset_needed and get_registered_workers_kresids(): + logger.debug("Resetting policy rules for all running 'kresd' workers") + cmd_results = await command_registered_workers("require('ffi').C.kr_rules_reset()") + for worker, res in cmd_results.items(): + if res != 0: + logger.error("Failed to reset policy rules in %s: %s", worker, res) + else: + logger.debug( + "Skipped resetting policy rules for all running 'kresd' workers:" + " the workers are already running with new configuration" + ) + + async def set_new_tls_sticket_secret(self, config: KresConfig) -> None: + + if config.network.tls.sticket_secret or config.network.tls.sticket_secret_file: + logger.debug("User-configured TLS resumption secret found - skipping auto-generation.") + return + + logger.debug("Creating TLS session ticket secret") + secret = token_hex(32) + logger.debug("Setting TLS session ticket secret for all running 'kresd' workers") + cmd_results = await command_registered_workers(f"net.tls_sticket_secret('{secret}')") + for worker, res in cmd_results.items(): + if res not in (0, True): + logger.error("Failed to set TLS session ticket secret in %s: %s", worker, res) + + async def apply_config(self, config: KresConfig, _noretry: bool = False) -> None: + try: + async with self._manager_lock: + logger.debug("Applying config to all workers") + await self._rolling_restart(config) + await self._ensure_number_of_children(config, int(config.workers)) + + if self._is_gc_running() != bool(config.cache.garbage_collector): + if config.cache.garbage_collector: + logger.debug("Starting cache GC") + await self._start_gc(config) + else: + logger.debug("Stopping cache GC") + await self._stop_gc() + except SubprocessControllerException as e: + if _noretry: + raise + elif self._fix_counter.is_too_high(): + logger.error(f"Failed to apply config: {e}") + logger.error("There have already been problems recently, refusing to try to fix it.") + await self.forced_shutdown() # possible improvement - the person who requested this change won't get a response this way + else: + logger.error(f"Failed to apply config: {e}") + logger.warning("Reloading system state and trying again.") + self._fix_counter.increase() + await self._reload_system_state() + await self.apply_config(config, _noretry=True) + + self._workers_reset_needed = False + + async def load_policy_rules(self, _old: KresConfig, new: KresConfig) -> Result[NoneType, str]: + try: + async with self._manager_lock: + logger.debug("Running kresd 'policy-loader'") + await self._run_policy_loader(new) + + # wait for 'policy-loader' to finish + logger.debug("Waiting for 'policy-loader' to finish loading policy rules") + while not self._is_policy_loader_exited(): + await asyncio.sleep(1) + + except (SubprocessError, SubprocessControllerException) as e: + logger.error(f"Failed to load policy rules: {e}") + return Result.err("kresd 'policy-loader' process failed to start. Config might be invalid.") + + self._workers_reset_needed = True + logger.debug("Loading policy rules has been successfully completed") + return Result.ok(None) + + async def stop(self): + if self._watchdog_task is not None: + self._watchdog_task.cancel() # cancel it + try: + await self._watchdog_task # and let it really finish + except asyncio.CancelledError: + pass + + async with self._manager_lock: + # we could stop all the children one by one right now + # we won't do that and we leave that up to the subprocess controller to do that while it is shutting down + await self._controller.shutdown_controller() + # now, when everything is stopped, let's clean up all the remains + await asyncio.gather(*[w.cleanup() for w in self._workers]) + + async def forced_shutdown(self) -> None: + logger.warning("Collecting all remaining workers...") + await self._reload_system_state() + logger.warning("Terminating...") + self._shutdown_trigger(1) + + async def _instability_handler(self) -> None: + if self._fix_counter.is_too_high(): + logger.error( + "Already attempted too many times to fix system state. Refusing to try again and shutting down." + ) + await self.forced_shutdown() + return + + try: + logger.warning("Instability detected. Dropping known list of workers and reloading it from the system.") + self._fix_counter.increase() + await self._reload_system_state() + logger.warning("Workers reloaded. Applying old config....") + await self._config_store.renew() + logger.warning(f"System stability hopefully renewed. Fix attempt counter is currently {self._fix_counter}") + except BaseException: + logger.error("Failed attempting to fix an error. Forcefully shutting down.", exc_info=True) + await self.forced_shutdown() + + async def _watchdog(self) -> None: # pylint: disable=too-many-branches + while True: + await asyncio.sleep(WATCHDOG_INTERVAL) + + self._fix_counter.try_decrease() + + try: + # gather current state + async with self._manager_lock: + detected_subprocesses = await self._controller.get_subprocess_status() + expected_ids = [x.id for x in self._workers] + if self._gc: + expected_ids.append(self._gc.id) + + invoke_callback = False + + if self._policy_loader: + expected_ids.append(self._policy_loader.id) + + for eid in expected_ids: + if eid not in detected_subprocesses: + logger.error("Subprocess with id '%s' was not found in the system!", eid) + invoke_callback = True + continue + + if detected_subprocesses[eid] is SubprocessStatus.FATAL: + if self._policy_loader and self._policy_loader.id == eid: + logger.info( + "Subprocess '%s' is skipped by WatchDog because its status is monitored in a different way.", + eid, + ) + continue + logger.error("Subprocess '%s' is in FATAL state!", eid) + invoke_callback = True + continue + + if detected_subprocesses[eid] is SubprocessStatus.UNKNOWN: + logger.warning("Subprocess '%s' is in UNKNOWN state!", eid) + + non_registered_ids = detected_subprocesses.keys() - set(expected_ids) + if len(non_registered_ids) != 0: + logger.error( + "Found additional process in the system, which shouldn't be there - %s", + non_registered_ids, + ) + invoke_callback = True + + except asyncio.CancelledError: + raise + except BaseException: + invoke_callback = True + logger.error("Knot Resolver watchdog failed with an unexpected exception.", exc_info=True) + + if invoke_callback: + try: + await self._instability_handler() + except Exception: + logger.error("Watchdog failed while invoking instability callback", exc_info=True) + logger.error("Violently terminating!") + sys.exit(1) diff --git a/python/knot_resolver/manager/log.py b/python/knot_resolver/manager/log.py new file mode 100644 index 00000000..a22898a5 --- /dev/null +++ b/python/knot_resolver/manager/log.py @@ -0,0 +1,105 @@ +import logging +import logging.handlers +import os +import sys +from typing import Optional + +from knot_resolver.manager.config_store import ConfigStore, only_on_real_changes_update +from knot_resolver.manager.constants import STARTUP_LOG_LEVEL +from knot_resolver.datamodel.config_schema import KresConfig +from knot_resolver.datamodel.logging_schema import LogTargetEnum + +logger = logging.getLogger(__name__) + + +def get_log_format(config: KresConfig) -> str: + """ + Based on an environment variable $KRES_SUPRESS_LOG_PREFIX, returns the appropriate format string for logger. + """ + + if os.environ.get("KRES_SUPRESS_LOG_PREFIX") == "true": + # In this case, we are running under supervisord and it's adding prefixes to our output + return "[%(levelname)s] %(name)s: %(message)s" + else: + # In this case, we are running standalone during inicialization and we need to add a prefix to each line + # by ourselves to make it consistent + assert config.logging.target != "syslog" + stream = "" + if config.logging.target == "stderr": + stream = " (stderr)" + + pid = os.getpid() + return f"%(asctime)s manager[{pid}]{stream}: [%(levelname)s] %(name)s: %(message)s" + + +async def _set_log_level(config: KresConfig) -> None: + levels_map = { + "crit": "CRITICAL", + "err": "ERROR", + "warning": "WARNING", + "notice": "WARNING", + "info": "INFO", + "debug": "DEBUG", + } + + # when logging group is set to make us log with DEBUG + if config.logging.groups and "manager" in config.logging.groups: + target = "DEBUG" + # otherwise, follow the standard log level + else: + target = levels_map[config.logging.level] + + # expect exactly one existing log handler on the root + logger.warning(f"Changing logging level to '{target}'") + logging.getLogger().setLevel(target) + + +async def _set_logging_handler(config: KresConfig) -> None: + target: Optional[LogTargetEnum] = config.logging.target + + if target is None: + target = "stdout" + + handler: logging.Handler + if target == "syslog": + handler = logging.handlers.SysLogHandler(address="/dev/log") + handler.setFormatter(logging.Formatter("%(name)s: %(message)s")) + elif target == "stdout": + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter(get_log_format(config))) + elif target == "stderr": + handler = logging.StreamHandler(sys.stderr) + handler.setFormatter(logging.Formatter(get_log_format(config))) + else: + raise RuntimeError(f"Unexpected value '{target}' for log target in the config") + + root = logging.getLogger() + + # if we had a MemoryHandler before, we should give it the new handler where we can flush it + if isinstance(root.handlers[0], logging.handlers.MemoryHandler): + root.handlers[0].setTarget(handler) + + # stop the old handler + root.handlers[0].flush() + root.handlers[0].close() + root.removeHandler(root.handlers[0]) + + # configure the new handler + root.addHandler(handler) + + +@only_on_real_changes_update(lambda config: config.logging) +async def _configure_logger(config: KresConfig) -> None: + await _set_logging_handler(config) + await _set_log_level(config) + + +async def logger_init(config_store: ConfigStore) -> None: + await config_store.register_on_change_callback(_configure_logger) + + +def logger_startup() -> None: + logging.getLogger().setLevel(STARTUP_LOG_LEVEL) + err_handler = logging.StreamHandler(sys.stderr) + err_handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT)) + logging.getLogger().addHandler(logging.handlers.MemoryHandler(10_000, logging.ERROR, err_handler)) diff --git a/python/knot_resolver/manager/main.py b/python/knot_resolver/manager/main.py new file mode 100644 index 00000000..5facc470 --- /dev/null +++ b/python/knot_resolver/manager/main.py @@ -0,0 +1,49 @@ +""" +Effectively the same as normal __main__.py. However, we moved it's content over to this +file to allow us to exclude the __main__.py file from black's autoformatting +""" + +import argparse +import os +import sys +from pathlib import Path +from typing import NoReturn + +from knot_resolver import compat +from knot_resolver.manager.constants import CONFIG_FILE_ENV_VAR, DEFAULT_MANAGER_CONFIG_FILE +from knot_resolver.manager.log import logger_startup +from knot_resolver.manager.server import start_server + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Knot Resolver - caching DNS resolver") + parser.add_argument( + "-c", + "--config", + help="Config file to load. Overrides default config location at '" + str(DEFAULT_MANAGER_CONFIG_FILE) + "'", + type=str, + nargs=1, + required=False, + default=None, + ) + return parser.parse_args() + + +def main() -> NoReturn: + # initial logging is to memory until we read the config + logger_startup() + + # parse arguments + args = parse_args() + + # where to look for config + config_env = os.getenv(CONFIG_FILE_ENV_VAR) + if args.config is not None: + config_path = Path(args.config[0]) + elif config_env is not None: + config_path = Path(config_env) + else: + config_path = DEFAULT_MANAGER_CONFIG_FILE + + exit_code = compat.asyncio.run(start_server(config=config_path)) + sys.exit(exit_code) diff --git a/python/knot_resolver/manager/server.py b/python/knot_resolver/manager/server.py new file mode 100644 index 00000000..cf31a3fc --- /dev/null +++ b/python/knot_resolver/manager/server.py @@ -0,0 +1,637 @@ +import asyncio +import errno +import json +import logging +import os +import signal +import sys +from functools import partial +from http import HTTPStatus +from pathlib import Path +from time import time +from typing import Any, Dict, List, Optional, Set, Union, cast + +from aiohttp import web +from aiohttp.web import middleware +from aiohttp.web_app import Application +from aiohttp.web_response import json_response +from aiohttp.web_runner import AppRunner, TCPSite, UnixSite +from typing_extensions import Literal + +import knot_resolver.utils.custom_atexit as atexit +from knot_resolver.manager import log, statistics +from knot_resolver.compat import asyncio as asyncio_compat +from knot_resolver.manager.config_store import ConfigStore +from knot_resolver.manager.constants import DEFAULT_MANAGER_CONFIG_FILE, PID_FILE_NAME, init_user_constants +from knot_resolver.datamodel.cache_schema import CacheClearRPCSchema +from knot_resolver.datamodel.config_schema import KresConfig, get_rundir_without_validation +from knot_resolver.datamodel.globals import Context, set_global_validation_context +from knot_resolver.datamodel.management_schema import ManagementSchema +from knot_resolver.manager.exceptions import CancelStartupExecInsteadException, KresManagerException +from knot_resolver.controller import get_best_controller_implementation +from knot_resolver.controller.registered_workers import command_single_registered_worker +from knot_resolver.utils import ignore_exceptions_optional +from knot_resolver.utils.async_utils import readfile +from knot_resolver.utils.etag import structural_etag +from knot_resolver.utils.functional import Result +from knot_resolver.utils.modeling.exceptions import ( + AggregateDataValidationError, + DataParsingError, + DataValidationError, +) +from knot_resolver.utils.modeling.parsing import DataFormat, try_to_parse +from knot_resolver.utils.modeling.query import query +from knot_resolver.utils.modeling.types import NoneType +from knot_resolver.utils.systemd_notify import systemd_notify + +from .kres_manager import KresManager + +logger = logging.getLogger(__name__) + + +@middleware +async def error_handler(request: web.Request, handler: Any) -> web.Response: + """ + Generic error handler for route handlers. + + If an exception is thrown during request processing, this middleware catches it + and responds accordingly. + """ + + try: + return await handler(request) + except DataValidationError as e: + return web.Response(text=f"validation of configuration failed:\n{e}", status=HTTPStatus.BAD_REQUEST) + except DataParsingError as e: + return web.Response(text=f"request processing error:\n{e}", status=HTTPStatus.BAD_REQUEST) + except KresManagerException as e: + return web.Response(text=f"request processing failed:\n{e}", status=HTTPStatus.INTERNAL_SERVER_ERROR) + + +def from_mime_type(mime_type: str) -> DataFormat: + formats = { + "application/json": DataFormat.JSON, + "application/octet-stream": DataFormat.JSON, # default in aiohttp + } + if mime_type not in formats: + raise DataParsingError(f"unsupported MIME type '{mime_type}', expected: {str(formats)[1:-1]}") + return formats[mime_type] + + +def parse_from_mime_type(data: str, mime_type: str) -> Any: + return from_mime_type(mime_type).parse_to_dict(data) + + +class Server: + # pylint: disable=too-many-instance-attributes + # This is top-level class containing pretty much everything. Instead of global + # variables, we use instance attributes. That's why there are so many and it's + # ok. + def __init__(self, store: ConfigStore, config_path: Optional[Path]): + # config store & server dynamic reconfiguration + self.config_store = store + + # HTTP server + self.app = Application(middlewares=[error_handler]) + self.runner = AppRunner(self.app) + self.listen: Optional[ManagementSchema] = None + self.site: Union[NoneType, TCPSite, UnixSite] = None + self.listen_lock = asyncio.Lock() + self._config_path: Optional[Path] = config_path + self._exit_code: int = 0 + self._shutdown_event = asyncio.Event() + + async def _reconfigure(self, config: KresConfig) -> None: + await self._reconfigure_listen_address(config) + + async def _deny_management_changes(self, config_old: KresConfig, config_new: KresConfig) -> Result[None, str]: + if config_old.management != config_new.management: + return Result.err( + "/server/management: Changing management API address/unix-socket dynamically is not allowed as it's really dangerous." + " If you really need this feature, please contact the developers and explain why. Technically," + " there are no problems in supporting it. We are only blocking the dynamic changes because" + " we think the consequences of leaving this footgun unprotected are worse than its usefulness." + ) + return Result.ok(None) + + async def _reload_config(self) -> None: + if self._config_path is None: + logger.warning("The manager was started with inlined configuration - can't reload") + else: + try: + data = await readfile(self._config_path) + config = KresConfig(try_to_parse(data)) + await self.config_store.update(config) + logger.info("Configuration file successfully reloaded") + except FileNotFoundError: + logger.error( + f"Configuration file was not found at '{self._config_path}'." + " Something must have happened to it while we were running." + ) + logger.error("Configuration have NOT been changed.") + except (DataParsingError, DataValidationError) as e: + logger.error(f"Failed to parse the updated configuration file: {e}") + logger.error("Configuration have NOT been changed.") + except KresManagerException as e: + logger.error(f"Reloading of the configuration file failed: {e}") + logger.error("Configuration have NOT been changed.") + + async def sigint_handler(self) -> None: + logger.info("Received SIGINT, triggering graceful shutdown") + self.trigger_shutdown(0) + + async def sigterm_handler(self) -> None: + logger.info("Received SIGTERM, triggering graceful shutdown") + self.trigger_shutdown(0) + + async def sighup_handler(self) -> None: + logger.info("Received SIGHUP, reloading configuration file") + systemd_notify(RELOADING="1") + await self._reload_config() + systemd_notify(READY="1") + + @staticmethod + def all_handled_signals() -> Set[signal.Signals]: + return {signal.SIGHUP, signal.SIGINT, signal.SIGTERM} + + def bind_signal_handlers(self): + asyncio_compat.add_async_signal_handler(signal.SIGTERM, self.sigterm_handler) + asyncio_compat.add_async_signal_handler(signal.SIGINT, self.sigint_handler) + asyncio_compat.add_async_signal_handler(signal.SIGHUP, self.sighup_handler) + + def unbind_signal_handlers(self): + asyncio_compat.remove_signal_handler(signal.SIGTERM) + asyncio_compat.remove_signal_handler(signal.SIGINT) + asyncio_compat.remove_signal_handler(signal.SIGHUP) + + async def start(self) -> None: + self._setup_routes() + await self.runner.setup() + await self.config_store.register_verifier(self._deny_management_changes) + await self.config_store.register_on_change_callback(self._reconfigure) + + async def wait_for_shutdown(self) -> None: + await self._shutdown_event.wait() + + def trigger_shutdown(self, exit_code: int) -> None: + self._shutdown_event.set() + self._exit_code = exit_code + + async def _handler_index(self, _request: web.Request) -> web.Response: + """ + Dummy index handler to indicate that the server is indeed running... + """ + return json_response( + { + "msg": "Knot Resolver Manager is running! The configuration endpoint is at /config", + "status": "RUNNING", + } + ) + + async def _handler_config_query(self, request: web.Request) -> web.Response: + """ + Route handler for changing resolver configuration + """ + # There are a lot of local variables in here, but they are usually immutable (almost SSA form :) ) + # pylint: disable=too-many-locals + + # parse the incoming data + if request.method == "GET": + update_with: Optional[Dict[str, Any]] = None + else: + update_with = parse_from_mime_type(await request.text(), request.content_type) + document_path = request.match_info["path"] + getheaders = ignore_exceptions_optional(List[str], None, KeyError)(request.headers.getall) + etags = getheaders("if-match") + not_etags = getheaders("if-none-match") + current_config: Dict[str, Any] = self.config_store.get().get_unparsed_data() + + # stop processing if etags + def strip_quotes(s: str) -> str: + return s.strip('"') + + # WARNING: this check is prone to race conditions. When changing, make sure that the current config + # is really the latest current config (i.e. no await in between obtaining the config and the checks) + status = HTTPStatus.NOT_MODIFIED if request.method in ("GET", "HEAD") else HTTPStatus.PRECONDITION_FAILED + if etags is not None and structural_etag(current_config) not in map(strip_quotes, etags): + return web.Response(status=status) + if not_etags is not None and structural_etag(current_config) in map(strip_quotes, not_etags): + return web.Response(status=status) + + # run query + op = cast(Literal["get", "delete", "patch", "put"], request.method.lower()) + new_config, to_return = query(current_config, op, document_path, update_with) + + # update the config + if request.method != "GET": + # validate + config_validated = KresConfig(new_config) + # apply + await self.config_store.update(config_validated) + + # serialize the response (the `to_return` object is a Dict/list/scalar, we want to return json) + resp_text: Optional[str] = json.dumps(to_return) if to_return is not None else None + + # create the response and return it + res = web.Response(status=HTTPStatus.OK, text=resp_text, content_type="application/json") + res.headers.add("ETag", f'"{structural_etag(new_config)}"') + return res + + async def _handler_metrics(self, request: web.Request) -> web.Response: + raise web.HTTPMovedPermanently("/metrics/json") + + async def _handler_metrics_json(self, _request: web.Request) -> web.Response: + return web.Response( + body=await statistics.report_stats(), + content_type="application/json", + charset="utf8", + ) + + async def _handler_metrics_prometheus(self, _request: web.Request) -> web.Response: + + metrics_report = await statistics.report_stats(prometheus_format=True) + if not metrics_report: + raise web.HTTPNotFound() + + return web.Response( + body=metrics_report, + content_type="text/plain", + charset="utf8", + ) + + async def _handler_cache_clear(self, request: web.Request) -> web.Response: + data = parse_from_mime_type(await request.text(), request.content_type) + + try: + config = CacheClearRPCSchema(data) + except (AggregateDataValidationError, DataValidationError) as e: + return web.Response( + body=e, + status=HTTPStatus.BAD_REQUEST, + content_type="text/plain", + charset="utf8", + ) + + _, result = await command_single_registered_worker(config.render_lua()) + return web.Response( + body=json.dumps(result), + content_type="application/json", + charset="utf8", + ) + + async def _handler_schema(self, _request: web.Request) -> web.Response: + return web.json_response( + KresConfig.json_schema(), headers={"Access-Control-Allow-Origin": "*"}, dumps=partial(json.dumps, indent=4) + ) + + async def _handle_view_schema(self, _request: web.Request) -> web.Response: + """ + Provides a UI for visuallising and understanding JSON schema. + + The feature in the Knot Resolver Manager to render schemas is unwanted, as it's completely + out of scope. However, it can be convinient. We therefore rely on a public web-based viewers + and provide just a redirect. If this feature ever breaks due to disapearance of the public + service, we can fix it. But we are not guaranteeing, that this will always work. + """ + + return web.Response( + text=""" + <html> + <head><title>Redirect to schema viewer</title></head> + <body> + <script> + // we are using JS in order to use proper host + let protocol = window.location.protocol; + let host = window.location.host; + let url = encodeURIComponent(`${protocol}//${host}/schema`); + window.location.replace(`https://json-schema.app/view/%23?url=${url}`); + </script> + <h1>JavaScript required for a dynamic redirect...</h1> + </body> + </html> + """, + content_type="text/html", + ) + + async def _handler_stop(self, _request: web.Request) -> web.Response: + """ + Route handler for shutting down the server (and whole manager) + """ + + self._shutdown_event.set() + logger.info("Shutdown event triggered...") + return web.Response(text="Shutting down...") + + async def _handler_reload(self, _request: web.Request) -> web.Response: + """ + Route handler for reloading the server + """ + + logger.info("Reloading event triggered...") + await self._reload_config() + return web.Response(text="Reloading...") + + def _setup_routes(self) -> None: + self.app.add_routes( + [ + web.get("/", self._handler_index), + web.get(r"/v1/config{path:.*}", self._handler_config_query), + web.put(r"/v1/config{path:.*}", self._handler_config_query), + web.delete(r"/v1/config{path:.*}", self._handler_config_query), + web.patch(r"/v1/config{path:.*}", self._handler_config_query), + web.post("/stop", self._handler_stop), + web.post("/reload", self._handler_reload), + web.get("/schema", self._handler_schema), + web.get("/schema/ui", self._handle_view_schema), + web.get("/metrics", self._handler_metrics), + web.get("/metrics/json", self._handler_metrics_json), + web.get("/metrics/prometheus", self._handler_metrics_prometheus), + web.post("/cache/clear", self._handler_cache_clear), + ] + ) + + async def _reconfigure_listen_address(self, config: KresConfig) -> None: + async with self.listen_lock: + mgn = config.management + + # if the listen address did not change, do nothing + if self.listen == mgn: + return + + # start the new listen address + nsite: Union[web.TCPSite, web.UnixSite] + if mgn.unix_socket: + nsite = web.UnixSite(self.runner, str(mgn.unix_socket)) + logger.info(f"Starting API HTTP server on http+unix://{mgn.unix_socket}") + elif mgn.interface: + nsite = web.TCPSite(self.runner, str(mgn.interface.addr), int(mgn.interface.port)) + logger.info(f"Starting API HTTP server on http://{mgn.interface.addr}:{mgn.interface.port}") + else: + raise KresManagerException("Requested API on unsupported configuration format.") + await nsite.start() + + # stop the old listen + assert (self.listen is None) == (self.site is None) + if self.listen is not None and self.site is not None: + if self.listen.unix_socket: + logger.info(f"Stopping API HTTP server on http+unix://{mgn.unix_socket}") + elif self.listen.interface: + logger.info( + f"Stopping API HTTP server on http://{self.listen.interface.addr}:{self.listen.interface.port}" + ) + await self.site.stop() + + # save new state + self.listen = mgn + self.site = nsite + + async def shutdown(self) -> None: + if self.site is not None: + await self.site.stop() + await self.runner.cleanup() + + def get_exit_code(self) -> int: + return self._exit_code + + +async def _load_raw_config(config: Union[Path, Dict[str, Any]]) -> Dict[str, Any]: + # Initial configuration of the manager + if isinstance(config, Path): + if not config.exists(): + raise KresManagerException( + f"Manager is configured to load config file at {config} on startup, but the file does not exist." + ) + else: + logger.info(f"Loading configuration from '{config}' file.") + config = try_to_parse(await readfile(config)) + + # validate the initial configuration + assert isinstance(config, dict) + return config + + +async def _load_config(config: Dict[str, Any]) -> KresConfig: + config_validated = KresConfig(config) + return config_validated + + +async def _init_config_store(config: Dict[str, Any]) -> ConfigStore: + config_validated = await _load_config(config) + config_store = ConfigStore(config_validated) + return config_store + + +async def _init_manager(config_store: ConfigStore, server: Server) -> KresManager: + """ + Called asynchronously when the application initializes. + """ + + # Instantiate subprocess controller (if we wanted to, we could switch it at this point) + controller = await get_best_controller_implementation(config_store.get()) + + # Create KresManager. This will perform autodetection of available service managers and + # select the most appropriate to use (or use the one configured directly) + manager = await KresManager.create(controller, config_store, server.trigger_shutdown) + + logger.info("Initial configuration applied. Process manager initialized...") + return manager + + +async def _deny_working_directory_changes(config_old: KresConfig, config_new: KresConfig) -> Result[None, str]: + if config_old.rundir != config_new.rundir: + return Result.err("Changing manager's `rundir` during runtime is not allowed.") + + return Result.ok(None) + + +def _set_working_directory(config_raw: Dict[str, Any]) -> None: + try: + rundir = get_rundir_without_validation(config_raw) + except ValueError as e: + raise DataValidationError(str(e), "/rundir") from e + + logger.debug(f"Changing working directory to '{rundir.to_path().absolute()}'.") + os.chdir(rundir.to_path()) + + +def _lock_working_directory(attempt: int = 0) -> None: + # the following syscall is atomic, it's essentially the same as acquiring a lock + try: + pidfile_fd = os.open(PID_FILE_NAME, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644) + except OSError as e: + if e.errno == errno.EEXIST and attempt == 0: + # the pid file exists, let's check PID + with open(PID_FILE_NAME, "r", encoding="utf-8") as f: + pid = int(f.read().strip()) + try: + os.kill(pid, 0) + except OSError as e2: + if e2.errno == errno.ESRCH: + os.unlink(PID_FILE_NAME) + _lock_working_directory(attempt=attempt + 1) + return + raise KresManagerException( + "Another manager is running in the same working directory." + f" PID file is located at {os.getcwd()}/{PID_FILE_NAME}" + ) from e + else: + raise KresManagerException( + "Another manager is running in the same working directory." + f" PID file is located at {os.getcwd()}/{PID_FILE_NAME}" + ) from e + + # now we know that we are the only manager running in this directory + + # write PID to the pidfile and close it afterwards + pidfile = os.fdopen(pidfile_fd, "w") + pid = os.getpid() + pidfile.write(f"{pid}\n") + pidfile.close() + + # make sure that the file is deleted on shutdown + atexit.register(lambda: os.unlink(PID_FILE_NAME)) + + +async def _sigint_while_shutting_down(): + logger.warning( + "Received SIGINT while already shutting down. Ignoring." + " If you want to forcefully stop the manager right now, use SIGTERM." + ) + + +async def _sigterm_while_shutting_down(): + logger.warning("Received SIGTERM. Invoking dirty shutdown!") + sys.exit(128 + signal.SIGTERM) + + +async def start_server(config: Path = DEFAULT_MANAGER_CONFIG_FILE) -> int: + # This function is quite long, but it describes how manager runs. So let's silence pylint + # pylint: disable=too-many-statements + + start_time = time() + working_directory_on_startup = os.getcwd() + manager: Optional[KresManager] = None + + # Block signals during initialization to force their processing once everything is ready + signal.pthread_sigmask(signal.SIG_BLOCK, Server.all_handled_signals()) + + # before starting server, initialize the subprocess controller, config store, etc. Any errors during inicialization + # are fatal + try: + # Make sure that the config path does not change meaning when we change working directory + config = config.absolute() + + # Preprocess config - load from file or in general take it to the last step before validation. + config_raw = await _load_raw_config(config) + + # before processing any configuration, set validation context + # - resolve_root = root against which all relative paths will be resolved + set_global_validation_context(Context(config.parent, True)) + + # We want to change cwd as soon as possible. Some parts of the codebase are using os.getcwd() to get the + # working directory. + # + # If we fail to read rundir from unparsed config, the first config validation error comes from here + _set_working_directory(config_raw) + + # We don't want more than one manager in a single working directory. So we lock it with a PID file. + # Warning - this does not prevent multiple managers with the same naming of kresd service. + _lock_working_directory() + + # set_global_validation_context(Context(config.parent)) + + # After the working directory is set, we can initialize proper config store with a newly parsed configuration. + config_store = await _init_config_store(config_raw) + + # Some "constants" need to be loaded from the initial config, some need to be stored from the initial run conditions + await init_user_constants(config_store, working_directory_on_startup) + + # This behaviour described above with paths means, that we MUST NOT allow `rundir` change after initialization. + # It would cause strange problems because every other path configuration depends on it. Therefore, we have to + # add a check to the config store, which disallows changes. + await config_store.register_verifier(_deny_working_directory_changes) + + # Up to this point, we have been logging to memory buffer. But now, when we have the configuration loaded, we + # can flush the buffer into the proper place + await log.logger_init(config_store) + + # With configuration on hand, we can initialize monitoring. We want to do this before any subprocesses are + # started, therefore before initializing manager + await statistics.init_monitoring(config_store) + + # prepare instance of the server (no side effects) + server = Server(config_store, config) + + # After we have loaded the configuration, we can start worring about subprocess management. + manager = await _init_manager(config_store, server) + + except CancelStartupExecInsteadException as e: + # if we caught this exception, some component wants to perform a reexec during startup. Most likely, it would + # be a subprocess manager like supervisord, which wants to make sure the manager runs under supervisord in + # the process tree. So now we stop everything, and exec what we are told to. We are assuming, that the thing + # we'll exec will invoke us again. + logger.info("Exec requested with arguments: %s", str(e.exec_args)) + + # unblock signals, this could actually terminate us straight away + signal.pthread_sigmask(signal.SIG_UNBLOCK, Server.all_handled_signals()) + + # run exit functions + atexit.run_callbacks() + + # and finally exec what we were told to exec + os.execl(*e.exec_args) + + except KresManagerException as e: + # We caught an error with a pretty error message. Just print it and exit. + logger.error(e) + return 1 + + except BaseException: + logger.error("Uncaught generic exception during manager inicialization...", exc_info=True) + return 1 + + # At this point, all backend functionality-providing components are initialized. It's therefore save to start + # the API server. + try: + await server.start() + except OSError as e: + if e.errno in (errno.EADDRINUSE, errno.EADDRNOTAVAIL): + # fancy error reporting of network binding errors + logger.error(str(e)) + await manager.stop() + return 1 + raise + + # At this point, pretty much everything is ready to go. We should just make sure the user can shut + # the manager down with signals. + server.bind_signal_handlers() + signal.pthread_sigmask(signal.SIG_UNBLOCK, Server.all_handled_signals()) + + logger.info(f"Manager fully initialized and running in {round(time() - start_time, 3)} seconds") + + # notify systemd/anything compatible that we are ready + systemd_notify(READY="1") + + await server.wait_for_shutdown() + + # notify systemd that we are shutting down + systemd_notify(STOPPING="1") + + # Ok, now we are tearing everything down. + + # First of all, let's block all unwanted interruptions. We don't want to be reconfiguring kresd's while + # shutting down. + signal.pthread_sigmask(signal.SIG_BLOCK, Server.all_handled_signals()) + server.unbind_signal_handlers() + # on the other hand, we want to immediatelly stop when the user really wants us to stop + asyncio_compat.add_async_signal_handler(signal.SIGTERM, _sigterm_while_shutting_down) + asyncio_compat.add_async_signal_handler(signal.SIGINT, _sigint_while_shutting_down) + signal.pthread_sigmask(signal.SIG_UNBLOCK, {signal.SIGTERM, signal.SIGINT}) + + # After triggering shutdown, we neet to clean everything up + logger.info("Stopping API service...") + await server.shutdown() + logger.info("Stopping kresd manager...") + await manager.stop() + logger.info(f"The manager run for {round(time() - start_time)} seconds...") + return server.get_exit_code() diff --git a/python/knot_resolver/manager/statistics.py b/python/knot_resolver/manager/statistics.py new file mode 100644 index 00000000..292a480d --- /dev/null +++ b/python/knot_resolver/manager/statistics.py @@ -0,0 +1,434 @@ +import asyncio +import importlib +import json +import logging +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple + +from knot_resolver import compat +from knot_resolver.manager.config_store import ConfigStore, only_on_real_changes_update +from knot_resolver.datamodel.config_schema import KresConfig +from knot_resolver.controller.registered_workers import ( + command_registered_workers, + get_registered_workers_kresids, +) +from knot_resolver.utils.functional import Result +from knot_resolver.utils.modeling.parsing import DataFormat + +if TYPE_CHECKING: + from knot_resolver.controller.interface import KresID + +logger = logging.getLogger(__name__) + + +_prometheus_support = False +if importlib.util.find_spec("prometheus_client"): + _prometheus_support = True + + +if _prometheus_support: + from prometheus_client import exposition # type: ignore + from prometheus_client.bridge.graphite import GraphiteBridge # type: ignore + from prometheus_client.core import GaugeMetricFamily # type: ignore + from prometheus_client.core import REGISTRY, CounterMetricFamily, HistogramMetricFamily, Metric + + def _counter(name: str, description: str, label: Tuple[str, str], value: float) -> CounterMetricFamily: + c = CounterMetricFamily(name, description, labels=(label[0],)) + c.add_metric((label[1],), value) # type: ignore + return c + + def _gauge(name: str, description: str, label: Tuple[str, str], value: float) -> GaugeMetricFamily: + c = GaugeMetricFamily(name, description, labels=(label[0],)) + c.add_metric((label[1],), value) # type: ignore + return c + + def _histogram( + name: str, description: str, label: Tuple[str, str], buckets: List[Tuple[str, int]], sum_value: float + ) -> HistogramMetricFamily: + c = HistogramMetricFamily(name, description, labels=(label[0],)) + c.add_metric((label[1],), buckets, sum_value=sum_value) # type: ignore + return c + + def _parse_resolver_metrics(instance_id: "KresID", metrics: Any) -> Generator[Metric, None, None]: + sid = str(instance_id) + + # response latency histogram + BUCKET_NAMES_IN_RESOLVER = ("1ms", "10ms", "50ms", "100ms", "250ms", "500ms", "1000ms", "1500ms", "slow") + BUCKET_NAMES_PROMETHEUS = ("0.001", "0.01", "0.05", "0.1", "0.25", "0.5", "1.0", "1.5", "+Inf") + yield _histogram( + "resolver_response_latency", + "Time it takes to respond to queries in seconds", + label=("instance_id", sid), + buckets=[ + (bnp, metrics["answer"][f"{duration}"]) + for bnp, duration in zip(BUCKET_NAMES_PROMETHEUS, BUCKET_NAMES_IN_RESOLVER) + ], + sum_value=metrics["answer"]["sum_ms"] / 1_000, + ) + + yield _counter( + "resolver_request_total", + "total number of DNS requests (including internal client requests)", + label=("instance_id", sid), + value=metrics["request"]["total"], + ) + yield _counter( + "resolver_request_internal", + "number of internal requests generated by Knot Resolver (e.g. DNSSEC trust anchor updates)", + label=("instance_id", sid), + value=metrics["request"]["internal"], + ) + yield _counter( + "resolver_request_udp", + "number of external requests received over plain UDP (RFC 1035)", + label=("instance_id", sid), + value=metrics["request"]["udp"], + ) + yield _counter( + "resolver_request_tcp", + "number of external requests received over plain TCP (RFC 1035)", + label=("instance_id", sid), + value=metrics["request"]["tcp"], + ) + yield _counter( + "resolver_request_dot", + "number of external requests received over DNS-over-TLS (RFC 7858)", + label=("instance_id", sid), + value=metrics["request"]["dot"], + ) + yield _counter( + "resolver_request_doh", + "number of external requests received over DNS-over-HTTP (RFC 8484)", + label=("instance_id", sid), + value=metrics["request"]["doh"], + ) + yield _counter( + "resolver_request_xdp", + "number of external requests received over plain UDP via an AF_XDP socket", + label=("instance_id", sid), + value=metrics["request"]["xdp"], + ) + yield _counter( + "resolver_answer_total", + "total number of answered queries", + label=("instance_id", sid), + value=metrics["answer"]["total"], + ) + yield _counter( + "resolver_answer_cached", + "number of queries answered from cache", + label=("instance_id", sid), + value=metrics["answer"]["cached"], + ) + yield _counter( + "resolver_answer_stale", + "number of queries that utilized stale data", + label=("instance_id", sid), + value=metrics["answer"]["stale"], + ) + yield _counter( + "resolver_answer_rcode_noerror", + "number of NOERROR answers", + label=("instance_id", sid), + value=metrics["answer"]["noerror"], + ) + yield _counter( + "resolver_answer_rcode_nodata", + "number of NOERROR answers without any data", + label=("instance_id", sid), + value=metrics["answer"]["nodata"], + ) + yield _counter( + "resolver_answer_rcode_nxdomain", + "number of NXDOMAIN answers", + label=("instance_id", sid), + value=metrics["answer"]["nxdomain"], + ) + yield _counter( + "resolver_answer_rcode_servfail", + "number of SERVFAIL answers", + label=("instance_id", sid), + value=metrics["answer"]["servfail"], + ) + yield _counter( + "resolver_answer_flag_aa", + "number of authoritative answers", + label=("instance_id", sid), + value=metrics["answer"]["aa"], + ) + yield _counter( + "resolver_answer_flag_tc", + "number of truncated answers", + label=("instance_id", sid), + value=metrics["answer"]["tc"], + ) + yield _counter( + "resolver_answer_flag_ra", + "number of answers with recursion available flag", + label=("instance_id", sid), + value=metrics["answer"]["ra"], + ) + yield _counter( + "resolver_answer_flags_rd", + "number of recursion desired (in answer!)", + label=("instance_id", sid), + value=metrics["answer"]["rd"], + ) + yield _counter( + "resolver_answer_flag_ad", + "number of authentic data (DNSSEC) answers", + label=("instance_id", sid), + value=metrics["answer"]["ad"], + ) + yield _counter( + "resolver_answer_flag_cd", + "number of checking disabled (DNSSEC) answers", + label=("instance_id", sid), + value=metrics["answer"]["cd"], + ) + yield _counter( + "resolver_answer_flag_do", + "number of DNSSEC answer OK", + label=("instance_id", sid), + value=metrics["answer"]["do"], + ) + yield _counter( + "resolver_answer_flag_edns0", + "number of answers with EDNS0 present", + label=("instance_id", sid), + value=metrics["answer"]["edns0"], + ) + yield _counter( + "resolver_query_edns", + "number of queries with EDNS present", + label=("instance_id", sid), + value=metrics["query"]["edns"], + ) + yield _counter( + "resolver_query_dnssec", + "number of queries with DNSSEC DO=1", + label=("instance_id", sid), + value=metrics["query"]["dnssec"], + ) + + if "predict" in metrics: + if "epoch" in metrics["predict"]: + yield _counter( + "resolver_predict_epoch", + "current prediction epoch (based on time of day and sampling window)", + label=("instance_id", sid), + value=metrics["predict"]["epoch"], + ) + yield _counter( + "resolver_predict_queue", + "number of queued queries in current window", + label=("instance_id", sid), + value=metrics["predict"]["queue"], + ) + yield _counter( + "resolver_predict_learned", + "number of learned queries in current window", + label=("instance_id", sid), + value=metrics["predict"]["learned"], + ) + + def _create_resolver_metrics_loaded_gauge(kresid: "KresID", loaded: bool) -> GaugeMetricFamily: + return _gauge( + "resolver_metrics_loaded", + "0 if metrics from resolver instance were not loaded, otherwise 1", + label=("instance_id", str(kresid)), + value=int(loaded), + ) + + async def _deny_turning_off_graphite_bridge(old_config: KresConfig, new_config: KresConfig) -> Result[None, str]: + if old_config.monitoring.graphite and not new_config.monitoring.graphite: + return Result.err( + "You can't turn off graphite monitoring dynamically. If you really want this feature, please let the developers know." + ) + + if ( + old_config.monitoring.graphite is not None + and new_config.monitoring.graphite is not None + and old_config.monitoring.graphite != new_config.monitoring.graphite + ): + return Result.err("Changing graphite exporter configuration in runtime is not allowed.") + + return Result.ok(None) + + _graphite_bridge: Optional[GraphiteBridge] = None + + @only_on_real_changes_update(lambda c: c.monitoring.graphite) + async def _configure_graphite_bridge(config: KresConfig) -> None: + """ + Starts graphite bridge if required + """ + global _graphite_bridge + if config.monitoring.graphite is not False and _graphite_bridge is None: + logger.info( + "Starting Graphite metrics exporter for [%s]:%d", + str(config.monitoring.graphite.host), + int(config.monitoring.graphite.port), + ) + _graphite_bridge = GraphiteBridge( + (str(config.monitoring.graphite.host), int(config.monitoring.graphite.port)) + ) + _graphite_bridge.start( # type: ignore + interval=config.monitoring.graphite.interval.seconds(), prefix=str(config.monitoring.graphite.prefix) + ) + + +class ResolverCollector: + def __init__(self, config_store: ConfigStore) -> None: + self._stats_raw: "Optional[Dict[KresID, object]]" = None + self._config_store: ConfigStore = config_store + self._collection_task: "Optional[asyncio.Task[None]]" = None + self._skip_immediate_collection: bool = False + + if _prometheus_support: + + def collect(self) -> Generator[Metric, None, None]: + # schedule new stats collection + self._trigger_stats_collection() + + # if we have no data, return metrics with information about it and exit + if self._stats_raw is None: + for kresid in get_registered_workers_kresids(): + yield _create_resolver_metrics_loaded_gauge(kresid, False) + return + + # if we have data, parse them + for kresid in get_registered_workers_kresids(): + success = False + try: + if kresid in self._stats_raw: + metrics = self._stats_raw[kresid] + yield from _parse_resolver_metrics(kresid, metrics) + success = True + except json.JSONDecodeError: + logger.warning( + "Failed to load metrics from resolver instance %s: failed to parse statistics", str(kresid) + ) + except KeyError as e: + logger.warning( + "Failed to load metrics from resolver instance %s: attempted to read missing statistic %s", + str(kresid), + str(e), + ) + + yield _create_resolver_metrics_loaded_gauge(kresid, success) + + def describe(self) -> List[Metric]: + # this function prevents the collector registry from invoking the collect function on startup + return [] + + def report_json(self) -> str: + # schedule new stats collection + self._trigger_stats_collection() + + # if we have no data, return metrics with information about it and exit + if self._stats_raw is None: + no_stats_dict: Dict[str, None] = {} + for kresid in get_registered_workers_kresids(): + no_stats_dict[str(kresid)] = None + return DataFormat.JSON.dict_dump(no_stats_dict) + + stats_dict: Dict[str, object] = {} + for kresid, stats in self._stats_raw.items(): + stats_dict[str(kresid)] = stats + + return DataFormat.JSON.dict_dump(stats_dict) + + async def collect_kresd_stats(self, _triggered_from_prometheus_library: bool = False) -> None: + if self._skip_immediate_collection: + # this would happen because we are calling this function first manually before stat generation, + # and once again immediately afterwards caused by the prometheus library's stat collection + # + # this is a code made to solve problem with calling async functions from sync methods + self._skip_immediate_collection = False + return + + config = self._config_store.get() + + if config.monitoring.enabled == "manager-only": + logger.debug("Skipping kresd stat collection due to configuration") + self._stats_raw = None + return + + lazy = config.monitoring.enabled == "lazy" + cmd = "collect_lazy_statistics()" if lazy else "collect_statistics()" + logger.debug("Collecting kresd stats with method '%s'", cmd) + stats_raw = await command_registered_workers(cmd) + self._stats_raw = stats_raw + + # if this function was not called by the prometheus library and calling collect() is imminent, + # we should block the next collection cycle as it would be useless + if not _triggered_from_prometheus_library: + self._skip_immediate_collection = True + + def _trigger_stats_collection(self) -> None: + # we are running inside an event loop, but in a synchronous function and that sucks a lot + # it means that we shouldn't block the event loop by performing a blocking stats collection + # but it also means that we can't yield to the event loop as this function is synchronous + # therefore we can only start a new task, but we can't wait for it + # which causes the metrics to be delayed by one collection pass (not the best, but probably good enough) + # + # this issue can be prevented by calling the `collect_kresd_stats()` function manually before entering + # the Prometheus library. We just have to prevent the library from invoking it again. See the mentioned + # function for details + + if compat.asyncio.is_event_loop_running(): + # when running, we can schedule the new data collection + if self._collection_task is not None and not self._collection_task.done(): + logger.warning("Statistics collection task is still running. Skipping scheduling of a new one!") + else: + self._collection_task = compat.asyncio.create_task( + self.collect_kresd_stats(_triggered_from_prometheus_library=True) + ) + + else: + # when not running, we can start a new loop (we are not in the manager's main thread) + compat.asyncio.run(self.collect_kresd_stats(_triggered_from_prometheus_library=True)) + + +_resolver_collector: Optional[ResolverCollector] = None + + +async def _collect_stats() -> None: + # manually trigger stat collection so that we do not have to wait for it + if _resolver_collector is not None: + await _resolver_collector.collect_kresd_stats() + else: + raise RuntimeError("Function invoked before initializing the module!") + + +async def report_stats(prometheus_format: bool = False) -> Optional[bytes]: + """ + Collects metrics from everything, returns data string in JSON (default) or Prometheus format. + """ + + # manually trigger stat collection so that we do not have to wait for it + if _resolver_collector is not None: + await _resolver_collector.collect_kresd_stats() + else: + raise RuntimeError("Function invoked before initializing the module!") + + if prometheus_format: + if _prometheus_support: + return exposition.generate_latest() # type: ignore + return None + return _resolver_collector.report_json().encode() + + +async def init_monitoring(config_store: ConfigStore) -> None: + """ + Initialize monitoring. Must be called before any other function from this module. + """ + global _resolver_collector + _resolver_collector = ResolverCollector(config_store) + + if _prometheus_support: + # register metrics collector + REGISTRY.register(_resolver_collector) # type: ignore + + # register graphite bridge + await config_store.register_verifier(_deny_turning_off_graphite_bridge) + await config_store.register_on_change_callback(_configure_graphite_bridge) diff --git a/python/knot_resolver/utils/__init__.py b/python/knot_resolver/utils/__init__.py new file mode 100644 index 00000000..edc36fca --- /dev/null +++ b/python/knot_resolver/utils/__init__.py @@ -0,0 +1,45 @@ +from typing import Any, Callable, Optional, Type, TypeVar + +T = TypeVar("T") + + +def ignore_exceptions_optional( + _tp: Type[T], default: Optional[T], *exceptions: Type[BaseException] +) -> Callable[[Callable[..., Optional[T]]], Callable[..., Optional[T]]]: + """ + Decorator, that wraps around a function preventing it from raising exceptions + and instead returning the configured default value. + + :param Type[T] _tp: Return type of the function. Essentialy only a template argument for type-checking + :param T default: The value to return as a default + :param List[Type[BaseException]] exceptions: The list of exceptions to catch + :return: value of the decorated function, or default if exception raised + :rtype: T + """ + + def decorator(func: Callable[..., Optional[T]]) -> Callable[..., Optional[T]]: + def f(*nargs: Any, **nkwargs: Any) -> Optional[T]: + try: + return func(*nargs, **nkwargs) + except BaseException as e: + if isinstance(e, exceptions): + return default + else: + raise e + + return f + + return decorator + + +def ignore_exceptions( + default: T, *exceptions: Type[BaseException] +) -> Callable[[Callable[..., Optional[T]]], Callable[..., Optional[T]]]: + return ignore_exceptions_optional(type(default), default, *exceptions) + + +def phantom_use(var: Any) -> None: # pylint: disable=unused-argument + """ + Function, which consumes its argument doing absolutely nothing with it. Useful + for convincing pylint, that we need the variable even when its unused. + """ diff --git a/python/knot_resolver/utils/async_utils.py b/python/knot_resolver/utils/async_utils.py new file mode 100644 index 00000000..a5acdbd5 --- /dev/null +++ b/python/knot_resolver/utils/async_utils.py @@ -0,0 +1,129 @@ +import asyncio +import os +import pkgutil +import signal +import sys +import time +from asyncio import create_subprocess_exec, create_subprocess_shell +from pathlib import PurePath +from threading import Thread +from typing import Any, Dict, Generic, List, Optional, TypeVar, Union + +from knot_resolver.compat.asyncio import to_thread + + +def unblock_signals(): + if sys.version_info.major >= 3 and sys.version_info.minor >= 8: + signal.pthread_sigmask(signal.SIG_UNBLOCK, signal.valid_signals()) # type: ignore + else: + # the list of signals is not exhaustive, but it should cover all signals we might ever want to block + signal.pthread_sigmask( + signal.SIG_UNBLOCK, + { + signal.SIGHUP, + signal.SIGINT, + signal.SIGTERM, + signal.SIGUSR1, + signal.SIGUSR2, + }, + ) + + +async def call( + cmd: Union[str, bytes, List[str], List[bytes]], shell: bool = False, discard_output: bool = False +) -> int: + """ + custom async alternative to subprocess.call() + """ + kwargs: Dict[str, Any] = { + "preexec_fn": unblock_signals, + } + if discard_output: + kwargs["stdout"] = asyncio.subprocess.DEVNULL + kwargs["stderr"] = asyncio.subprocess.DEVNULL + + if shell: + if isinstance(cmd, list): + raise RuntimeError("can't use list of arguments with shell=True") + proc = await create_subprocess_shell(cmd, **kwargs) + else: + if not isinstance(cmd, list): + raise RuntimeError( + "Please use list of arguments, not a single string. It will prevent ambiguity when parsing" + ) + proc = await create_subprocess_exec(*cmd, **kwargs) + + return await proc.wait() + + +async def readfile(path: Union[str, PurePath]) -> str: + """ + asynchronously read whole file and return its content + """ + + def readfile_sync(path: Union[str, PurePath]) -> str: + with open(path, "r", encoding="utf8") as f: + return f.read() + + return await to_thread(readfile_sync, path) + + +async def writefile(path: Union[str, PurePath], content: str) -> None: + """ + asynchronously set content of a file to a given string `content`. + """ + + def writefile_sync(path: Union[str, PurePath], content: str) -> int: + with open(path, "w", encoding="utf8") as f: + return f.write(content) + + await to_thread(writefile_sync, path, content) + + +async def wait_for_process_termination(pid: int, sleep_sec: float = 0) -> None: + """ + will wait for any process (does not have to be a child process) given by its PID to terminate + + sleep_sec configures the granularity, with which we should return + """ + + def wait_sync(pid: int, sleep_sec: float) -> None: + while True: + try: + os.kill(pid, 0) + if sleep_sec == 0: + os.sched_yield() + else: + time.sleep(sleep_sec) + except ProcessLookupError: + break + + await to_thread(wait_sync, pid, sleep_sec) + + +async def read_resource(package: str, filename: str) -> Optional[bytes]: + return await to_thread(pkgutil.get_data, package, filename) + + +T = TypeVar("T") + + +class BlockingEventDispatcher(Thread, Generic[T]): + def __init__(self, name: str = "blocking_event_dispatcher") -> None: + super().__init__(name=name, daemon=True) + # warning: the asyncio queue is not thread safe + self._removed_unit_names: "asyncio.Queue[T]" = asyncio.Queue() + self._main_event_loop = asyncio.get_event_loop() + + def dispatch_event(self, event: T) -> None: + """ + Method to dispatch events from the blocking thread + """ + + async def add_to_queue(): + await self._removed_unit_names.put(event) + + self._main_event_loop.call_soon_threadsafe(add_to_queue) + + async def next_event(self) -> T: + return await self._removed_unit_names.get() diff --git a/python/knot_resolver/utils/custom_atexit.py b/python/knot_resolver/utils/custom_atexit.py new file mode 100644 index 00000000..2fe55433 --- /dev/null +++ b/python/knot_resolver/utils/custom_atexit.py @@ -0,0 +1,20 @@ +""" +Custom replacement for standard module `atexit`. We use `atexit` behind the scenes, we just add the option +to invoke the exit functions manually. +""" + +import atexit +from typing import Callable, List + +_at_exit_functions: List[Callable[[], None]] = [] + + +def register(func: Callable[[], None]) -> None: + _at_exit_functions.append(func) + atexit.register(func) + + +def run_callbacks() -> None: + for func in _at_exit_functions: + func() + atexit.unregister(func) diff --git a/python/knot_resolver/utils/etag.py b/python/knot_resolver/utils/etag.py new file mode 100644 index 00000000..bb80700b --- /dev/null +++ b/python/knot_resolver/utils/etag.py @@ -0,0 +1,10 @@ +import base64 +import json +from hashlib import blake2b +from typing import Any + + +def structural_etag(obj: Any) -> str: + m = blake2b(digest_size=15) + m.update(json.dumps(obj, sort_keys=True).encode("utf8")) + return base64.urlsafe_b64encode(m.digest()).decode("utf8") diff --git a/python/knot_resolver/utils/functional.py b/python/knot_resolver/utils/functional.py new file mode 100644 index 00000000..43abd705 --- /dev/null +++ b/python/knot_resolver/utils/functional.py @@ -0,0 +1,72 @@ +from enum import Enum, auto +from typing import Any, Callable, Generic, Iterable, TypeVar, Union + +T = TypeVar("T") + + +def foldl(oper: Callable[[T, T], T], default: T, arr: Iterable[T]) -> T: + val = default + for x in arr: + val = oper(val, x) + return val + + +def contains_element_matching(cond: Callable[[T], bool], arr: Iterable[T]) -> bool: + return foldl(lambda x, y: x or y, False, map(cond, arr)) + + +def all_matches(cond: Callable[[T], bool], arr: Iterable[T]) -> bool: + return foldl(lambda x, y: x and y, True, map(cond, arr)) + + +Succ = TypeVar("Succ") +Err = TypeVar("Err") + + +class _Status(Enum): + OK = auto() + ERROR = auto() + + +class _ResultSentinel: + pass + + +_RESULT_SENTINEL = _ResultSentinel() + + +class Result(Generic[Succ, Err]): + @staticmethod + def ok(succ: T) -> "Result[T, Any]": + return Result(_Status.OK, succ=succ) + + @staticmethod + def err(err: T) -> "Result[Any, T]": + return Result(_Status.ERROR, err=err) + + def __init__( + self, + status: _Status, + succ: Union[Succ, _ResultSentinel] = _RESULT_SENTINEL, + err: Union[Err, _ResultSentinel] = _RESULT_SENTINEL, + ) -> None: + super().__init__() + self._status: _Status = status + self._succ: Union[_ResultSentinel, Succ] = succ + self._err: Union[_ResultSentinel, Err] = err + + def unwrap(self) -> Succ: + assert self._status is _Status.OK + assert not isinstance(self._succ, _ResultSentinel) + return self._succ + + def unwrap_err(self) -> Err: + assert self._status is _Status.ERROR + assert not isinstance(self._err, _ResultSentinel) + return self._err + + def is_ok(self) -> bool: + return self._status is _Status.OK + + def is_err(self) -> bool: + return self._status is _Status.ERROR diff --git a/python/knot_resolver/utils/modeling/README.md b/python/knot_resolver/utils/modeling/README.md new file mode 100644 index 00000000..97c68b54 --- /dev/null +++ b/python/knot_resolver/utils/modeling/README.md @@ -0,0 +1,155 @@ +# Modeling utils + +These utilities are used to model schemas for data stored in a python dictionary or YAML and JSON format. +The utilities also take care of parsing, validating and creating JSON schemas and basic documentation. + +## Creating schema + +Schema is created using `ConfigSchema` class. Schema structure is specified using annotations. + +```python +from .modeling import ConfigSchema + +class SimpleSchema(ConfigSchema): + integer: int = 5 # a default value can be specified + string: str + boolean: bool +``` +Even more complex types can be used in a schema. Schemas can be also nested. +Words in multi-word names are separated by underscore `_` (e.g. `simple_schema`). + +```python +from typing import Dict, List, Optional, Union + +class ComplexSchema(ConfigSchema): + optional: Optional[str] # this field is optional + union: Union[int, str] # integer and string are both valid + list: List[int] # list of integers + dictionary: Dict[str, bool] = {"key": False} + simple_schema: SimpleSchema # nested schema +``` + + +### Additianal validation + +If a some additional validation needs to be done, there is `_validate()` method for that. +`ValueError` exception should be raised in case of validation error. + +```python +class FieldsSchema(ConfigSchema): + field1: int + field2: int + + def _validate(self) -> None: + if self.field1 > self.field2: + raise ValueError("field1 is bigger than field2") +``` + + +### Additional layer, transformation methods + +It is possible to add layers to schema and use a transformation method between layers to process the value. +Transformation method must be named based on field (`value` in this example) with `_` underscore prefix. +In this example, the `Layer2Schema` is structure for input data and `Layer1Schema` is for result data. + +```python +class Layer1Schema(ConfigSchema): + class Layer2Schema(ConfigSchema): + value: Union[str, int] + + _LAYER = Layer2Schema + + value: int + + def _value(self, obj: Layer2Schema) -> Any: + if isinstance(str, obj.value): + return len(obj.value) # transform str values to int; this is just example + return obj.value +``` + +### Documentation and JSON schema + +Created schema can be documented using simple docstring. Json schema is created by calling `json_schema()` method on schema class. JSON schema includes description from docstring, defaults, etc. + +```python +SimpleSchema(ConfigSchema): + """ + This is description for SimpleSchema itself. + + --- + integer: description for integer field + string: description for string field + boolean: description for boolean field + """ + + integer: int = 5 + string: str + boolean: bool + +json_schema = SimpleSchema.json_schema() +``` + + +## Creating custom type + +Custom types can be made by extending `BaseValueType` class which is integrated to parsing and validating process. +Use `DataValidationError` to rase exception during validation. `object_path` is used to track node in more complex/nested schemas and create useful logging message. + +```python +from .modeling import BaseValueType +from .modeling.exceptions import DataValidationError + +class IntNonNegative(BaseValueType): + def __init__(self, source_value: Any, object_path: str = "/") -> None: + super().__init__(source_value) + if isinstance(source_value, int) and not isinstance(source_value, bool): + if source_value < 0: + raise DataValidationError(f"value {source_value} is negative number.", object_path) + self._value = source_value + else: + raise DataValidationError( + f"expected integer, got '{type(source_value)}'", + object_path, + ) +``` + +For JSON schema you should implement `json_schema` method. +It should return [JSON schema representation](https://json-schema.org/understanding-json-schema/index.html) of the custom type. + +```python + @classmethod + def json_schema(cls: Type["IntNonNegative"]) -> Dict[Any, Any]: + return {"type": "integer", "minimum": 0} +``` + + +## Parsing JSON/YAML + +For example, YAML data for `ComplexSchema` can look like this. +Words in multi-word names are separated by hyphen `-` (e.g. `simple-schema`). + +```yaml +# data.yaml +union: here could also be a number +list: [1,2,3,] +dictionary: + key": false +simple-schema: + integer: 55 + string: this is string + boolean: false +``` + +To parse data from YAML format just use `parse_yaml` function or `parse_json` for JSON format. +Parsed data are stored in a dict-like object that takes care of `-`/`_` conversion. + +```python +from .modeling import parse_yaml + +# read data from file +with open("data.yaml") as f: + str_data = f.read() + +dict_data = parse_yaml(str_data) +validated_data = ComplexSchema(dict_data) +```
\ No newline at end of file diff --git a/python/knot_resolver/utils/modeling/__init__.py b/python/knot_resolver/utils/modeling/__init__.py new file mode 100644 index 00000000..d16f6c12 --- /dev/null +++ b/python/knot_resolver/utils/modeling/__init__.py @@ -0,0 +1,14 @@ +from .base_generic_type_wrapper import BaseGenericTypeWrapper +from .base_schema import BaseSchema, ConfigSchema +from .base_value_type import BaseValueType +from .parsing import parse_json, parse_yaml, try_to_parse + +__all__ = [ + "BaseGenericTypeWrapper", + "BaseValueType", + "BaseSchema", + "ConfigSchema", + "parse_yaml", + "parse_json", + "try_to_parse", +] diff --git a/python/knot_resolver/utils/modeling/base_generic_type_wrapper.py b/python/knot_resolver/utils/modeling/base_generic_type_wrapper.py new file mode 100644 index 00000000..1f2c1767 --- /dev/null +++ b/python/knot_resolver/utils/modeling/base_generic_type_wrapper.py @@ -0,0 +1,9 @@ +from typing import Generic, TypeVar + +from .base_value_type import BaseTypeABC + +T = TypeVar("T") + + +class BaseGenericTypeWrapper(Generic[T], BaseTypeABC): # pylint: disable=abstract-method + pass diff --git a/python/knot_resolver/utils/modeling/base_schema.py b/python/knot_resolver/utils/modeling/base_schema.py new file mode 100644 index 00000000..aca3be05 --- /dev/null +++ b/python/knot_resolver/utils/modeling/base_schema.py @@ -0,0 +1,816 @@ +import enum +import inspect +from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module] +from typing import Any, Callable, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar, Union, cast + +import yaml + +from knot_resolver.utils.functional import all_matches + +from .base_generic_type_wrapper import BaseGenericTypeWrapper +from .base_value_type import BaseValueType +from .exceptions import AggregateDataValidationError, DataDescriptionError, DataValidationError +from .renaming import Renamed, renamed +from .types import ( + get_generic_type_argument, + get_generic_type_arguments, + get_generic_type_wrapper_argument, + get_optional_inner_type, + is_dict, + is_enum, + is_generic_type_wrapper, + is_internal_field_name, + is_list, + is_literal, + is_none_type, + is_optional, + is_tuple, + is_union, +) + +T = TypeVar("T") + + +def is_obj_type(obj: Any, types: Union[type, Tuple[Any, ...], Tuple[type, ...]]) -> bool: + # To check specific type we are using 'type()' instead of 'isinstance()' + # because for example 'bool' is instance of 'int', 'isinstance(False, int)' returns True. + # pylint: disable=unidiomatic-typecheck + if isinstance(types, tuple): + return type(obj) in types + return type(obj) is types + + +class Serializable(ABC): + """ + An interface for making classes serializable to a dictionary (and in turn into a JSON). + """ + + @abstractmethod + def to_dict(self) -> Dict[Any, Any]: + raise NotImplementedError(f"...for class {self.__class__.__name__}") + + @staticmethod + def is_serializable(typ: Type[Any]) -> bool: + return ( + typ in {str, bool, int, float} + or is_none_type(typ) + or is_literal(typ) + or is_dict(typ) + or is_list(typ) + or is_generic_type_wrapper(typ) + or (inspect.isclass(typ) and issubclass(typ, Serializable)) + or (inspect.isclass(typ) and issubclass(typ, BaseValueType)) + or (inspect.isclass(typ) and issubclass(typ, BaseSchema)) + or (is_optional(typ) and Serializable.is_serializable(get_optional_inner_type(typ))) + or (is_union(typ) and all_matches(Serializable.is_serializable, get_generic_type_arguments(typ))) + ) + + @staticmethod + def serialize(obj: Any) -> Any: + if isinstance(obj, Serializable): + return obj.to_dict() + + elif isinstance(obj, (BaseValueType, BaseGenericTypeWrapper)): + o = obj.serialize() + # if Serializable.is_serializable(o): + return Serializable.serialize(o) + # return o + + elif isinstance(obj, list): + res: List[Any] = [Serializable.serialize(i) for i in cast(List[Any], obj)] + return res + + return obj + + +class _lazy_default(Generic[T], Serializable): + """ + Wrapper for default values BaseSchema classes which deffers their instantiation until the schema + itself is being instantiated + """ + + def __init__(self, constructor: Callable[..., T], *args: Any, **kwargs: Any) -> None: + # pylint: disable=[super-init-not-called] + self._func = constructor + self._args = args + self._kwargs = kwargs + + def instantiate(self) -> T: + return self._func(*self._args, **self._kwargs) + + def to_dict(self) -> Dict[Any, Any]: + return Serializable.serialize(self.instantiate()) + + +def lazy_default(constructor: Callable[..., T], *args: Any, **kwargs: Any) -> T: + """We use a factory function because you can't lie about the return type in `__new__`""" + return _lazy_default(constructor, *args, **kwargs) # type: ignore + + +def _split_docstring(docstring: str) -> Tuple[str, Optional[str]]: + """ + Splits docstring into description of the class and description of attributes + """ + + if "---" not in docstring: + return ("\n".join([s.strip() for s in docstring.splitlines()]).strip(), None) + + doc, attrs_doc = docstring.split("---", maxsplit=1) + return ( + "\n".join([s.strip() for s in doc.splitlines()]).strip(), + attrs_doc, + ) + + +def _parse_attrs_docstrings(docstring: str) -> Optional[Dict[str, str]]: + """ + Given a docstring of a BaseSchema, return a dict with descriptions of individual attributes. + """ + + _, attrs_doc = _split_docstring(docstring) + if attrs_doc is None: + return None + + # try to parse it as yaml: + data = yaml.safe_load(attrs_doc) + assert isinstance(data, dict), "Invalid format of attribute description" + return cast(Dict[str, str], data) + + +def _get_properties_schema(typ: Type[Any]) -> Dict[Any, Any]: + schema: Dict[Any, Any] = {} + annot: Dict[str, Any] = typ.__dict__.get("__annotations__", {}) + docstring: str = typ.__dict__.get("__doc__", "") or "" + attribute_documentation = _parse_attrs_docstrings(docstring) + for field_name, python_type in annot.items(): + name = field_name.replace("_", "-") + schema[name] = _describe_type(python_type) + + # description + if attribute_documentation is not None: + if field_name not in attribute_documentation: + raise DataDescriptionError(f"The docstring does not describe field '{field_name}'", str(typ)) + schema[name]["description"] = attribute_documentation[field_name] + del attribute_documentation[field_name] + + # default value + if hasattr(typ, field_name): + assert Serializable.is_serializable( + python_type + ), f"Type '{python_type}' does not appear to be JSON serializable" + schema[name]["default"] = Serializable.serialize(getattr(typ, field_name)) + + if attribute_documentation is not None and len(attribute_documentation) > 0: + raise DataDescriptionError( + f"The docstring describes attributes which are not present - {tuple(attribute_documentation.keys())}", + str(typ), + ) + + return schema + + +def _describe_type(typ: Type[Any]) -> Dict[Any, Any]: + # pylint: disable=too-many-branches + + if inspect.isclass(typ) and issubclass(typ, BaseSchema): + return typ.json_schema(include_schema_definition=False) + + elif inspect.isclass(typ) and issubclass(typ, BaseValueType): + return typ.json_schema() + + elif is_generic_type_wrapper(typ): + wrapped = get_generic_type_wrapper_argument(typ) + return _describe_type(wrapped) + + elif is_none_type(typ): + return {"type": "null"} + + elif typ == int: + return {"type": "integer"} + + elif typ == bool: + return {"type": "boolean"} + + elif typ == str: + return {"type": "string"} + + elif is_literal(typ): + lit = get_generic_type_arguments(typ) + return {"type": "string", "enum": lit} + + elif is_optional(typ): + desc = _describe_type(get_optional_inner_type(typ)) + if "type" in desc: + desc["type"] = [desc["type"], "null"] + return desc + else: + return {"anyOf": [{"type": "null"}, desc]} + + elif is_union(typ): + variants = get_generic_type_arguments(typ) + return {"anyOf": [_describe_type(v) for v in variants]} + + elif is_list(typ): + return {"type": "array", "items": _describe_type(get_generic_type_argument(typ))} + + elif is_dict(typ): + key, val = get_generic_type_arguments(typ) + + if inspect.isclass(key) and issubclass(key, BaseValueType): + assert ( + key.__str__ is not BaseValueType.__str__ + ), "To support derived 'BaseValueType', __str__ must be implemented." + else: + assert key == str, "We currently do not support any other keys then strings" + + return {"type": "object", "additionalProperties": _describe_type(val)} + + elif inspect.isclass(typ) and issubclass(typ, enum.Enum): # same as our is_enum(typ), but inlined for type checker + return {"type": "string", "enum": [str(v) for v in typ]} + + raise NotImplementedError(f"Trying to get JSON schema for type '{typ}', which is not implemented") + + +TSource = Union[None, "BaseSchema", Dict[str, Any]] + + +def _create_untouchable(name: str) -> object: + class _Untouchable: + def __getattribute__(self, item_name: str) -> Any: + raise RuntimeError(f"You are not supposed to access object '{name}'.") + + def __setattr__(self, item_name: str, value: Any) -> None: + raise RuntimeError(f"You are not supposed to access object '{name}'.") + + return _Untouchable() + + +class ObjectMapper: + def _create_tuple(self, tp: Type[Any], obj: Tuple[Any, ...], object_path: str) -> Tuple[Any, ...]: + types = get_generic_type_arguments(tp) + errs: List[DataValidationError] = [] + res: List[Any] = [] + for i, (t, val) in enumerate(zip(types, obj)): + try: + res.append(self.map_object(t, val, object_path=f"{object_path}[{i}]")) + except DataValidationError as e: + errs.append(e) + if len(errs) == 1: + raise errs[0] + elif len(errs) > 1: + raise AggregateDataValidationError(object_path, child_exceptions=errs) + return tuple(res) + + def _create_dict(self, tp: Type[Any], obj: Dict[Any, Any], object_path: str) -> Dict[Any, Any]: + key_type, val_type = get_generic_type_arguments(tp) + try: + errs: List[DataValidationError] = [] + res: Dict[Any, Any] = {} + for key, val in obj.items(): + try: + nkey = self.map_object(key_type, key, object_path=f"{object_path}[{key}]") + nval = self.map_object(val_type, val, object_path=f"{object_path}[{key}]") + res[nkey] = nval + except DataValidationError as e: + errs.append(e) + if len(errs) == 1: + raise errs[0] + elif len(errs) > 1: + raise AggregateDataValidationError(object_path, child_exceptions=errs) + return res + except AttributeError as e: + raise DataValidationError( + f"Expected dict-like object, but failed to access its .items() method. Value was {obj}", object_path + ) from e + + def _create_list(self, tp: Type[Any], obj: List[Any], object_path: str) -> List[Any]: + if isinstance(obj, str): + raise DataValidationError("expected list, got string", object_path) + + inner_type = get_generic_type_argument(tp) + errs: List[DataValidationError] = [] + res: List[Any] = [] + + try: + for i, val in enumerate(obj): + res.append(self.map_object(inner_type, val, object_path=f"{object_path}[{i}]")) + if len(res) == 0: + raise DataValidationError("empty list is not allowed", object_path) + except DataValidationError as e: + errs.append(e) + except TypeError as e: + errs.append(DataValidationError(str(e), object_path)) + + if len(errs) == 1: + raise errs[0] + elif len(errs) > 1: + raise AggregateDataValidationError(object_path, child_exceptions=errs) + return res + + def _create_str(self, obj: Any, object_path: str) -> str: + # we are willing to cast any primitive value to string, but no compound values are allowed + if is_obj_type(obj, (str, float, int)) or isinstance(obj, BaseValueType): + return str(obj) + elif is_obj_type(obj, bool): + raise DataValidationError( + "Expected str, found bool. Be careful, that YAML parsers consider even" + ' "no" and "yes" as a bool. Search for the Norway Problem for more' + " details. And please use quotes explicitly.", + object_path, + ) + else: + raise DataValidationError( + f"expected str (or number that would be cast to string), but found type {type(obj)}", object_path + ) + + def _create_int(self, obj: Any, object_path: str) -> int: + # we don't want to make an int out of anything else than other int + # except for BaseValueType class instances + if is_obj_type(obj, int) or isinstance(obj, BaseValueType): + return int(obj) + raise DataValidationError(f"expected int, found {type(obj)}", object_path) + + def _create_union(self, tp: Type[T], obj: Any, object_path: str) -> T: + variants = get_generic_type_arguments(tp) + errs: List[DataValidationError] = [] + for v in variants: + try: + return self.map_object(v, obj, object_path=object_path) + except DataValidationError as e: + errs.append(e) + + raise DataValidationError("could not parse any of the possible variants", object_path, child_exceptions=errs) + + def _create_optional(self, tp: Type[Optional[T]], obj: Any, object_path: str) -> Optional[T]: + inner: Type[Any] = get_optional_inner_type(tp) + if obj is None: + return None + else: + return self.map_object(inner, obj, object_path=object_path) + + def _create_bool(self, obj: Any, object_path: str) -> bool: + if is_obj_type(obj, bool): + return obj + else: + raise DataValidationError(f"expected bool, found {type(obj)}", object_path) + + def _create_literal(self, tp: Type[Any], obj: Any, object_path: str) -> Any: + expected = get_generic_type_arguments(tp) + if obj in expected: + return obj + else: + raise DataValidationError(f"'{obj}' does not match any of the expected values {expected}", object_path) + + def _create_base_schema_object(self, tp: Type[Any], obj: Any, object_path: str) -> "BaseSchema": + if isinstance(obj, (dict, BaseSchema)): + return tp(obj, object_path=object_path) + raise DataValidationError(f"expected 'dict' or 'NoRenameBaseSchema' object, found '{type(obj)}'", object_path) + + def create_value_type_object(self, tp: Type[Any], obj: Any, object_path: str) -> "BaseValueType": + if isinstance(obj, tp): + # if we already have a custom value type, just pass it through + return obj + else: + # no validation performed, the implementation does it in the constuctor + try: + return tp(obj, object_path=object_path) + except ValueError as e: + if len(e.args) > 0 and isinstance(e.args[0], str): + msg = e.args[0] + else: + msg = f"Failed to validate value against {tp} type" + raise DataValidationError(msg, object_path) from e + + def _create_default(self, obj: Any) -> Any: + if isinstance(obj, _lazy_default): + return obj.instantiate() # type: ignore + else: + return obj + + def map_object( + self, + tp: Type[Any], + obj: Any, + default: Any = ..., + use_default: bool = False, + object_path: str = "/", + ) -> Any: + """ + Given an expected type `cls` and a value object `obj`, return a new object of the given type and map fields of `obj` into it. During the mapping procedure, + runtime type checking is performed. + """ + + # Disabling these checks, because I think it's much more readable as a single function + # and it's not that large at this point. If it got larger, then we should definitely split it + # pylint: disable=too-many-branches,too-many-locals,too-many-statements + + # default values + if obj is None and use_default: + return self._create_default(default) + + # NoneType + elif is_none_type(tp): + if obj is None: + return None + else: + raise DataValidationError(f"expected None, found '{obj}'.", object_path) + + # Optional[T] (could be technically handled by Union[*variants], but this way we have better error reporting) + elif is_optional(tp): + return self._create_optional(tp, obj, object_path) + + # Union[*variants] + elif is_union(tp): + return self._create_union(tp, obj, object_path) + + # after this, there is no place for a None object + elif obj is None: + raise DataValidationError(f"unexpected value 'None' for type {tp}", object_path) + + # int + elif tp == int: + return self._create_int(obj, object_path) + + # str + elif tp == str: + return self._create_str(obj, object_path) + + # bool + elif tp == bool: + return self._create_bool(obj, object_path) + + # float + elif tp == float: + raise NotImplementedError( + "Floating point values are not supported in the object mapper." + " Please implement them and be careful with type coercions" + ) + + # Literal[T] + elif is_literal(tp): + return self._create_literal(tp, obj, object_path) + + # Dict[K,V] + elif is_dict(tp): + return self._create_dict(tp, obj, object_path) + + # any Enums (probably used only internally in DataValidator) + elif is_enum(tp): + if isinstance(obj, tp): + return obj + else: + raise DataValidationError(f"unexpected value '{obj}' for enum '{tp}'", object_path) + + # List[T] + elif is_list(tp): + return self._create_list(tp, obj, object_path) + + # Tuple[A,B,C,D,...] + elif is_tuple(tp): + return self._create_tuple(tp, obj, object_path) + + # type of obj and cls type match + elif is_obj_type(obj, tp): + return obj + + # when the specified type is Any, just return the given value + # on mypy version 1.11.0 comparison-overlap error started popping up + # https://github.com/python/mypy/issues/17665 + elif tp == Any: # type: ignore[comparison-overlap] + return obj + + # BaseValueType subclasses + elif inspect.isclass(tp) and issubclass(tp, BaseValueType): + return self.create_value_type_object(tp, obj, object_path) + + # BaseGenericTypeWrapper subclasses + elif is_generic_type_wrapper(tp): + inner_type = get_generic_type_wrapper_argument(tp) + obj_valid = self.map_object(inner_type, obj, object_path) + return tp(obj_valid, object_path=object_path) # type: ignore + + # nested BaseSchema subclasses + elif inspect.isclass(tp) and issubclass(tp, BaseSchema): + return self._create_base_schema_object(tp, obj, object_path) + + # if the object matches, just pass it through + elif inspect.isclass(tp) and isinstance(obj, tp): + return obj + + # default error handler + else: + raise DataValidationError( + f"Type {tp} cannot be parsed. This is a implementation error. " + "Please fix your types in the class or improve the parser/validator.", + object_path, + ) + + def is_obj_type_valid(self, obj: Any, tp: Type[Any]) -> bool: + """ + Runtime type checking. Validate, that a given object is of a given type. + """ + + try: + self.map_object(tp, obj) + return True + except (DataValidationError, ValueError): + return False + + def _assign_default(self, obj: Any, name: str, python_type: Any, object_path: str) -> None: + cls = obj.__class__ + + try: + default = self._create_default(getattr(cls, name, None)) + except ValueError as e: + raise DataValidationError(str(e), f"{object_path}/{name}") from e + + value = self.map_object(python_type, default, object_path=f"{object_path}/{name}") + setattr(obj, name, value) + + def _assign_field(self, obj: Any, name: str, python_type: Any, value: Any, object_path: str) -> None: + value = self.map_object(python_type, value, object_path=f"{object_path}/{name}") + setattr(obj, name, value) + + def _assign_fields(self, obj: Any, source: Union[Dict[str, Any], "BaseSchema", None], object_path: str) -> Set[str]: + """ + Order of assignment: + 1. all direct assignments + 2. assignments with conversion method + """ + cls = obj.__class__ + annot = cls.__dict__.get("__annotations__", {}) + errs: List[DataValidationError] = [] + + used_keys: Set[str] = set() + for name, python_type in annot.items(): + try: + if is_internal_field_name(name): + continue + + # populate field + if source is None: + self._assign_default(obj, name, python_type, object_path) + + # check for invalid configuration with both transformation function and default value + elif hasattr(obj, f"_{name}") and hasattr(obj, name): + raise RuntimeError( + f"Field '{obj.__class__.__name__}.{name}' has default value and transformation function at" + " the same time. That is now allowed. Store the default in the transformation function." + ) + + # there is a transformation function to create the value + elif hasattr(obj, f"_{name}") and callable(getattr(obj, f"_{name}")): + val = self._get_converted_value(obj, name, source, object_path) + self._assign_field(obj, name, python_type, val, object_path) + used_keys.add(name) + + # source just contains the value + elif name in source: + val = source[name] + self._assign_field(obj, name, python_type, val, object_path) + used_keys.add(name) + + # there is a default value, or the type is optional => store the default or null + elif hasattr(obj, name) or is_optional(python_type): + self._assign_default(obj, name, python_type, object_path) + + # we expected a value but it was not there + else: + errs.append(DataValidationError(f"missing attribute '{name}'.", object_path)) + except DataValidationError as e: + errs.append(e) + + if len(errs) == 1: + raise errs[0] + elif len(errs) > 1: + raise AggregateDataValidationError(object_path, errs) + return used_keys + + def _get_converted_value(self, obj: Any, key: str, source: TSource, object_path: str) -> Any: + """ + Get a value of a field by invoking appropriate transformation function. + """ + try: + func = getattr(obj.__class__, f"_{key}") + argc = len(inspect.signature(func).parameters) + if argc == 1: + # it is a static method + return func(source) + elif argc == 2: + # it is a instance method + return func(_create_untouchable("obj"), source) + else: + raise RuntimeError("Transformation function has wrong number of arguments") + except ValueError as e: + if len(e.args) > 0 and isinstance(e.args[0], str): + msg = e.args[0] + else: + msg = "Failed to validate value type" + raise DataValidationError(msg, object_path) from e + + def object_constructor(self, obj: Any, source: Union["BaseSchema", Dict[Any, Any]], object_path: str) -> None: + """ + Delegated constructor for the NoRenameBaseSchema class. + + The reason this method is delegated to the mapper is due to renaming. Like this, we don't have to + worry about a different BaseSchema class, when we want to have dynamically renamed fields. + """ + # As this is a delegated constructor, we must ignore protected access warnings + # pylint: disable=protected-access + + # sanity check + if not isinstance(source, (BaseSchema, dict)): # type: ignore + raise DataValidationError(f"expected dict-like object, found '{type(source)}'", object_path) + + # construct lower level schema first if configured to do so + if obj._LAYER is not None: + source = obj._LAYER(source, object_path=object_path) # pylint: disable=not-callable + + # assign fields + used_keys = self._assign_fields(obj, source, object_path) + + # check for unused keys in the source object + if source and not isinstance(source, BaseSchema): + unused = source.keys() - used_keys + if len(unused) > 0: + keys = ", ".join((f"'{u}'" for u in unused)) + raise DataValidationError( + f"unexpected extra key(s) {keys}", + object_path, + ) + + # validate the constructed value + try: + obj._validate() + except ValueError as e: + raise DataValidationError(e.args[0] if len(e.args) > 0 else "Validation error", object_path or "/") from e + + +class BaseSchema(Serializable): + """ + Base class for modeling configuration schema. It somewhat resembles standard dataclasses with additional + functionality: + + * type validation + * data conversion + + To create an instance of this class, you have to provide source data in the form of dict-like object. + Generally, raw dict or another `BaseSchema` instance. The provided data object is traversed, transformed + and validated before assigned to the appropriate fields (attributes). + + Fields (attributes) + =================== + + The fields (or attributes) of the class are defined the same way as in a dataclass by creating a class-level + type-annotated fields. An example of that is: + + class A(BaseSchema): + awesome_number: int + + If your `BaseSchema` instance has a field with type of a BaseSchema, its value is recursively created + from the nested input data. This way, you can specify a complex tree of BaseSchema's and use the root + BaseSchema to create instance of everything. + + Transformation + ============== + + You can provide the BaseSchema class with a field and a function with the same name, but starting with + underscore ('_'). For example, you could have field called `awesome_number` and function called + `_awesome_number(self, source)`. The function takes one argument - the source data (optionally with self, + but you are not supposed to touch that). It can read any data from the source object and return a value of + an appropriate type, which will be assigned to the field `awesome_number`. If you want to report an error + during validation, raise a `ValueError` exception. + + Using this, you can convert any input values into any type and field you want. To make the conversion easier + to write, you could also specify a special class variable called `_LAYER` pointing to another + BaseSchema class. This causes the source object to be first parsed as the specified additional layer of BaseSchema and after that + used a source for this class. This therefore allows nesting of transformation functions. + + Validation + ========== + + All assignments to fields during object construction are checked at runtime for proper types. This means, + you are free to use an untrusted source object and turn it into a data structure, where you are sure what + is what. + + You can also define a `_validate` method, which will be called once the whole data structure is built. You + can validate the data in there and raise a `ValueError`, if they are invalid. + + Default values + ============== + + If you create a field with a value, it will be used as a default value whenever the data in source object + are not present. As a special case, default value for Optional type is None if not specified otherwise. You + are not allowed to have a field with a default value and a transformation function at once. + + Example + ======= + + See tests/utils/test_modelling.py for example usage. + """ + + _LAYER: Optional[Type["BaseSchema"]] = None + _MAPPER: ObjectMapper = ObjectMapper() + + def __init__(self, source: TSource = None, object_path: str = ""): # pylint: disable=[super-init-not-called] + # save source data (and drop information about nullness) + source = source or {} + self.__source: Union[Dict[str, Any], BaseSchema] = source + + # delegate the rest of the constructor + self._MAPPER.object_constructor(self, source, object_path) + + def get_unparsed_data(self) -> Dict[str, Any]: + if isinstance(self.__source, BaseSchema): + return self.__source.get_unparsed_data() + elif isinstance(self.__source, Renamed): + return self.__source.original() + else: + return self.__source + + def __getitem__(self, key: str) -> Any: + if not hasattr(self, key): + raise RuntimeError(f"Object '{self}' of type '{type(self)}' does not have field named '{key}'") + return getattr(self, key) + + def __contains__(self, item: Any) -> bool: + return hasattr(self, item) + + def _validate(self) -> None: + """ + Validation procedure called after all field are assigned. Should throw a ValueError in case of failure. + """ + + def __eq__(self, o: object) -> bool: + cls = self.__class__ + if not isinstance(o, cls): + return False + + annot = cls.__dict__.get("__annotations__", {}) + for name in annot.keys(): + if getattr(self, name) != getattr(o, name): + return False + + return True + + @classmethod + def json_schema(cls: Type["BaseSchema"], include_schema_definition: bool = True) -> Dict[Any, Any]: + if cls._LAYER is not None: + return cls._LAYER.json_schema(include_schema_definition=include_schema_definition) + + schema: Dict[Any, Any] = {} + if include_schema_definition: + schema["$schema"] = "https://json-schema.org/draft/2020-12/schema" + if cls.__doc__ is not None: + schema["description"] = _split_docstring(cls.__doc__)[0] + schema["type"] = "object" + schema["properties"] = _get_properties_schema(cls) + + return schema + + def to_dict(self) -> Dict[Any, Any]: + res: Dict[Any, Any] = {} + cls = self.__class__ + annot = cls.__dict__.get("__annotations__", {}) + + for name in annot: + res[name] = Serializable.serialize(getattr(self, name)) + return res + + +class RenamingObjectMapper(ObjectMapper): + """ + Same as object mapper, but it uses collection wrappers from the module `renamed` to perform dynamic field renaming. + + More specifically: + - it renames all properties in (nested) objects + - it does not rename keys in dictionaries + """ + + def _create_dict(self, tp: Type[Any], obj: Dict[Any, Any], object_path: str) -> Dict[Any, Any]: + if isinstance(obj, Renamed): + obj = obj.original() + return super()._create_dict(tp, obj, object_path) + + def _create_base_schema_object(self, tp: Type[Any], obj: Any, object_path: str) -> "BaseSchema": + if isinstance(obj, dict): + obj = renamed(obj) + return super()._create_base_schema_object(tp, obj, object_path) + + def object_constructor(self, obj: Any, source: Union["BaseSchema", Dict[Any, Any]], object_path: str) -> None: + if isinstance(source, dict): + source = renamed(source) + return super().object_constructor(obj, source, object_path) + + +# export as a standalone functions for simplicity compatibility +is_obj_type_valid = ObjectMapper().is_obj_type_valid +map_object = ObjectMapper().map_object + + +class ConfigSchema(BaseSchema): + """ + Same as BaseSchema, but maps with RenamingObjectMapper + """ + + _MAPPER: ObjectMapper = RenamingObjectMapper() diff --git a/python/knot_resolver/utils/modeling/base_value_type.py b/python/knot_resolver/utils/modeling/base_value_type.py new file mode 100644 index 00000000..dff4a3fe --- /dev/null +++ b/python/knot_resolver/utils/modeling/base_value_type.py @@ -0,0 +1,45 @@ +from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module] +from typing import Any, Dict, Type + + +class BaseTypeABC(ABC): + @abstractmethod + def __init__(self, source_value: Any, object_path: str = "/") -> None: + pass + + @abstractmethod + def __int__(self) -> int: + raise NotImplementedError(f" return 'int()' value for {type(self).__name__} is not implemented.") + + @abstractmethod + def __str__(self) -> str: + raise NotImplementedError(f"return 'str()' value for {type(self).__name__} is not implemented.") + + @abstractmethod + def serialize(self) -> Any: + """ + Used for dumping configuration. Returns a JSON-serializable object from which the object + can be recreated again using the constructor. + + It's not necessary to return the same structure that was given as an input. It only has + to be the same semantically. + """ + raise NotImplementedError(f"{type(self).__name__}'s' 'serialize()' not implemented.") + + +class BaseValueType(BaseTypeABC): + """ + Subclasses of this class can be used as type annotations in 'DataParser'. When a value + is being parsed from a serialized format (e.g. JSON/YAML), an object will be created by + calling the constructor of the appropriate type on the field value. The only limitation + is that the value MUST NOT be `None`. + + There is no validation done on the wrapped value. The only condition is that + it can't be `None`. If you want to perform any validation during creation, + raise a `ValueError` in case of errors. + """ + + @classmethod + @abstractmethod + def json_schema(cls: Type["BaseValueType"]) -> Dict[Any, Any]: + raise NotImplementedError() diff --git a/python/knot_resolver/utils/modeling/exceptions.py b/python/knot_resolver/utils/modeling/exceptions.py new file mode 100644 index 00000000..ea057339 --- /dev/null +++ b/python/knot_resolver/utils/modeling/exceptions.py @@ -0,0 +1,63 @@ +from typing import Iterable, List + +from knot_resolver.manager.exceptions import KresManagerException + + +class DataModelingBaseException(KresManagerException): + """ + Base class for all exceptions used in modelling. + """ + + +class DataParsingError(DataModelingBaseException): + pass + + +class DataDescriptionError(DataModelingBaseException): + pass + + +class DataValidationError(DataModelingBaseException): + def __init__(self, msg: str, tree_path: str, child_exceptions: "Iterable[DataValidationError]" = tuple()) -> None: + super().__init__(msg) + self._tree_path = tree_path.replace("_", "-") + self._child_exceptions = child_exceptions + + def where(self) -> str: + return self._tree_path + + def msg(self): + return f"[{self.where()}] {super().__str__()}" + + def recursive_msg(self, indentation_level: int = 0) -> str: + msg_parts: List[str] = [] + + if indentation_level == 0: + indentation_level += 1 + msg_parts.append("Configuration validation error detected:") + + INDENT = indentation_level * "\t" + msg_parts.append(f"{INDENT}{self.msg()}") + + for c in self._child_exceptions: + msg_parts.append(c.recursive_msg(indentation_level + 1)) + return "\n".join(msg_parts) + + def __str__(self) -> str: + return self.recursive_msg() + + +class AggregateDataValidationError(DataValidationError): + def __init__(self, object_path: str, child_exceptions: "Iterable[DataValidationError]") -> None: + super().__init__("error due to lower level exceptions", object_path, child_exceptions) + + def recursive_msg(self, indentation_level: int = 0) -> str: + inc = 0 + msg_parts: List[str] = [] + if indentation_level == 0: + inc = 1 + msg_parts.append("Configuration validation errors detected:") + + for c in self._child_exceptions: + msg_parts.append(c.recursive_msg(indentation_level + inc)) + return "\n".join(msg_parts) diff --git a/python/knot_resolver/utils/modeling/json_pointer.py b/python/knot_resolver/utils/modeling/json_pointer.py new file mode 100644 index 00000000..a60ba5d1 --- /dev/null +++ b/python/knot_resolver/utils/modeling/json_pointer.py @@ -0,0 +1,88 @@ +""" +Implements JSON pointer resolution based on RFC 6901: +https://www.rfc-editor.org/rfc/rfc6901 +""" + +from typing import Any, Optional, Tuple, Union + +# JSONPtrAddressable = Optional[Union[Dict[str, "JSONPtrAddressable"], List["JSONPtrAddressable"], int, float, bool, str, None]] +JSONPtrAddressable = Any # the recursive definition above is not valid :( + + +class _JSONPtr: + @staticmethod + def _decode_token(token: str) -> str: + """ + Resolves escaped characters ~ and / + """ + + # the order of the replace statements is important, do not change without + # consulting the RFC + return token.replace("~1", "/").replace("~0", "~") + + @staticmethod + def _encode_token(token: str) -> str: + return token.replace("~", "~0").replace("/", "~1") + + def __init__(self, ptr: str): + if ptr == "": + # pointer to the root + self.tokens = [] + + else: + if ptr[0] != "/": + raise SyntaxError( + f"JSON pointer '{ptr}' invalid: the first character MUST be '/' or the pointer must be empty" + ) + + ptr = ptr[1:] + self.tokens = [_JSONPtr._decode_token(tok) for tok in ptr.split("/")] + + def resolve( + self, obj: JSONPtrAddressable + ) -> Tuple[Optional[JSONPtrAddressable], JSONPtrAddressable, Union[str, int, None]]: + """ + Returns (Optional[parent], Optional[direct value], key of value in the parent object) + """ + + parent: Optional[JSONPtrAddressable] = None + current = obj + current_ptr = "" + token: Union[int, str, None] = None + + for token in self.tokens: + if current is None: + raise ValueError( + f"JSON pointer cannot reference nested non-existent object: object at ptr '{current_ptr}' already points to None, cannot nest deeper with token '{token}'" + ) + + elif isinstance(current, (bool, int, float, str)): + raise ValueError(f"object at '{current_ptr}' is a scalar, JSON pointer cannot point into it") + + else: + parent = current + if isinstance(current, list): + if token == "-": + current = None + else: + try: + token = int(token) + current = current[token] + except ValueError as e: + raise ValueError( + f"invalid JSON pointer: list '{current_ptr}' require numbers as keys, instead got '{token}'" + ) from e + + elif isinstance(current, dict): + current = current.get(token, None) + + current_ptr += f"/{token}" + + return parent, current, token + + +def json_ptr_resolve( + obj: JSONPtrAddressable, + ptr: str, +) -> Tuple[Optional[JSONPtrAddressable], Optional[JSONPtrAddressable], Union[str, int, None]]: + return _JSONPtr(ptr).resolve(obj) diff --git a/python/knot_resolver/utils/modeling/parsing.py b/python/knot_resolver/utils/modeling/parsing.py new file mode 100644 index 00000000..185a53a1 --- /dev/null +++ b/python/knot_resolver/utils/modeling/parsing.py @@ -0,0 +1,99 @@ +import json +from enum import Enum, auto +from typing import Any, Dict, List, Optional, Tuple, Union + +import yaml +from yaml.constructor import ConstructorError +from yaml.nodes import MappingNode + +from .exceptions import DataParsingError +from .renaming import Renamed, renamed + + +# custom hook for 'json.loads()' to detect duplicate keys in data +# source: https://stackoverflow.com/q/14902299/12858520 +def _json_raise_duplicates(pairs: List[Tuple[Any, Any]]) -> Optional[Any]: + dict_out: Dict[Any, Any] = {} + for key, val in pairs: + if key in dict_out: + raise DataParsingError(f"Duplicate attribute key detected: {key}") + dict_out[key] = val + return dict_out + + +# custom loader for 'yaml.load()' to detect duplicate keys in data +# source: https://gist.github.com/pypt/94d747fe5180851196eb +class _RaiseDuplicatesLoader(yaml.SafeLoader): + def construct_mapping(self, node: Union[MappingNode, Any], deep: bool = False) -> Dict[Any, Any]: + if not isinstance(node, MappingNode): + raise ConstructorError(None, None, f"expected a mapping node, but found {node.id}", node.start_mark) + mapping: Dict[Any, Any] = {} + for key_node, value_node in node.value: + key = self.construct_object(key_node, deep=deep) # type: ignore + # we need to check, that the key object can be used in a hash table + try: + _ = hash(key) # type: ignore + except TypeError as exc: + raise ConstructorError( + "while constructing a mapping", + node.start_mark, + f"found unacceptable key ({exc})", + key_node.start_mark, + ) from exc + + # check for duplicate keys + if key in mapping: + raise DataParsingError(f"duplicate key detected: {key_node.start_mark}") + value = self.construct_object(value_node, deep=deep) # type: ignore + mapping[key] = value + return mapping + + +class DataFormat(Enum): + YAML = auto() + JSON = auto() + + def parse_to_dict(self, text: str) -> Any: + if self is DataFormat.YAML: + # RaiseDuplicatesLoader extends yaml.SafeLoader, so this should be safe + # https://python.land/data-processing/python-yaml#PyYAML_safe_load_vs_load + return renamed(yaml.load(text, Loader=_RaiseDuplicatesLoader)) # type: ignore + elif self is DataFormat.JSON: + return renamed(json.loads(text, object_pairs_hook=_json_raise_duplicates)) + else: + raise NotImplementedError(f"Parsing of format '{self}' is not implemented") + + def dict_dump(self, data: Union[Dict[str, Any], Renamed], indent: Optional[int] = None) -> str: + if isinstance(data, Renamed): + data = data.original() + + if self is DataFormat.YAML: + return yaml.safe_dump(data, indent=indent) # type: ignore + elif self is DataFormat.JSON: + return json.dumps(data, indent=indent) + else: + raise NotImplementedError(f"Exporting to '{self}' format is not implemented") + + +def parse_yaml(data: str) -> Any: + return DataFormat.YAML.parse_to_dict(data) + + +def parse_json(data: str) -> Any: + return DataFormat.JSON.parse_to_dict(data) + + +def try_to_parse(data: str) -> Any: + """Attempt to parse the data as a JSON or YAML string.""" + + try: + return parse_json(data) + except json.JSONDecodeError as je: + try: + return parse_yaml(data) + except yaml.YAMLError as ye: + # We do not raise-from here because there are two possible causes + # and we may not know which one is the actual one. + raise DataParsingError( # pylint: disable=raise-missing-from + f"failed to parse data, JSON: {je}, YAML: {ye}" + ) diff --git a/python/knot_resolver/utils/modeling/query.py b/python/knot_resolver/utils/modeling/query.py new file mode 100644 index 00000000..aab0be0e --- /dev/null +++ b/python/knot_resolver/utils/modeling/query.py @@ -0,0 +1,183 @@ +import copy +from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module] +from typing import Any, List, Optional, Tuple, Union + +from typing_extensions import Literal + +from knot_resolver.utils.modeling.base_schema import BaseSchema, map_object +from knot_resolver.utils.modeling.json_pointer import json_ptr_resolve + + +class PatchError(Exception): + pass + + +class Op(BaseSchema, ABC): + @abstractmethod + def eval(self, fakeroot: Any) -> Any: + """ + modifies the given fakeroot, returns a new one + """ + + def _resolve_ptr(self, fakeroot: Any, ptr: str) -> Tuple[Any, Any, Union[str, int, None]]: + # Lookup tree part based on the given JSON pointer + parent, obj, token = json_ptr_resolve(fakeroot["root"], ptr) + + # the lookup was on pure data, wrap the results in QueryTree + if parent is None: + parent = fakeroot + token = "root" + + assert token is not None + + return parent, obj, token + + +class AddOp(Op): + op: Literal["add"] + path: str + value: Any + + def eval(self, fakeroot: Any) -> Any: + parent, _obj, token = self._resolve_ptr(fakeroot, self.path) + + if isinstance(parent, dict): + parent[token] = self.value + elif isinstance(parent, list): + if token == "-": + parent.append(self.value) + else: + assert isinstance(token, int) + parent.insert(token, self.value) + else: + assert False, "never happens" + + return fakeroot + + +class RemoveOp(Op): + op: Literal["remove"] + path: str + + def eval(self, fakeroot: Any) -> Any: + parent, _obj, token = self._resolve_ptr(fakeroot, self.path) + del parent[token] + return fakeroot + + +class ReplaceOp(Op): + op: Literal["replace"] + path: str + value: str + + def eval(self, fakeroot: Any) -> Any: + parent, obj, token = self._resolve_ptr(fakeroot, self.path) + + if obj is None: + raise PatchError("the value you are trying to replace is null") + parent[token] = self.value + return fakeroot + + +class MoveOp(Op): + op: Literal["move"] + source: str + path: str + + def _source(self, source): + if "from" not in source: + raise ValueError("missing property 'from' in 'move' JSON patch operation") + return str(source["from"]) + + def eval(self, fakeroot: Any) -> Any: + if self.path.startswith(self.source): + raise PatchError("can't move value into itself") + + _parent, obj, _token = self._resolve_ptr(fakeroot, self.source) + newobj = copy.deepcopy(obj) + + fakeroot = RemoveOp({"op": "remove", "path": self.source}).eval(fakeroot) + fakeroot = AddOp({"path": self.path, "value": newobj, "op": "add"}).eval(fakeroot) + return fakeroot + + +class CopyOp(Op): + op: Literal["copy"] + source: str + path: str + + def _source(self, source): + if "from" not in source: + raise ValueError("missing property 'from' in 'copy' JSON patch operation") + return str(source["from"]) + + def eval(self, fakeroot: Any) -> Any: + _parent, obj, _token = self._resolve_ptr(fakeroot, self.source) + newobj = copy.deepcopy(obj) + + fakeroot = AddOp({"path": self.path, "value": newobj, "op": "add"}).eval(fakeroot) + return fakeroot + + +class TestOp(Op): + op: Literal["test"] + path: str + value: Any + + def eval(self, fakeroot: Any) -> Any: + _parent, obj, _token = self._resolve_ptr(fakeroot, self.path) + + if obj != self.value: + raise PatchError("test failed") + + return fakeroot + + +def query( + original: Any, method: Literal["get", "delete", "put", "patch"], ptr: str, payload: Any +) -> Tuple[Any, Optional[Any]]: + ######################################## + # Prepare data we will be working on + + # First of all, we consider the original data to be immutable. So we need to make a copy + # in order to freely mutate them + dataroot = copy.deepcopy(original) + + # To simplify referencing the root, create a fake root node + fakeroot = {"root": dataroot} + + ######################################### + # Handle the actual requested operation + + # get = return what the path selector picks + if method == "get": + parent, obj, token = json_ptr_resolve(fakeroot, f"/root{ptr}") + return fakeroot["root"], obj + + elif method == "delete": + fakeroot = RemoveOp({"op": "remove", "path": ptr}).eval(fakeroot) + return fakeroot["root"], None + + elif method == "put": + parent, obj, token = json_ptr_resolve(fakeroot, f"/root{ptr}") + assert parent is not None # we know this due to the fakeroot + if isinstance(parent, list) and token == "-": + parent.append(payload) + else: + parent[token] = payload + return fakeroot["root"], None + + elif method == "patch": + tp = List[Union[AddOp, RemoveOp, MoveOp, CopyOp, TestOp, ReplaceOp]] + transaction: tp = map_object(tp, payload) + + for i, op in enumerate(transaction): + try: + fakeroot = op.eval(fakeroot) + except PatchError as e: + raise ValueError(f"json patch transaction failed on step {i}") from e + + return fakeroot["root"], None + + else: + assert False, "invalid operation, never happens" diff --git a/python/knot_resolver/utils/modeling/renaming.py b/python/knot_resolver/utils/modeling/renaming.py new file mode 100644 index 00000000..2420ed04 --- /dev/null +++ b/python/knot_resolver/utils/modeling/renaming.py @@ -0,0 +1,90 @@ +""" +This module implements a standard dict and list alternatives, which can dynamically rename its keys replacing `-` with `_`. +They persist in nested data structes, meaning that if you try to obtain a dict from Renamed variant, you will actually +get RenamedDict back instead. + +Usage: + +d = dict() +l = list() + +rd = renamed(d) +rl = renamed(l) + +assert isinstance(rd, Renamed) == True +assert l = rl.original() +""" + +from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module] +from typing import Any, Dict, List, TypeVar + + +class Renamed(ABC): + @abstractmethod + def original(self) -> Any: + """ + Returns a data structure, which is the source without dynamic renamings + """ + + @staticmethod + def map_public_to_private(name: Any) -> Any: + if isinstance(name, str): + return name.replace("_", "-") + return name + + @staticmethod + def map_private_to_public(name: Any) -> Any: + if isinstance(name, str): + return name.replace("-", "_") + return name + + +K = TypeVar("K") +V = TypeVar("V") + + +class RenamedDict(Dict[K, V], Renamed): + def keys(self) -> Any: + keys = super().keys() + return {Renamed.map_private_to_public(key) for key in keys} + + def __getitem__(self, key: K) -> V: + key = Renamed.map_public_to_private(key) + res = super().__getitem__(key) + return renamed(res) + + def __setitem__(self, key: K, value: V) -> None: + key = Renamed.map_public_to_private(key) + return super().__setitem__(key, value) + + def __contains__(self, key: object) -> bool: + key = Renamed.map_public_to_private(key) + return super().__contains__(key) + + def items(self) -> Any: + for k, v in super().items(): + yield Renamed.map_private_to_public(k), renamed(v) + + def original(self) -> Dict[K, V]: + return dict(super().items()) + + +class RenamedList(List[V], Renamed): # type: ignore + def __getitem__(self, key: Any) -> Any: + res = super().__getitem__(key) + return renamed(res) + + def original(self) -> Any: + return list(super().__iter__()) + + +def renamed(obj: Any) -> Any: + if isinstance(obj, dict): + return RenamedDict(**obj) + elif isinstance(obj, list): + return RenamedList(obj) + else: + return obj + + +__all__ = ["renamed", "Renamed"] diff --git a/python/knot_resolver/utils/modeling/types.py b/python/knot_resolver/utils/modeling/types.py new file mode 100644 index 00000000..4ce9aecc --- /dev/null +++ b/python/knot_resolver/utils/modeling/types.py @@ -0,0 +1,105 @@ +# pylint: disable=comparison-with-callable + + +import enum +import inspect +import sys +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union + +from typing_extensions import Literal + +from .base_generic_type_wrapper import BaseGenericTypeWrapper + +NoneType = type(None) + + +def is_optional(tp: Any) -> bool: + origin = getattr(tp, "__origin__", None) + args = get_generic_type_arguments(tp) + + return origin == Union and len(args) == 2 and args[1] == NoneType # type: ignore + + +def is_dict(tp: Any) -> bool: + return getattr(tp, "__origin__", None) in (Dict, dict) + + +def is_enum(tp: Any) -> bool: + return inspect.isclass(tp) and issubclass(tp, enum.Enum) + + +def is_list(tp: Any) -> bool: + return getattr(tp, "__origin__", None) in (List, list) + + +def is_tuple(tp: Any) -> bool: + return getattr(tp, "__origin__", None) in (Tuple, tuple) + + +def is_union(tp: Any) -> bool: + """Returns true even for optional types, because they are just a Union[T, NoneType]""" + return getattr(tp, "__origin__", None) == Union # type: ignore + + +def is_literal(tp: Any) -> bool: + if sys.version_info.minor == 6: + return isinstance(tp, type(Literal)) + else: + return getattr(tp, "__origin__", None) == Literal + + +def is_generic_type_wrapper(tp: Any) -> bool: + orig = getattr(tp, "__origin__", None) + return inspect.isclass(orig) and issubclass(orig, BaseGenericTypeWrapper) + + +def get_generic_type_arguments(tp: Any) -> List[Any]: + default: List[Any] = [] + if sys.version_info.minor == 6 and is_literal(tp): + return getattr(tp, "__values__") + else: + return getattr(tp, "__args__", default) + + +def get_generic_type_argument(tp: Any) -> Any: + """same as function get_generic_type_arguments, but expects just one type argument""" + + args = get_generic_type_arguments(tp) + assert len(args) == 1 + return args[0] + + +def get_generic_type_wrapper_argument(tp: Type["BaseGenericTypeWrapper[Any]"]) -> Any: + assert hasattr(tp, "__origin__") + origin = getattr(tp, "__origin__") + + assert hasattr(origin, "__orig_bases__") + orig_base: List[Any] = getattr(origin, "__orig_bases__", [])[0] + + arg = get_generic_type_argument(tp) + return get_generic_type_argument(orig_base[arg]) + + +def is_none_type(tp: Any) -> bool: + return tp is None or tp == NoneType + + +def get_attr_type(obj: Any, attr_name: str) -> Any: + assert hasattr(obj, attr_name) + assert hasattr(obj, "__annotations__") + annot = getattr(type(obj), "__annotations__") + assert attr_name in annot + return annot[attr_name] + + +T = TypeVar("T") + + +def get_optional_inner_type(optional: Type[Optional[T]]) -> Type[T]: + assert is_optional(optional) + t: Type[T] = get_generic_type_arguments(optional)[0] + return t + + +def is_internal_field_name(field_name: str) -> bool: + return field_name.startswith("_") diff --git a/python/knot_resolver/utils/requests.py b/python/knot_resolver/utils/requests.py new file mode 100644 index 00000000..e52e54a3 --- /dev/null +++ b/python/knot_resolver/utils/requests.py @@ -0,0 +1,135 @@ +import errno +import socket +import sys +from http.client import HTTPConnection +from typing import Any, Optional, Union +from urllib.error import HTTPError, URLError +from urllib.parse import quote, unquote, urlparse +from urllib.request import AbstractHTTPHandler, Request, build_opener, install_opener, urlopen + +from typing_extensions import Literal + + +class SocketDesc: + def __init__(self, socket_def: str, source: str): + self.source = source + if ":" in socket_def: + # `socket_def` contains a schema, probably already URI-formatted, use directly + self.uri = socket_def + else: + # `socket_def` is probably a path, convert to URI + self.uri = f'http+unix://{quote(socket_def, safe="")}' + + while self.uri.endswith("/"): + self.uri = self.uri[:-1] + + +class Response: + def __init__(self, status: int, body: str) -> None: + self.status = status + self.body = body + + def __repr__(self) -> str: + return f"status: {self.status}\nbody:\n{self.body}" + + +def _print_conn_error(error_desc: str, url: str, socket_source: str) -> None: + host: str + try: + parsed_url = urlparse(url) + host = unquote(parsed_url.hostname or "(Unknown)") + except Exception as e: + host = f"(Invalid URL: {e})" + msg = f""" +{error_desc} +\tURL: {url} +\tHost/Path: {host} +\tSourced from: {socket_source} +Is the URL correct? +\tUnix socket would start with http+unix:// and URL encoded path. +\tInet sockets would start with http:// and domain or ip + """ + print(msg, file=sys.stderr) + + +def request( + socket_desc: SocketDesc, + method: Literal["GET", "POST", "HEAD", "PUT", "DELETE"], + path: str, + body: Optional[str] = None, + content_type: str = "application/json", +) -> Response: + while path.startswith("/"): + path = path[1:] + url = f"{socket_desc.uri}/{path}" + req = Request( + url, + method=method, + data=body.encode("utf8") if body is not None else None, + headers={"Content-Type": content_type}, + ) + # req.add_header("Authorization", _authorization_header) + + timeout_m = 5 # minutes + try: + with urlopen(req, timeout=timeout_m * 60) as response: + return Response(response.status, response.read().decode("utf8")) + except HTTPError as err: + return Response(err.code, err.read().decode("utf8")) + except URLError as err: + if err.errno == errno.ECONNREFUSED or isinstance(err.reason, ConnectionRefusedError): + _print_conn_error("Connection refused.", url, socket_desc.source) + elif err.errno == errno.ENOENT or isinstance(err.reason, FileNotFoundError): + _print_conn_error("No such file or directory.", url, socket_desc.source) + else: + _print_conn_error(str(err), url, socket_desc.source) + sys.exit(1) + except (TimeoutError, socket.timeout): + _print_conn_error( + f"Connection timed out after {timeout_m} minutes." + "\nIt does not mean that the operation necessarily failed." + "\nSee Knot Resolver's log for more information.", + url, + socket_desc.source, + ) + sys.exit(1) + + +# Code heavily inspired by requests-unixsocket +# https://github.com/msabramo/requests-unixsocket/blob/master/requests_unixsocket/adapters.py +class UnixHTTPConnection(HTTPConnection): + def __init__(self, unix_socket_url: str, timeout: Union[int, float] = 60): + """Create an HTTP connection to a unix domain socket + :param unix_socket_url: A URL with a scheme of 'http+unix' and the + netloc is a percent-encoded path to a unix domain socket. E.g.: + 'http+unix://%2Ftmp%2Fprofilesvc.sock/status/pid' + """ + super().__init__("localhost", timeout=timeout) + self.unix_socket_path = unix_socket_url + self.timeout = timeout + self.sock: Optional[socket.socket] = None + + def __del__(self): # base class does not have d'tor + if self.sock: + self.sock.close() + + def connect(self): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(self.timeout) + sock.connect(self.unix_socket_path) + self.sock = sock + + +class UnixHTTPHandler(AbstractHTTPHandler): + def __init__(self) -> None: + super().__init__() + + def open_(self: UnixHTTPHandler, req: Any) -> Any: + return self.do_open(UnixHTTPConnection, req) # type: ignore[arg-type] + + setattr(UnixHTTPHandler, "http+unix_open", open_) + setattr(UnixHTTPHandler, "http+unix_request", AbstractHTTPHandler.do_request_) + + +opener = build_opener(UnixHTTPHandler()) +install_opener(opener) diff --git a/python/knot_resolver/utils/systemd_notify.py b/python/knot_resolver/utils/systemd_notify.py new file mode 100644 index 00000000..44e8dee1 --- /dev/null +++ b/python/knot_resolver/utils/systemd_notify.py @@ -0,0 +1,54 @@ +import enum +import logging +import os +import socket + +logger = logging.getLogger(__name__) + + +class _Status(enum.Enum): + NOT_INITIALIZED = 1 + FUNCTIONAL = 2 + FAILED = 3 + + +_status = _Status.NOT_INITIALIZED +_socket = None + + +def systemd_notify(**values: str) -> None: + global _status + global _socket + + if _status is _Status.NOT_INITIALIZED: + socket_addr = os.getenv("NOTIFY_SOCKET") + os.unsetenv("NOTIFY_SOCKET") + if socket_addr is None: + _status = _Status.FAILED + return + if socket_addr.startswith("@"): + socket_addr = socket_addr.replace("@", "\0", 1) + + try: + _socket = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + _socket.connect(socket_addr) + _status = _Status.FUNCTIONAL + except Exception: + _socket = None + _status = _Status.FAILED + logger.warning(f"Failed to connect to $NOTIFY_SOCKET at '{socket_addr}'", exc_info=True) + return + + elif _status is _Status.FAILED: + return + + if _status is _Status.FUNCTIONAL: + assert _socket is not None + payload = "\n".join((f"{key}={value}" for key, value in values.items())) + try: + _socket.send(payload.encode("utf8")) + except Exception: + logger.warning("Failed to send notification to systemd", exc_info=True) + _status = _Status.FAILED + _socket.close() + _socket = None diff --git a/python/knot_resolver/utils/which.py b/python/knot_resolver/utils/which.py new file mode 100644 index 00000000..450102f3 --- /dev/null +++ b/python/knot_resolver/utils/which.py @@ -0,0 +1,22 @@ +import functools +import os +from pathlib import Path + + +@functools.lru_cache(maxsize=16) +def which(binary_name: str) -> Path: + """ + Given a name of an executable, search $PATH and return + the absolute path of that executable. The results of this function + are LRU cached. + + If not found, throws an RuntimeError. + """ + + possible_directories = os.get_exec_path() + for dr in possible_directories: + p = Path(dr, binary_name) + if p.exists(): + return p.absolute() + + raise RuntimeError(f"Executable {binary_name} was not found in $PATH") |