Files
metal/tools/render

183 lines
5.5 KiB
Python
Executable File

#!/usr/bin/env python3
# Adapted from: https://enix.io/en/blog/pxe-talos/
import functools
import json
import pathlib
import sys
import git
import requests
import yaml
from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template
from mergedeep import merge
REPO = git.Repo(sys.path[0], search_parent_directories=True)
assert REPO.working_dir is not None
ROOT = pathlib.Path(REPO.working_dir)
NODES = ROOT.joinpath("nodes")
SCHEMATICS = ROOT.joinpath("schematics")
RENDERED = ROOT.joinpath("rendered")
EXTENSIONS = ["jinja2.ext.do"]
PATCHES = Environment(
loader=FileSystemLoader(ROOT.joinpath("patches")),
undefined=StrictUndefined,
extensions=EXTENSIONS,
)
TEMPLATES = Environment(
loader=FileSystemLoader(ROOT.joinpath("templates")),
undefined=StrictUndefined,
extensions=EXTENSIONS,
)
def render_templates(node: dict, args: dict):
class Inner(json.JSONEncoder):
def default(self, o):
if isinstance(o, Template):
try:
rendered = o.render(args | node)
except Exception as e:
e.add_note(f"While rendering for: {node['hostname']}")
raise e
# Parse the rendered yaml
return yaml.safe_load(rendered)
return super().default(o)
return Inner
@functools.cache
def get_schematic_id(schematic: str):
"""Lookup the schematic id associated with a given schematic"""
r = requests.post("https://factory.talos.dev/schematics", data=schematic)
r.raise_for_status()
data = r.json()
return data["id"]
def schematic_constructor(loader: yaml.SafeLoader, node: yaml.nodes.ScalarNode):
"""Load specified schematic file and get the assocatied schematic id"""
schematic_name = loader.construct_yaml_str(node)
try:
schematic = SCHEMATICS.joinpath(schematic_name).with_suffix(".yaml").read_text()
return get_schematic_id(schematic)
except Exception:
raise yaml.MarkedYAMLError("Failed to load schematic", node.start_mark)
def template_constructor(environment: Environment):
def inner(loader: yaml.SafeLoader, node: yaml.nodes.ScalarNode):
patch_name = loader.construct_scalar(node)
try:
template = environment.get_template(f"{patch_name}.yaml")
return template
except Exception:
raise yaml.MarkedYAMLError("Failed to load patch", node.start_mark)
return inner
def realpath_constructor(directory: pathlib.Path):
def inner(loader: yaml.SafeLoader, node: yaml.nodes.ScalarNode):
try:
realpath = directory.joinpath(loader.construct_scalar(node)).resolve(
strict=True
)
return str(realpath)
except Exception:
raise yaml.MarkedYAMLError("Failed to get real path", node.start_mark)
return inner
def get_loader(directory: pathlib.Path):
"""Add special constructors to yaml loader"""
loader = yaml.SafeLoader
loader.add_constructor("!realpath", realpath_constructor(directory))
loader.add_constructor("!schematic", schematic_constructor)
loader.add_constructor("!patch", template_constructor(PATCHES))
return loader
@functools.cache
def get_defaults(directory: pathlib.Path, root: pathlib.Path):
"""Compute the defaults from the provided directory and parents."""
try:
with open(directory.joinpath("_defaults.yaml")) as fyaml:
yml_data = yaml.load(fyaml, Loader=get_loader(directory))
except OSError:
yml_data = {}
# Stop recursion when reaching root directory
if directory != root:
return get_defaults(directory.parent, root) | yml_data
else:
return yml_data
def walk_files(root: pathlib.Path):
"""Get all files that do not start with and underscore"""
for dirpath, _dirnames, filenames in root.walk():
for fn in filenames:
if not fn.startswith("_"):
yield dirpath.joinpath(fn)
def main():
with open(ROOT.joinpath("config.yaml")) as fyaml:
config = yaml.safe_load(fyaml)
with open(ROOT.joinpath("secrets.yaml")) as fyaml:
merge(config, yaml.safe_load(fyaml))
template_args = {"config": config, "root": ROOT}
nodes = []
for fullname in walk_files(NODES):
filename = str(fullname.relative_to(NODES).parent) + "/" + fullname.stem
with open(fullname) as fyaml:
yml_data = yaml.load(fyaml, Loader=get_loader(fullname.parent))
yml_data = get_defaults(fullname.parent, NODES) | yml_data
yml_data["hostname"] = fullname.stem
yml_data["filename"] = filename
nodes.append(yml_data)
# Quick and dirty way to resolve all the templates using a custom encoder
nodes = list(
map(
lambda node: json.loads(
json.dumps(node, cls=render_templates(node, template_args))
),
nodes,
)
)
# Get all clusters
# NOTE: This assumes that all nodes in the cluster use the same definition for the cluster
clusters = [
dict(s) for s in set(frozenset(node["cluster"].items()) for node in nodes)
]
template_args |= {"nodes": nodes, "clusters": clusters}
RENDERED.mkdir(exist_ok=True)
for template_name in TEMPLATES.list_templates():
template = TEMPLATES.get_template(template_name)
rendered = template.render(template_args)
with open(RENDERED.joinpath(template_name), "w") as f:
f.write(rendered)
if __name__ == "__main__":
main()