#!/usr/bin/env -S uv run --script # vim: set filetype=python : # Adapted from: https://enix.io/en/blog/pxe-talos/ import base64 import functools import pathlib import sys from typing import Annotated, Any, List, Literal import git import requests import yaml from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template from mergedeep import Strategy, merge from netaddr import IPAddress from pydantic import ( BaseModel, BeforeValidator, ConfigDict, HttpUrl, IPvAnyAddress, ValidationInfo, ) from pydantic_extra_types.semantic_version import SemanticVersion from models import Model as TalosModel REPO = git.Repo(sys.path[0], search_parent_directories=True) assert REPO.working_dir is not None ROOT = pathlib.Path(REPO.working_dir) NODES = ROOT.joinpath("nodes") SCHEMATICS = ROOT.joinpath("schematics") RENDERED = ROOT.joinpath("rendered") EXTENSIONS = ["jinja2.ext.do"] PATCHES = Environment( loader=FileSystemLoader(ROOT.joinpath("patches")), undefined=StrictUndefined, extensions=EXTENSIONS, ) TEMPLATES = Environment( loader=FileSystemLoader(ROOT.joinpath("templates")), undefined=StrictUndefined, extensions=EXTENSIONS, ) class ServerConfig(BaseModel): model_config = ConfigDict(strict=True, extra="forbid") tftpIp: IPvAnyAddress httpUrl: HttpUrl class TailscaleConfig(BaseModel): model_config = ConfigDict(strict=True, extra="forbid") loginServer: HttpUrl authKey: str class Config(BaseModel): model_config = ConfigDict(strict=True, extra="forbid") server: ServerConfig tailscale: TailscaleConfig class Cluster(BaseModel): model_config = ConfigDict(strict=True, extra="forbid") name: str production: bool controlPlaneIp: IPvAnyAddress # TODO: Path secretsFile: str sopsKeyFile: str # When we try to make a deep copy of the nodes dict it fails as the Template # does not implement __deepcopy__, so this wrapper type facilitates that class TemplateWrapper: def __init__(self, template: Template): self.template = template def __deepcopy__(self, memo): # NOTE: This is not a true deepcopy, but since we know we won't modify # the template this is fine. return self def render_patch(wrapper: Any, info: ValidationInfo): if not isinstance(wrapper, TemplateWrapper): raise RuntimeError("Expected TemplateWrapper") args = (info.context or {}) | {"node": info.data} try: rendered = wrapper.template.render(args) except Exception as e: e.add_note(f"While rendering for: {args['node']['hostname']}") raise e # Parse the rendered yaml return yaml.safe_load(rendered) class Node(BaseModel): model_config = ConfigDict(strict=True, extra="forbid") schematicId: str arch: Literal["amd64"] talosVersion: SemanticVersion kubernetesVersion: SemanticVersion kernelArgs: List[str] extraKernelArgs: List[str] dns: List[IPvAnyAddress] # TODO: Validation ntp: str install: bool advertiseRoutes: bool serial: str interface: str ip: IPvAnyAddress netmask: IPvAnyAddress gateway: IPvAnyAddress # TODO: Extra validation installDisk: str autoInstall: bool cluster: Cluster hostname: str filename: str type: Literal["controlplane", "worker"] patches: List[Annotated[TalosModel, BeforeValidator(render_patch)]] patchesControlPlane: List[Annotated[TalosModel, BeforeValidator(render_patch)]] def tailscale_subnet(gateway: IPvAnyAddress, netmask: IPvAnyAddress): netmask_bits = IPAddress(netmask.exploded).netmask_bits() return f"{IPAddress(gateway.exploded) & IPAddress(netmask.exploded)}/{netmask_bits}" def load_secret(path: str): with open(path) as f: return base64.b64encode(f.read().encode()).decode() def model_dump_json(model: BaseModel): return model.model_dump_json(exclude_none=True) @functools.cache def get_schematic_id(schematic: str): """Lookup the schematic id associated with a given schematic""" r = requests.post("https://factory.talos.dev/schematics", data=schematic) r.raise_for_status() data = r.json() return data["id"] def schematic_constructor(loader: yaml.SafeLoader, node: yaml.nodes.ScalarNode): """Load specified schematic file and get the assocatied schematic id""" schematic_name = loader.construct_yaml_str(node) try: schematic = SCHEMATICS.joinpath(schematic_name).with_suffix(".yaml").read_text() return get_schematic_id(schematic) except Exception: raise yaml.MarkedYAMLError("Failed to load schematic", node.start_mark) def template_constructor(environment: Environment): def inner(loader: yaml.SafeLoader, node: yaml.nodes.ScalarNode): patch_name = loader.construct_scalar(node) try: template = environment.get_template(f"{patch_name}.yaml") return TemplateWrapper(template) except Exception: raise yaml.MarkedYAMLError("Failed to load patch", node.start_mark) return inner def realpath_constructor(directory: pathlib.Path): def inner(loader: yaml.SafeLoader, node: yaml.nodes.ScalarNode): try: realpath = directory.joinpath(loader.construct_scalar(node)).resolve( strict=True ) return str(realpath) except Exception: raise yaml.MarkedYAMLError("Failed to get real path", node.start_mark) return inner def get_loader(directory: pathlib.Path): """Add special constructors to yaml loader""" loader = yaml.SafeLoader loader.add_constructor("!realpath", realpath_constructor(directory)) loader.add_constructor("!schematic", schematic_constructor) loader.add_constructor("!patch", template_constructor(PATCHES)) return loader @functools.cache def get_defaults(directory: pathlib.Path, root: pathlib.Path): """Compute the defaults from the provided directory and parents.""" try: with open(directory.joinpath("_defaults.yaml")) as fyaml: yml_data = yaml.load(fyaml, Loader=get_loader(directory)) except OSError: yml_data = {} # Stop recursion when reaching root directory if directory != root: return merge( {}, get_defaults(directory.parent, root), yml_data, strategy=Strategy.TYPESAFE_REPLACE, ) else: return yml_data def walk_files(root: pathlib.Path): """Get all files that do not start with and underscore""" for dirpath, _dirnames, filenames in root.walk(): for fn in filenames: if not fn.startswith("_"): yield dirpath.joinpath(fn) def main(): with open(ROOT.joinpath("config.yaml")) as fyaml: config = yaml.safe_load(fyaml) with open(ROOT.joinpath("secrets.yaml")) as fyaml: merge(config, yaml.safe_load(fyaml), strategy=Strategy.TYPESAFE_REPLACE) config = Config(**config) template_args = { "config": config, "root": ROOT, "helper": { "tailscale_subnet": tailscale_subnet, "load_secret": load_secret, "model_dump_json": model_dump_json, }, } nodes: List[Node] = [] for fullname in walk_files(NODES): filename = str(fullname.relative_to(NODES).parent) + "/" + fullname.stem with open(fullname) as fyaml: yml_data = yaml.load(fyaml, Loader=get_loader(fullname.parent)) yml_data = merge( {}, get_defaults(fullname.parent, NODES), yml_data, strategy=Strategy.TYPESAFE_REPLACE, ) yml_data["hostname"] = fullname.stem yml_data["filename"] = filename node = Node.model_validate(yml_data, context=template_args) nodes.append(node) # HACK: We can't hash a dict, so we first convert it to json, the use set # to get all the unique entries, and then convert it back # NOTE: This assumes that all nodes in the cluster use the same definition for the cluster clusters = list( Cluster.model_validate_json(cluster) for cluster in set(node.cluster.model_dump_json() for node in nodes) ) template_args |= {"nodes": nodes, "clusters": clusters} RENDERED.mkdir(exist_ok=True) for template_name in TEMPLATES.list_templates(): template = TEMPLATES.get_template(template_name) rendered = template.render(template_args) with open(RENDERED.joinpath(template_name), "w") as f: f.write(rendered) if __name__ == "__main__": main()