249 lines
6.8 KiB
Python
249 lines
6.8 KiB
Python
"""Update IANA command handler."""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from urllib.error import URLError
|
|
from urllib.request import urlopen
|
|
|
|
from ..iana_parser import (
|
|
extract_updated_date,
|
|
find_registry,
|
|
get_table_name_from_filename,
|
|
parse_xml_with_namespace_support,
|
|
)
|
|
from ..iana_validator import (
|
|
ValidationError,
|
|
normalize_header,
|
|
validate_headers,
|
|
validate_registry_data,
|
|
)
|
|
from ..output import print_error
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def fetch_xml_from_url(url: str, timeout: int = 30) -> str:
|
|
"""Fetch XML content from URL.
|
|
|
|
Args:
|
|
url: URL to fetch
|
|
timeout: Timeout in seconds
|
|
|
|
Returns:
|
|
XML content as string
|
|
|
|
Raises:
|
|
URLError: If URL cannot be fetched
|
|
|
|
"""
|
|
logger.info(f"Fetching {url}")
|
|
with urlopen(url, timeout=timeout) as response:
|
|
return response.read().decode("utf-8")
|
|
|
|
|
|
def calculate_diff(
|
|
old_rows: list[tuple],
|
|
new_rows: list[tuple],
|
|
pk_index: int = 0,
|
|
) -> dict[str, list]:
|
|
"""Calculate diff between old and new data.
|
|
|
|
Args:
|
|
old_rows: Existing rows from DB
|
|
new_rows: New rows from XML
|
|
pk_index: Index of primary key column
|
|
|
|
Returns:
|
|
Dict with 'added', 'deleted', 'modified' lists of primary keys
|
|
|
|
"""
|
|
old_dict = {row[pk_index]: row for row in old_rows}
|
|
new_dict = {row[pk_index]: row for row in new_rows}
|
|
|
|
added = [k for k in new_dict if k not in old_dict]
|
|
deleted = [k for k in old_dict if k not in new_dict]
|
|
modified = [k for k in new_dict if k in old_dict and old_dict[k] != new_dict[k]]
|
|
|
|
return {"added": added, "deleted": deleted, "modified": modified}
|
|
|
|
|
|
def process_registry_with_validation(
|
|
xml_content: str,
|
|
registry_id: str,
|
|
table_name: str,
|
|
headers: list[str],
|
|
db_conn: sqlite3.Connection,
|
|
skip_min_rows_check: bool = False,
|
|
) -> tuple[int, dict[str, list]]:
|
|
"""Process registry with validation and diff calculation.
|
|
|
|
Args:
|
|
xml_content: XML content as string
|
|
registry_id: Registry ID to extract
|
|
table_name: Database table name
|
|
headers: List of column headers
|
|
db_conn: Database connection
|
|
skip_min_rows_check: Skip minimum rows validation (for tests)
|
|
|
|
Returns:
|
|
Tuple of (row_count, diff_dict)
|
|
|
|
Raises:
|
|
ValidationError: If validation fails
|
|
ValueError: If registry not found
|
|
|
|
"""
|
|
import tempfile
|
|
|
|
with tempfile.NamedTemporaryFile(
|
|
mode="w", suffix=".xml", delete=False, encoding="utf-8"
|
|
) as tmp_file:
|
|
tmp_file.write(xml_content)
|
|
tmp_path = tmp_file.name
|
|
|
|
try:
|
|
root, ns = parse_xml_with_namespace_support(tmp_path)
|
|
finally:
|
|
Path(tmp_path).unlink()
|
|
|
|
validate_headers(table_name, headers, db_conn)
|
|
|
|
registry = find_registry(root, registry_id, ns)
|
|
|
|
if ns:
|
|
records = registry.findall("iana:record", ns)
|
|
else:
|
|
records = registry.findall("record")
|
|
|
|
from ..iana_parser import extract_field_value, is_unassigned
|
|
|
|
rows_dict = []
|
|
for record in records:
|
|
if is_unassigned(record, ns):
|
|
continue
|
|
row_dict = {}
|
|
for header in headers:
|
|
normalized_key = normalize_header(header)
|
|
row_dict[normalized_key] = extract_field_value(record, header, ns)
|
|
rows_dict.append(row_dict)
|
|
|
|
validate_registry_data(table_name, rows_dict, skip_min_rows_check)
|
|
|
|
rows = [tuple(row.values()) for row in rows_dict]
|
|
|
|
cursor = db_conn.cursor()
|
|
old_rows = cursor.execute(f"SELECT * FROM {table_name}").fetchall()
|
|
|
|
diff = calculate_diff(old_rows, rows)
|
|
|
|
placeholders = ",".join(["?"] * len(headers))
|
|
cursor.execute(f"DELETE FROM {table_name}")
|
|
cursor.executemany(f"INSERT INTO {table_name} VALUES ({placeholders})", rows)
|
|
|
|
return len(rows), diff
|
|
|
|
|
|
def handle_update_iana_command(args: argparse.Namespace) -> int:
|
|
"""Handle the update-iana subcommand.
|
|
|
|
Args:
|
|
args: Parsed arguments
|
|
|
|
Returns:
|
|
Exit code (0 for success, 1 for error)
|
|
|
|
"""
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
|
|
|
db_path = args.database
|
|
|
|
if not Path(db_path).exists():
|
|
print_error(f"Database not found: {db_path}")
|
|
return 1
|
|
|
|
script_dir = Path(__file__).parent.parent
|
|
config_path = script_dir / "data" / "iana_parse.json"
|
|
|
|
logger.info(f"Loading configuration from {config_path}")
|
|
|
|
try:
|
|
with config_path.open(encoding="utf-8") as f:
|
|
config = json.load(f)
|
|
except (FileNotFoundError, json.JSONDecodeError, OSError) as e:
|
|
print_error(f"Error loading configuration: {e}")
|
|
return 1
|
|
|
|
try:
|
|
conn = sqlite3.connect(str(db_path))
|
|
except sqlite3.Error as e:
|
|
print_error(f"Error opening database: {e}")
|
|
return 1
|
|
|
|
logger.info("Starting IANA registry update")
|
|
|
|
try:
|
|
conn.execute("BEGIN TRANSACTION")
|
|
|
|
total_registries = 0
|
|
total_rows = 0
|
|
|
|
for url, registries in config.items():
|
|
try:
|
|
xml_content = fetch_xml_from_url(url)
|
|
except (URLError, OSError) as e:
|
|
print_error(f"Failed to fetch {url}: {e}")
|
|
conn.rollback()
|
|
conn.close()
|
|
return 1
|
|
|
|
xml_date = extract_updated_date(xml_content)
|
|
logger.info(f"XML data date: {xml_date}")
|
|
|
|
for registry_id, output_filename, headers in registries:
|
|
table_name = get_table_name_from_filename(output_filename)
|
|
|
|
try:
|
|
row_count, diff = process_registry_with_validation(
|
|
xml_content, registry_id, table_name, headers, conn
|
|
)
|
|
|
|
logger.info(
|
|
f"{table_name}: {row_count} rows "
|
|
f"({len(diff['added'])} added, "
|
|
f"{len(diff['modified'])} modified, "
|
|
f"{len(diff['deleted'])} deleted)"
|
|
)
|
|
|
|
total_registries += 1
|
|
total_rows += row_count
|
|
|
|
except (ValidationError, ValueError) as e:
|
|
print_error(
|
|
f"Validation failed for {table_name}: {e}\n"
|
|
f"IANA data structure may have changed. "
|
|
f"Please open an issue at the project repository."
|
|
)
|
|
conn.rollback()
|
|
conn.close()
|
|
return 1
|
|
|
|
conn.commit()
|
|
logger.info(
|
|
f"Successfully updated {total_registries} registries "
|
|
f"({total_rows} total rows)"
|
|
)
|
|
|
|
except sqlite3.Error as e:
|
|
print_error(f"Database error: {e}")
|
|
conn.rollback()
|
|
conn.close()
|
|
return 1
|
|
|
|
finally:
|
|
conn.close()
|
|
|
|
return 0
|