feature: IANA update
This commit is contained in:
248
src/sslysze_scan/commands/update_iana.py
Normal file
248
src/sslysze_scan/commands/update_iana.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Update IANA command handler."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from urllib.error import URLError
|
||||
from urllib.request import urlopen
|
||||
|
||||
from ..iana_parser import (
|
||||
extract_updated_date,
|
||||
find_registry,
|
||||
get_table_name_from_filename,
|
||||
parse_xml_with_namespace_support,
|
||||
)
|
||||
from ..iana_validator import (
|
||||
ValidationError,
|
||||
normalize_header,
|
||||
validate_headers,
|
||||
validate_registry_data,
|
||||
)
|
||||
from ..output import print_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fetch_xml_from_url(url: str, timeout: int = 30) -> str:
|
||||
"""Fetch XML content from URL.
|
||||
|
||||
Args:
|
||||
url: URL to fetch
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
XML content as string
|
||||
|
||||
Raises:
|
||||
URLError: If URL cannot be fetched
|
||||
|
||||
"""
|
||||
logger.info(f"Fetching {url}")
|
||||
with urlopen(url, timeout=timeout) as response:
|
||||
return response.read().decode("utf-8")
|
||||
|
||||
|
||||
def calculate_diff(
|
||||
old_rows: list[tuple],
|
||||
new_rows: list[tuple],
|
||||
pk_index: int = 0,
|
||||
) -> dict[str, list]:
|
||||
"""Calculate diff between old and new data.
|
||||
|
||||
Args:
|
||||
old_rows: Existing rows from DB
|
||||
new_rows: New rows from XML
|
||||
pk_index: Index of primary key column
|
||||
|
||||
Returns:
|
||||
Dict with 'added', 'deleted', 'modified' lists of primary keys
|
||||
|
||||
"""
|
||||
old_dict = {row[pk_index]: row for row in old_rows}
|
||||
new_dict = {row[pk_index]: row for row in new_rows}
|
||||
|
||||
added = [k for k in new_dict if k not in old_dict]
|
||||
deleted = [k for k in old_dict if k not in new_dict]
|
||||
modified = [k for k in new_dict if k in old_dict and old_dict[k] != new_dict[k]]
|
||||
|
||||
return {"added": added, "deleted": deleted, "modified": modified}
|
||||
|
||||
|
||||
def process_registry_with_validation(
|
||||
xml_content: str,
|
||||
registry_id: str,
|
||||
table_name: str,
|
||||
headers: list[str],
|
||||
db_conn: sqlite3.Connection,
|
||||
skip_min_rows_check: bool = False,
|
||||
) -> tuple[int, dict[str, list]]:
|
||||
"""Process registry with validation and diff calculation.
|
||||
|
||||
Args:
|
||||
xml_content: XML content as string
|
||||
registry_id: Registry ID to extract
|
||||
table_name: Database table name
|
||||
headers: List of column headers
|
||||
db_conn: Database connection
|
||||
skip_min_rows_check: Skip minimum rows validation (for tests)
|
||||
|
||||
Returns:
|
||||
Tuple of (row_count, diff_dict)
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
ValueError: If registry not found
|
||||
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".xml", delete=False, encoding="utf-8"
|
||||
) as tmp_file:
|
||||
tmp_file.write(xml_content)
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
try:
|
||||
root, ns = parse_xml_with_namespace_support(tmp_path)
|
||||
finally:
|
||||
Path(tmp_path).unlink()
|
||||
|
||||
validate_headers(table_name, headers, db_conn)
|
||||
|
||||
registry = find_registry(root, registry_id, ns)
|
||||
|
||||
if ns:
|
||||
records = registry.findall("iana:record", ns)
|
||||
else:
|
||||
records = registry.findall("record")
|
||||
|
||||
from ..iana_parser import extract_field_value, is_unassigned
|
||||
|
||||
rows_dict = []
|
||||
for record in records:
|
||||
if is_unassigned(record, ns):
|
||||
continue
|
||||
row_dict = {}
|
||||
for header in headers:
|
||||
normalized_key = normalize_header(header)
|
||||
row_dict[normalized_key] = extract_field_value(record, header, ns)
|
||||
rows_dict.append(row_dict)
|
||||
|
||||
validate_registry_data(table_name, rows_dict, skip_min_rows_check)
|
||||
|
||||
rows = [tuple(row.values()) for row in rows_dict]
|
||||
|
||||
cursor = db_conn.cursor()
|
||||
old_rows = cursor.execute(f"SELECT * FROM {table_name}").fetchall()
|
||||
|
||||
diff = calculate_diff(old_rows, rows)
|
||||
|
||||
placeholders = ",".join(["?"] * len(headers))
|
||||
cursor.execute(f"DELETE FROM {table_name}")
|
||||
cursor.executemany(f"INSERT INTO {table_name} VALUES ({placeholders})", rows)
|
||||
|
||||
return len(rows), diff
|
||||
|
||||
|
||||
def handle_update_iana_command(args: argparse.Namespace) -> int:
|
||||
"""Handle the update-iana subcommand.
|
||||
|
||||
Args:
|
||||
args: Parsed arguments
|
||||
|
||||
Returns:
|
||||
Exit code (0 for success, 1 for error)
|
||||
|
||||
"""
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
db_path = args.database
|
||||
|
||||
if not Path(db_path).exists():
|
||||
print_error(f"Database not found: {db_path}")
|
||||
return 1
|
||||
|
||||
script_dir = Path(__file__).parent.parent
|
||||
config_path = script_dir / "data" / "iana_parse.json"
|
||||
|
||||
logger.info(f"Loading configuration from {config_path}")
|
||||
|
||||
try:
|
||||
with config_path.open(encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError, OSError) as e:
|
||||
print_error(f"Error loading configuration: {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
except sqlite3.Error as e:
|
||||
print_error(f"Error opening database: {e}")
|
||||
return 1
|
||||
|
||||
logger.info("Starting IANA registry update")
|
||||
|
||||
try:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
|
||||
total_registries = 0
|
||||
total_rows = 0
|
||||
|
||||
for url, registries in config.items():
|
||||
try:
|
||||
xml_content = fetch_xml_from_url(url)
|
||||
except (URLError, OSError) as e:
|
||||
print_error(f"Failed to fetch {url}: {e}")
|
||||
conn.rollback()
|
||||
conn.close()
|
||||
return 1
|
||||
|
||||
xml_date = extract_updated_date(xml_content)
|
||||
logger.info(f"XML data date: {xml_date}")
|
||||
|
||||
for registry_id, output_filename, headers in registries:
|
||||
table_name = get_table_name_from_filename(output_filename)
|
||||
|
||||
try:
|
||||
row_count, diff = process_registry_with_validation(
|
||||
xml_content, registry_id, table_name, headers, conn
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{table_name}: {row_count} rows "
|
||||
f"({len(diff['added'])} added, "
|
||||
f"{len(diff['modified'])} modified, "
|
||||
f"{len(diff['deleted'])} deleted)"
|
||||
)
|
||||
|
||||
total_registries += 1
|
||||
total_rows += row_count
|
||||
|
||||
except (ValidationError, ValueError) as e:
|
||||
print_error(
|
||||
f"Validation failed for {table_name}: {e}\n"
|
||||
f"IANA data structure may have changed. "
|
||||
f"Please open an issue at the project repository."
|
||||
)
|
||||
conn.rollback()
|
||||
conn.close()
|
||||
return 1
|
||||
|
||||
conn.commit()
|
||||
logger.info(
|
||||
f"Successfully updated {total_registries} registries "
|
||||
f"({total_rows} total rows)"
|
||||
)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
print_error(f"Database error: {e}")
|
||||
conn.rollback()
|
||||
conn.close()
|
||||
return 1
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return 0
|
||||
Reference in New Issue
Block a user