Skip to content

Commit

Permalink
stage
Browse files Browse the repository at this point in the history
Signed-off-by: Ruoqing He <[email protected]>
  • Loading branch information
RuoqingHe committed Mar 1, 2025
1 parent d871f25 commit c9b17a9
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 95 deletions.
5 changes: 5 additions & 0 deletions scripts/lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Supported architectures (arch used in kernel)
SUPPORT_ARCHS = ["arm64", "x86_64", "riscv"]

# Map arch used in linux kernel to arch understandable for Rust
MAP_RUST_ARCH = {"arm64": "aarch64", "x86_64": "x86_64", "riscv": "riscv64"}
41 changes: 38 additions & 3 deletions scripts/lib/kernel_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,42 @@
import requests
import subprocess
import tempfile
from lib import SUPPORT_ARCHS

KERNEL_ORG_CDN = "https://cdn.kernel.org/pub/linux/kernel"


def prepare_source(args):
check_kernel_version(args.version)

# Create `temp_dir` under `/tmp`
temp_dir = create_temp_dir(args.version)

# Download kernel tarball from https://cdn.kernel.org/
tarball = download_kernel(args.version, temp_dir)

# Extract kernel source
src_dir = extract_kernel(tarball, temp_dir)

# If arch is not provided, install headers for all supported archs
if args.arch is None:
for arch in SUPPORT_ARCHS:
installed_header_path = install_headers(
src_dir=src_dir,
arch=arch,
install_path=args.install_path,
)
else:
installed_header_path = install_headers(
src_dir=src_dir,
arch=args.arch,
install_path=args.install_path,
)

print(f"\nSuccessfully installed kernel headers to {installed_header_path}")
return installed_header_path


