#!/usr/bin/env -S uv run --script # vim: set filetype=python : # Adapted from: https://enix.io/en/blog/pxe-talos/ import functools import json import pathlib import sys import git import requests import yaml from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template from mergedeep import merge from netaddr import IPAddress REPO = git.Repo(sys.path[0], search_parent_directories=True) assert REPO.working_dir is not None ROOT = pathlib.Path(REPO.working_dir) NODES = ROOT.joinpath("nodes") SCHEMATICS = ROOT.joinpath("schematics") RENDERED = ROOT.joinpath("rendered") EXTENSIONS = ["jinja2.ext.do"] PATCHES = Environment( loader=FileSystemLoader(ROOT.joinpath("patches")), undefined=StrictUndefined, extensions=EXTENSIONS, ) TEMPLATES = Environment( loader=FileSystemLoader(ROOT.joinpath("templates")), undefined=StrictUndefined, extensions=EXTENSIONS, ) def render_templates(node: dict, args: dict): class Inner(json.JSONEncoder): def default(self, o): if isinstance(o, Template): try: rendered = o.render(args | {"node": node}) except Exception as e: e.add_note(f"While rendering for: {node['hostname']}") raise e # Parse the rendered yaml return yaml.safe_load(rendered) return super().default(o) return Inner def tailscale_subnet(gateway: str, netmask: str): netmask_bits = IPAddress(netmask).netmask_bits() return f"{IPAddress(gateway) & IPAddress(netmask)}/{netmask_bits}" @functools.cache def get_schematic_id(schematic: str): """Lookup the schematic id associated with a given schematic""" r = requests.post("https://factory.talos.dev/schematics", data=schematic) r.raise_for_status() data = r.json() return data["id"] def schematic_constructor(loader: yaml.SafeLoader, node: yaml.nodes.ScalarNode): """Load specified schematic file and get the assocatied schematic id""" schematic_name = loader.construct_yaml_str(node) try: schematic = SCHEMATICS.joinpath(schematic_name).with_suffix(".yaml").read_text() return get_schematic_id(schematic) except Exception: raise yaml.MarkedYAMLError("Failed to load schematic", node.start_mark) def template_constructor(environment: Environment): def inner(loader: yaml.SafeLoader, node: yaml.nodes.ScalarNode): patch_name = loader.construct_scalar(node) try: template = environment.get_template(f"{patch_name}.yaml") return template except Exception: raise yaml.MarkedYAMLError("Failed to load patch", node.start_mark) return inner def realpath_constructor(directory: pathlib.Path): def inner(loader: yaml.SafeLoader, node: yaml.nodes.ScalarNode): try: realpath = directory.joinpath(loader.construct_scalar(node)).resolve( strict=True ) return str(realpath) except Exception: raise yaml.MarkedYAMLError("Failed to get real path", node.start_mark) return inner def get_loader(directory: pathlib.Path): """Add special constructors to yaml loader""" loader = yaml.SafeLoader loader.add_constructor("!realpath", realpath_constructor(directory)) loader.add_constructor("!schematic", schematic_constructor) loader.add_constructor("!patch", template_constructor(PATCHES)) return loader @functools.cache def get_defaults(directory: pathlib.Path, root: pathlib.Path): """Compute the defaults from the provided directory and parents.""" try: with open(directory.joinpath("_defaults.yaml")) as fyaml: yml_data = yaml.load(fyaml, Loader=get_loader(directory)) except OSError: yml_data = {} # Stop recursion when reaching root directory if directory != root: return get_defaults(directory.parent, root) | yml_data else: return yml_data def walk_files(root: pathlib.Path): """Get all files that do not start with and underscore""" for dirpath, _dirnames, filenames in root.walk(): for fn in filenames: if not fn.startswith("_"): yield dirpath.joinpath(fn) def main(): with open(ROOT.joinpath("config.yaml")) as fyaml: config = yaml.safe_load(fyaml) with open(ROOT.joinpath("secrets.yaml")) as fyaml: merge(config, yaml.safe_load(fyaml)) template_args = { "config": config, "root": ROOT, "helper": {"tailscale_subnet": tailscale_subnet}, } nodes = [] for fullname in walk_files(NODES): filename = str(fullname.relative_to(NODES).parent) + "/" + fullname.stem with open(fullname) as fyaml: yml_data = yaml.load(fyaml, Loader=get_loader(fullname.parent)) yml_data = get_defaults(fullname.parent, NODES) | yml_data yml_data["hostname"] = fullname.stem yml_data["filename"] = filename nodes.append(yml_data) # Quick and dirty way to resolve all the templates using a custom encoder nodes = list( map( lambda node: json.loads( json.dumps(node, cls=render_templates(node, template_args)) ), nodes, ) ) # Get all clusters # NOTE: This assumes that all nodes in the cluster use the same definition for the cluster clusters = [ dict(s) for s in set(frozenset(node["cluster"].items()) for node in nodes) ] template_args |= {"nodes": nodes, "clusters": clusters} RENDERED.mkdir(exist_ok=True) for template_name in TEMPLATES.list_templates(): template = TEMPLATES.get_template(template_name) rendered = template.render(template_args) with open(RENDERED.joinpath(template_name), "w") as f: f.write(rendered) if __name__ == "__main__": main()