222 lines
6.8 KiB
Python
Executable File
222 lines
6.8 KiB
Python
Executable File
#!/usr/bin/env -S uv run --script
|
|
# vim: set filetype=python :
|
|
|
|
# Adapted from: https://enix.io/en/blog/pxe-talos/
|
|
|
|
import base64
|
|
import functools
|
|
import json
|
|
import pathlib
|
|
import sys
|
|
|
|
import git
|
|
import requests
|
|
import yaml
|
|
from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template
|
|
from mergedeep import Strategy, merge
|
|
from netaddr import IPAddress
|
|
|
|
REPO = git.Repo(sys.path[0], search_parent_directories=True)
|
|
assert REPO.working_dir is not None
|
|
|
|
ROOT = pathlib.Path(REPO.working_dir)
|
|
|
|
NODES = ROOT.joinpath("nodes")
|
|
SCHEMATICS = ROOT.joinpath("schematics")
|
|
RENDERED = ROOT.joinpath("rendered")
|
|
|
|
EXTENSIONS = ["jinja2.ext.do"]
|
|
|
|
PATCHES = Environment(
|
|
loader=FileSystemLoader(ROOT.joinpath("patches")),
|
|
undefined=StrictUndefined,
|
|
extensions=EXTENSIONS,
|
|
)
|
|
TEMPLATES = Environment(
|
|
loader=FileSystemLoader(ROOT.joinpath("templates")),
|
|
undefined=StrictUndefined,
|
|
extensions=EXTENSIONS,
|
|
)
|
|
|
|
|
|
# When we try to make a deep copy of the nodes dict it fails as the Template
|
|
# does not implement __deepcopy__, so this wrapper type facilitates that
|
|
class TemplateWrapper:
|
|
def __init__(self, template: Template):
|
|
self.template = template
|
|
|
|
def __deepcopy__(self, memo):
|
|
# NOTE: This is not a true deepcopy, but since we know we won't modify
|
|
# the template this is fine.
|
|
return self
|
|
|
|
|
|
def render_templates(node: dict, args: dict):
|
|
class Inner(json.JSONEncoder):
|
|
def default(self, o):
|
|
if isinstance(o, TemplateWrapper):
|
|
try:
|
|
rendered = o.template.render(args | {"node": node})
|
|
except Exception as e:
|
|
e.add_note(f"While rendering for: {node['hostname']}")
|
|
raise e
|
|
# Parse the rendered yaml
|
|
return yaml.safe_load(rendered)
|
|
|
|
return super().default(o)
|
|
|
|
return Inner
|
|
|
|
|
|
def tailscale_subnet(gateway: str, netmask: str):
|
|
netmask_bits = IPAddress(netmask).netmask_bits()
|
|
return f"{IPAddress(gateway) & IPAddress(netmask)}/{netmask_bits}"
|
|
|
|
def load_secret(path: str):
|
|
with open(path) as f:
|
|
return base64.b64encode(f.read().encode()).decode()
|
|
|
|
@functools.cache
|
|
def get_schematic_id(schematic: str):
|
|
"""Lookup the schematic id associated with a given schematic"""
|
|
r = requests.post("https://factory.talos.dev/schematics", data=schematic)
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
return data["id"]
|
|
|
|
|
|
def schematic_constructor(loader: yaml.SafeLoader, node: yaml.nodes.ScalarNode):
|
|
"""Load specified schematic file and get the assocatied schematic id"""
|
|
schematic_name = loader.construct_yaml_str(node)
|
|
try:
|
|
schematic = SCHEMATICS.joinpath(schematic_name).with_suffix(".yaml").read_text()
|
|
return get_schematic_id(schematic)
|
|
except Exception:
|
|
raise yaml.MarkedYAMLError("Failed to load schematic", node.start_mark)
|
|
|
|
|
|
def template_constructor(environment: Environment):
|
|
def inner(loader: yaml.SafeLoader, node: yaml.nodes.ScalarNode):
|
|
patch_name = loader.construct_scalar(node)
|
|
try:
|
|
template = environment.get_template(f"{patch_name}.yaml")
|
|
return TemplateWrapper(template)
|
|
except Exception:
|
|
raise yaml.MarkedYAMLError("Failed to load patch", node.start_mark)
|
|
|
|
return inner
|
|
|
|
|
|
def realpath_constructor(directory: pathlib.Path):
|
|
def inner(loader: yaml.SafeLoader, node: yaml.nodes.ScalarNode):
|
|
try:
|
|
realpath = directory.joinpath(loader.construct_scalar(node)).resolve(
|
|
strict=True
|
|
)
|
|
return str(realpath)
|
|
except Exception:
|
|
raise yaml.MarkedYAMLError("Failed to get real path", node.start_mark)
|
|
|
|
return inner
|
|
|
|
|
|
def get_loader(directory: pathlib.Path):
|
|
"""Add special constructors to yaml loader"""
|
|
loader = yaml.SafeLoader
|
|
loader.add_constructor("!realpath", realpath_constructor(directory))
|
|
loader.add_constructor("!schematic", schematic_constructor)
|
|
loader.add_constructor("!patch", template_constructor(PATCHES))
|
|
|
|
return loader
|
|
|
|
|
|
@functools.cache
|
|
def get_defaults(directory: pathlib.Path, root: pathlib.Path):
|
|
"""Compute the defaults from the provided directory and parents."""
|
|
try:
|
|
with open(directory.joinpath("_defaults.yaml")) as fyaml:
|
|
yml_data = yaml.load(fyaml, Loader=get_loader(directory))
|
|
except OSError:
|
|
yml_data = {}
|
|
|
|
# Stop recursion when reaching root directory
|
|
if directory != root:
|
|
return merge(
|
|
{},
|
|
get_defaults(directory.parent, root),
|
|
yml_data,
|
|
strategy=Strategy.TYPESAFE_REPLACE,
|
|
)
|
|
else:
|
|
return yml_data
|
|
|
|
|
|
def walk_files(root: pathlib.Path):
|
|
"""Get all files that do not start with and underscore"""
|
|
for dirpath, _dirnames, filenames in root.walk():
|
|
for fn in filenames:
|
|
if not fn.startswith("_"):
|
|
yield dirpath.joinpath(fn)
|
|
|
|
|
|
def main():
|
|
with open(ROOT.joinpath("config.yaml")) as fyaml:
|
|
config = yaml.safe_load(fyaml)
|
|
|
|
with open(ROOT.joinpath("secrets.yaml")) as fyaml:
|
|
merge(config, yaml.safe_load(fyaml), strategy=Strategy.TYPESAFE_REPLACE)
|
|
|
|
template_args = {
|
|
"config": config,
|
|
"root": ROOT,
|
|
"helper": {"tailscale_subnet": tailscale_subnet, "load_secret": load_secret},
|
|
}
|
|
|
|
nodes = []
|
|
for fullname in walk_files(NODES):
|
|
filename = str(fullname.relative_to(NODES).parent) + "/" + fullname.stem
|
|
|
|
with open(fullname) as fyaml:
|
|
yml_data = yaml.load(fyaml, Loader=get_loader(fullname.parent))
|
|
yml_data = merge(
|
|
{},
|
|
get_defaults(fullname.parent, NODES),
|
|
yml_data,
|
|
strategy=Strategy.TYPESAFE_REPLACE,
|
|
)
|
|
yml_data["hostname"] = fullname.stem
|
|
yml_data["filename"] = filename
|
|
nodes.append(yml_data)
|
|
|
|
# Quick and dirty way to resolve all the templates using a custom encoder
|
|
nodes = list(
|
|
map(
|
|
lambda node: json.loads(
|
|
json.dumps(node, cls=render_templates(node, template_args))
|
|
),
|
|
nodes,
|
|
)
|
|
)
|
|
|
|
# HACK: We can't hash a dict, so we first convert it to json, the use set
|
|
# to get all the unique entries, and then convert it back
|
|
# NOTE: This assumes that all nodes in the cluster use the same definition for the cluster
|
|
clusters = list(
|
|
json.loads(cluster)
|
|
for cluster in set(json.dumps(node["cluster"]) for node in nodes)
|
|
)
|
|
|
|
template_args |= {"nodes": nodes, "clusters": clusters}
|
|
|
|
RENDERED.mkdir(exist_ok=True)
|
|
for template_name in TEMPLATES.list_templates():
|
|
template = TEMPLATES.get_template(template_name)
|
|
|
|
rendered = template.render(template_args)
|
|
with open(RENDERED.joinpath(template_name), "w") as f:
|
|
f.write(rendered)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|