def check_kernel_version(version):
"""
Validate if the input kernel version exists in remote. Supports both X.Y
Expand Down Expand Up @@ -96,14 +128,17 @@ def extract_kernel(tarball_path, temp_dir):


def install_headers(src_dir, arch, install_path):
parent_dir = os.path.dirname(src_dir)
# If install_path is not provided, install to parent directory of src_dir to
# prevent messing up with extracted kernel source code
if install_path is None:
install_path = os.path.join(parent_dir, f"{arch}_headers")
install_path = os.path.dirname(src_dir)

try:
os.makedirs(install_path, exist_ok=True)

abs_install_path = os.path.abspath(install_path)
abs_install_path = os.path.abspath(
os.path.join(install_path, f"{arch}_headers")
)
print(f"Installing to {abs_install_path}")
result = subprocess.run(
[
Expand Down
144 changes: 144 additions & 0 deletions scripts/lib/kvm_bindings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import re
import os
import subprocess
from pathlib import Path
from lib.kernel_source import prepare_source
from lib import MAP_RUST_ARCH, SUPPORT_ARCHS


KVM_BINDINGS_DIR = "kvm-bindings/src/"


def generate_kvm_bindings(args):
installed_header_path = prepare_source(args)

# If arch is not provided, install headers for all supported archs
if args.arch is None:
for arch in SUPPORT_ARCHS:
generate_bindings(
installed_header_path, arch, args.attribute, args.output_path
)
else:
generate_bindings(
installed_header_path, args.arch, args.attribute, args.output_path
)


def generate_bindings(
installed_header_path: str, arch: str, attribute: str, output_path: str
):
"""
Generate bindings with source directory support
:param src_dir: Root source directory containing include/ and kvm-bindings/
:param arch: Target architecture (e.g. aarch64, riscv64, x86_64)
:param attribute: Attribute template for custom structs
:raises RuntimeError: If any generation step fails
"""
try:
# Validate source directory structure
arch_headers = os.path.join(installed_header_path, f"{arch}_headers")
kvm_header = Path(os.path.join(arch_headers, f"include/linux/kvm.h"))
if not kvm_header.is_file():
raise FileNotFoundError(f"KVM header missing at {kvm_header}")

arch = MAP_RUST_ARCH[arch]
structs = capture_serde(arch)
if not structs:
raise RuntimeError(
f"No structs found for {arch}, you need to invoke this command under rustvmm/kvm repo root"
)

# Step 2: Build bindgen command with dynamic paths
base_cmd = [
"bindgen",
os.path.abspath(kvm_header), # Use absolute path to header
"--impl-debug",
"--impl-partialeq",
"--with-derive-default",
"--with-derive-partialeq",
]

# Add custom attributes for each struct
for struct in structs:
base_cmd += ["--with-attribute-custom-struct", f"{struct}={attribute}"]

# Add include paths relative to source directory
base_cmd += ["--", f"-I{arch_headers}/include"] # Use absolute include path

# Step 3: Execute command with error handling
print(f"\nGenerating bindings for {arch}...")
bindings = subprocess.run(
base_cmd, check=True, capture_output=True, text=True, encoding="utf-8"
).stdout

print("Successfully generated bindings")

# Generate architecture-specific filename
output_file_path = f"{output_path}/{arch}/bindings.rs"

print(f"Generating to: {output_file_path}")

except subprocess.CalledProcessError as e:
err_msg = f"Bindgen failed (code {e.returncode})"
raise RuntimeError(err_msg) from e
except Exception as e:
raise RuntimeError(f"Generation failed: {str(e)}") from e

try:
with open(output_file_path, "w") as f:
f.write(bindings)

# Format with rustfmt
subprocess.run(["rustfmt", output_file_path], check=True)
print(f"Generation succeeded: {output_file_path}")
except subprocess.CalledProcessError:
raise RuntimeError("rustfmt formatting failed")
except IOError as e:
raise RuntimeError(f"File write error: {str(e)}")


def capture_serde(arch: str) -> list[str]:
"""
Parse serde implementations for specified architecture
:param arch: Architecture name (e.g. aarch64, riscv64, x86_64)
:return: List of found struct names
:raises FileNotFoundError: When target file is missing
:raises ValueError: When serde_impls block is not found
"""
# Build target file path
target_path = Path(f"{KVM_BINDINGS_DIR}/{arch}/serialize.rs")

# Validate file existence
if not target_path.is_file():
raise FileNotFoundError(
f"Serialization file not found for {arch}: {target_path}"
)

print(f"Extracting serde structs of {arch} from: {target_path}")

# Read file content
content = target_path.read_text(encoding="utf-8")

# Multi-line regex pattern to find serde_impls block
pattern = re.compile(
r"serde_impls!\s*\{\s*(?P<struct>.*?)\s*\}", re.DOTALL | re.MULTILINE
)

# Extract struct list from matched block
match = pattern.search(content)
if not match:
raise ValueError(f"No serde_impls! block found in {target_path}")

struct_list = match.group("struct")

structs = []
for line in struct_list.splitlines():
# Split and clean individual words
for word in line.split():
clean_word = word.strip().rstrip(",")
if clean_word:
structs.append(clean_word)

return structs
79 changes: 52 additions & 27 deletions scripts/lib/syscall.py → scripts/lib/seccompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,39 @@
# SPDX-License-Identifier: Apache-2.0

import subprocess
import os
import re
from lib.kernel_source import prepare_source
from lib import MAP_RUST_ARCH, SUPPORT_ARCHS
from pathlib import Path

SECCOMPILER_SYSCALL_DIR = "src/syscall_table"

def generate_syscall_table(file_path):
"""Generate syscall table from specified header file"""
try:
with open(file_path, "r") as f:
syscalls = []
pattern = re.compile(r"^#define __NR_(\w+)\s+(\d+)")

for line in f:
line = line.strip()
if line.startswith("#define __NR_"):
match = pattern.match(line)
if match:
name = match.group(1)
num = int(match.group(2))
syscalls.append((name, num))
def generate_seccompiler(args):
installed_header_path = prepare_source(args)

# Sort alphabetically by syscall name
syscalls.sort(key=lambda x: x[0])
syscall_list = [f'("{name}", {num}),' for name, num in syscalls]
return " ".join(syscall_list)
# If arch is not provided, install headers for all supported archs
if args.arch is None:
for arch in SUPPORT_ARCHS:
generate_rust_code(installed_header_path, arch, args.output_path)
else:
generate_rust_code(installed_header_path, args.arch, args.output_path)

except FileNotFoundError:
raise RuntimeError(f"Header file not found: {file_path}")
except Exception as e:
raise RuntimeError(f"File processing failed: {str(e)}")

def generate_rust_code(installed_header_path: str, arch: str, output_path: str):
# Generate syscall table
arch_headers = os.path.join(installed_header_path, f"{arch}_headers")
syscall_header = Path(os.path.join(arch_headers, f"include/asm/unistd_64.h"))
if not syscall_header.is_file():
raise FileNotFoundError(f"syscall headers missing at {syscall_header}")
syscalls = generate_syscall_table(syscall_header)

arch = MAP_RUST_ARCH[arch]
output_file_path = f"{output_path}/{arch}.rs"

def generate_rust_code(syscalls, output_path):
"""Generate Rust code and format with rustfmt"""
print(f"Generating to: {output_path}")
print(f"Generating to: {output_file_path}")
code = f"""use std::collections::HashMap;
pub(crate) fn make_syscall_table() -> HashMap<&'static str, i64> {{
vec![
Expand All @@ -43,13 +43,38 @@ def generate_rust_code(syscalls, output_path):
}}
"""
try:
with open(output_path, "w") as f:
with open(output_file_path, "w") as f:
f.write(code)

# Format with rustfmt
subprocess.run(["rustfmt", output_path], check=True)
print(f"Generation succeeded: {output_path}")
subprocess.run(["rustfmt", output_file_path], check=True)
print(f"Generation succeeded: {output_file_path}")
except subprocess.CalledProcessError:
raise RuntimeError("rustfmt formatting failed")
except IOError as e:
raise RuntimeError(f"File write error: {str(e)}")


def generate_syscall_table(syscall_header_path: str):
"""Generate syscall table from specified header file"""
try:
with open(syscall_header_path, "r") as f:
syscalls = []
pattern = re.compile(r"^#define __NR_(\w+)\s+(\d+)")

for line in f:
line = line.strip()
if line.startswith("#define __NR_"):
match = pattern.match(line)
if match:
name = match.group(1)
num = int(match.group(2))
syscalls.append((name, num))

# Sort alphabetically by syscall name
syscalls.sort(key=lambda x: x[0])
syscall_list = [f'("{name}", {num}),' for name, num in syscalls]
return " ".join(syscall_list)

except Exception as e:
raise RuntimeError(f"File processing failed: {str(e)}")
Loading

0 comments on commit c9b17a9

Please sign in to comment.