'''
XML parsing and SQLite database utilities for ClinVar Build.
This module provides classes and functions for parsing large XML clinvar files
and loading them into SQLite databases. It includes utilities for database
validation, progress tracking, error formatting, and XML inspection.
'''
import gzip
import sqlite3
import logging
import warnings
import xml.etree.ElementTree as ET
from pathlib import Path
from collections import Counter
from typing import (
Any,
Iterator,
)
from clinvar_build.errors import (
is_type,
XMLValidationError,
)
from contextlib import contextmanager
from clinvar_build.errors import (
is_type,
)
from clinvar_build.utils.config_tools import (
ProgressHandler,
)
from lxml import etree
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
LXML_ERROR = 'expected'
WARN1 = "XML file is valid (ignoring XML elements not in the supplied XSD)."
logger = logging.getLogger(__name__)
[docs]
class SQLiteParser(object):
"""
A general SQLite parser.
Parameters
----------
config : dict[str, Any]
A dictionary with instructions to parser an XML file to a SQLite
database.
Attributes
----------
conn : `sqlite3.Connection` or `None`
Active SQLite connection, or None if not connected.
cursor : `sqlite3.Cursor` or `None`
Cursor for executing SQL statements, or None if not connected.
stats : `dict` [`str`, `int`]
Statistics on parsed records.
config : `dict` [`str`, `any`]
A configuration dictionary with parsing instructions.
"""
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# class attributes
_LINE_LEN = 70
_progress_logger = None
_progress_handler = None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def __init__(self, config:dict[str, Any] | None = None):
"""Initialize parser."""
is_type(config, (type(None),dict))
# #### recording the table names
# inserted rows
config = config if config is not None else {}
stats = {k: 0 for k in config.keys()}
# #### Set to self
self.conn = None
self.cursor = None
self.stats = stats
self.config = config
# #### Set up progress logger for in-place updates (once per class)
if type(self)._progress_logger is None:
type(self)._progress_logger = logging.getLogger('parser.progress')
type(self)._progress_logger.setLevel(logging.INFO)
type(self)._progress_handler = ProgressHandler()
type(self)._progress_handler.setFormatter(
logging.Formatter(
'%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S'
)
)
type(self)._progress_logger.addHandler(
type(self)._progress_handler
)
# Dont propagate to root logger
type(self)._progress_logger.propagate = False
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def __repr__(self) -> str:
"""
Return unambiguous string representation.
Returns
-------
str
String representation suitable for debugging
"""
connection_status = (
f"connected to {self.conn}"
if self.conn
else "not connected"
)
return f"{type(self).__name__}({connection_status})"
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def __str__(self) -> str:
"""
Return human-readable string representation.
Returns
-------
str
Human-readable description of the parser
"""
if self.conn:
return f"{type(self).__name__} (active connection)"
config_count = len([k for k in self.config.keys()
if not k.startswith('_')])
return (
f"{type(self).__name__}\n"
f"SQLite parser configured for {config_count} tables."
)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def validate_database(self) -> dict[str, Any]:
"""
Run comprehensive database validation checks.
Performs multiple validation checks including foreign key integrity,
table row counts, and basic statistics.
Returns
-------
dict [str, Any]
Dictionary containing validation results and statistics
Raises
------
ValueError
If foreign key violations are found
Examples
--------
>>> with parser._connection(db_path):
... results = parser.validate_database()
>>> print(results)
{'foreign_keys': 'valid', 'total_records': 12345, ...}
"""
# empty results
results = {}
# foreign key check
logger.info(" ")
logger.info("=" * type(self)._LINE_LEN)
logger.info("STEP 1/2: Foreign Key Integrity Check")
logger.info("=" * type(self)._LINE_LEN)
self._validate_foreign_keys()
results['foreign_keys'] = 'valid'
# count records in each table
logger.info(" ")
logger.info("=" * type(self)._LINE_LEN)
logger.info("STEP 2/2: Counting Records per Table")
logger.info("=" * type(self)._LINE_LEN)
table_counts = self._count_table_records()
# adding to results
results['table_counts'] = table_counts
results['total_records'] = sum(table_counts.values())
# logging
logger.info(" ")
logger.info("=" * type(self)._LINE_LEN)
logger.info(f"✓ Database validation complete")
logger.info(
f" Total records across all tables: {results['total_records']:,}"
)
logger.info("=" * type(self)._LINE_LEN)
return results
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def count_duplicates(self):
"""
Count duplicate rows in all tables based on non-id columns.
For each table in the configuration, identifies duplicate
rows where all columns (except the primary key 'id') are identical,
treating NULL values as equal.
Returns
-------
dict [str, int]
Dictionary mapping table names to number of duplicates found
Notes
-----
Uses SQLite's rowid to identify duplicates, counting rows that would
be removed (keeping the row with lowest rowid). NULL values are
treated as equal via GROUP BY.
"""
logger.info("")
logger.info("=" * type(self)._LINE_LEN)
logger.info("DUPLICATE CHECK: Counting duplicate rows")
logger.info("=" * type(self)._LINE_LEN)
# Init
duplicates_found = {}
# loop over tables
for table in [k for k in self.config.keys() if not k.startswith('_')]:
try:
self.cursor.execute(f"PRAGMA table_info({table})")
cols = ','.join(
[r[1] for r in self.cursor.fetchall() if r[1] != 'id']
)
if cols:
# count rows that would be deleted (not the keeper rows)
self.cursor.execute(f"""
SELECT COUNT(*) FROM {table}
WHERE rowid NOT IN (
SELECT MIN(rowid) FROM {table} GROUP BY {cols}
)
""")
duplicates_found[table] = self.cursor.fetchone()[0]
except sqlite3.OperationalError:
duplicates_found[table] = 0
# log only tables with duplicates
for table, count in sorted(duplicates_found.items()):
if count > 0:
logger.info(f" {table}: {count:,} duplicates")
total = sum(duplicates_found.values())
logger.info("")
logger.info(f"Found {total:,} duplicates across "
f"{sum(1 for v in duplicates_found.values() if v > 0)} tables")
logger.info("")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def _connect(self, db_path: str | Path) -> None:
"""
Establish database connection.
Parameters
----------
db_path : str or Path
Path to SQLite database file
"""
self.conn = sqlite3.connect(db_path)
self.cursor = self.conn.cursor()
# NOTE: enable foreign key enforcement: prevents inserting records with
# non-existent parent IDs, cascades deletes to child records, and
# maintains referential integrity
self.cursor.execute("PRAGMA foreign_keys = ON")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def _close(self) -> None:
"""Close database connection."""
if self.conn:
self.conn.commit()
self.conn.close()
self.conn = None
self.cursor = None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@contextmanager
def _connection(self, db_path: str | Path) -> Iterator[None]:
"""
Context manager for SQLite database connection.
Establishes a connection to the SQLite database at `db_path`
and ensures the connection is properly closed after use, even
if an exception occurs during operations. Any uncommitted
transactions are rolled back on error or interrupt.
Parameters
----------
db_path : str or Path
Path to the SQLite database file
Yields
------
None
Provides a context in which the database connection is active.
`self.conn` and `self.cursor` are available for executing SQL
statements within the context
"""
self._connect(db_path)
try:
yield
except (Exception, KeyboardInterrupt):
if self.conn:
self.conn.rollback()
logger.info("Rolled back uncommitted changes")
raise
finally:
self._close()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def _validate_foreign_keys(self) -> None:
"""
Validate all foreign key constraints after parsing.
Checks every foreign key relationship in the database to ensure
referential integrity. Called after re-enabling foreign key
enforcement following parsing.
Raises
------
ValueError
If any foreign key violations are found
"""
logger.info("Running foreign key integrity check...")
self.cursor.execute("PRAGMA foreign_key_check")
violations = self.cursor.fetchall()
if violations:
logger.error(f"Found {len(violations)} foreign key violations!")
logger.error("First 10 violations:")
for i, violation in enumerate(violations[:10], 1):
# violation format: (table, rowid, parent_table, fkid)
logger.error(
f" {i}. Table '{violation[0]}' row {violation[1]} "
f"references non-existent parent in '{violation[2]}'"
)
raise ValueError(
f"Foreign key integrity check failed! "
f"Found {len(violations)} violations."
)
logger.info("✓ Foreign key integrity validated - all constraints satisfied")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def _count_table_records(self) -> dict[str, int]:
"""
Count records in all configured tables.
Iterates through all tables in the configuration and counts
the number of records in each. Updates the parser's stats
dictionary with the counts.
Returns
-------
dict [str, int]
Dictionary mapping table names to record counts
Notes
-----
Tables that do not exist will have a count of 0. The method
also updates `self.stats` with the counts for later reporting.
"""
table_counts = {}
# filter out meta keys
config_tables = [
k for k in self.config.keys() if not k.startswith('_')
]
total_tables = len(config_tables)
# loop over tables
for idx, table_name in enumerate(config_tables, 1):
try:
logger.info(
f" [{idx:2d}/{total_tables}] Counting {table_name}..."
)
self.cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
count = self.cursor.fetchone()[0]
table_counts[table_name] = count
# add to stats as well, for final report
self.stats[table_name] = count
if count > 0:
logger.info(f" → {count:,} records")
except sqlite3.OperationalError as e:
# table might not exist
logger.warning(f" → Table does not exist: {e}")
table_counts[table_name] = 0
return table_counts
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def _format_sqlite_error(self, e: sqlite3.IntegrityError, elem: ET.Element,
sql: str, keys: list[str], values: tuple,
accession: str | int | None = None) -> str:
"""
Format SQLite IntegrityError with detailed context for debugging.
Parameters
----------
e : `sqlite3.IntegrityError`
The SQLite integrity error that occurred
elem : `ET.Element`
The XML element being parsed when error occurred
sql : `str`
The SQL statement that failed
keys : `list` [`str`]
Column names from the SQL statement
values : `tuple`
Values that were being inserted
accession : `int` or `None`, default `None`
The XML accession number
Returns
-------
str
Formatted error message with context and suggested fixes
"""
error_msg = str(e)
# Build detailed error message
msg_parts = [
f"Database constraint violation while inserting {elem.tag}:",
]
if accession:
msg_parts.append(f" Accession: {accession}")
# Add error message
msg_parts.append(f" Error: {error_msg}")
# Check error type and add specific details
if "NOT NULL constraint failed" in error_msg:
# Extract column name
failed_column =(
error_msg.split(":")[-1].strip()
if ":" in error_msg
else "unknown"
)
msg_parts.extend([
f" Failed column: {failed_column}",
f" SQL: {sql}",
f" Values provided: {dict(zip(keys, values))}",
f" Element attributes: {dict(elem.attrib)}",
"",
"Possible fixes:",
" 1. Check 'attributes' config - verify xml_attr/xml_path exists",
" 2. Make column nullable in schema (remove NOT NULL)",
" 3. Verify xpath in 'children' config finds the element",
])
elif "FOREIGN KEY constraint failed" in error_msg:
# foreign key error
msg_parts.extend([
f" SQL: {sql}",
f" Values: {dict(zip(keys, values))}",
"",
"Parent record was not inserted or has wrong ID.",
"Possible fixes:",
" 1. Check 'parent_id' in config matches parent's 'returns_id'",
" 2. Verify 'children' xpath correctly locates child elements",
" 3. For polymorphic tables, ensure 'entity_type' is set in child spec",
" 4. Check parent table has 'returns_id' configured",
])
elif "UNIQUE constraint failed" in error_msg:
# unique
failed_columns = (
error_msg.split(":")[-1].strip()
if ":" in error_msg
else "unknown"
)
msg_parts.extend([
f" Duplicate columns: {failed_columns}",
f" SQL: {sql}",
f" Values: {dict(zip(keys, values))}",
"",
"This record already exists in the database.",
"Possible fixes:",
" 1. Set 'ignore_duplicates': true in config sql block",
" 2. Check if XML contains duplicate records for this accession",
" 3. Verify 'children' xpath is not matching same element twice",
" 4. Clear database before re-parsing",
])
elif "CHECK constraint failed" in error_msg:
# constraint
msg_parts.extend([
f" SQL: {sql}",
f" Values: {dict(zip(keys, values))}",
f" Element attributes: {dict(elem.attrib)}",
"",
"A CHECK constraint was violated - value outside allowed range.",
"Possible fixes:",
" 1. Verify 'cast' type in attributes config (int/float/bool/str)",
" 2. Check xml_attr/xml_path extracts the correct value",
" 3. Review CHECK constraint in schema matches expected XML values",
" 4. Add validation in _extract_value for this field",
])
else:
# Generic integrity error
msg_parts.extend([
f" SQL: {sql}",
f" Values: {dict(zip(keys, values))}",
])
return "\n".join(msg_parts)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# At module level in your logging_config module
TRACE = 5
logging.addLevelName(TRACE, "TRACE")
[docs]
def trace(self, message, *args, **kwargs):
"""custom trace method for detailed logging."""
if self.isEnabledFor(TRACE):
self._log(TRACE, message, args, **kwargs)
# add the trace method to Logger class at import time
logging.Logger.trace = trace
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
class ViewXML(object):
"""
Class to view and inspect XML records.
This class loads a single XML element and provides utilities for
inspecting and navigating its structure, particularly useful for
large ClinVar XML files.
Parameters
----------
xml_path : `str` or `Path`
Path to ClinVar XML file (supports .gz compression)
tag_name : `str`
XML tag name to search for (e.g., 'VariationArchive')
index : `int`, default 0
Which occurrence of the tag to load (0-indexed)
Attributes
----------
xml_path : `Path`
Path to the source XML file
tag_name : `str`
The tag name that was searched for
index : `int`
The index of the loaded element
element : `ET.Element` or `None`
The loaded XML element, or None if not found
Methods
-------
show_tree(max_depth=4, indent=0)
Display element tree structure up to specified depth
show_children(indent=0)
Display immediate children with summary information
find_all_paths()
Find all unique XPath-like paths in the element tree
count_all_tags()
Count all descendant tags by frequency
Examples
--------
>>> viewer = ViewXML('clinvar.xml.gz', 'VariationArchive', index=5)
>>> viewer.show_tree()
>>> paths = viewer.find_all_paths()
"""
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def __init__(self, xml_path: str | Path, tag_name: str, index: int = 0):
"""
Initialise ViewXML with a single XML element.
"""
is_type(xml_path, (str, Path))
is_type(tag_name, str)
is_type(index, int)
self.xml_path = Path(xml_path)
self.tag_name = tag_name
self.index = index
self.element = self._load_element()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def __str__(self) -> str:
"""
Return human-readable string representation.
Returns
-------
str
Summary of the loaded element
"""
NAME = type(self).__name__
if self.element is None:
return (f"{NAME}(tag='{self.tag_name}', index={self.index}, "
f"element=None)")
n_children = len(list(self.element))
n_attrs = len(self.element.attrib)
return (f"{NAME}(tag='{self.tag_name}', index={self.index}, "
f"children={n_children}, attrs={n_attrs})")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def __repr__(self) -> str:
"""
Return unambiguous string representation.
Returns
-------
str
String that could recreate the object
"""
NAME = type(self).__name__
return (f"{NAME}(xml_path='{self.xml_path}', "
f"tag_name='{self.tag_name}', index={self.index})")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def show_children(self, indent: int = 0) -> None:
"""
Display immediate children of loaded element with summary info.
Parameters
----------
indent : `int`, default 0
Indentation level for output formatting
Raises
------
ValueError
If no element is loaded
"""
is_type(indent, int)
if self.element is None:
raise ValueError("no element loaded. check xml path and tag name")
# get the element tags and loop over children
elem = self.element
print(f"{' ' * indent}<{elem.tag}> has "
f"{len(list(elem))} children:")
for child in elem:
attr_preview = dict(list(child.attrib.items())[:2])
text = (child.text.strip()[:30]
if child.text and child.text.strip() else "")
n_children = len(list(child))
print(f"{' ' * (indent+1)}<{child.tag}> "
f"attrs={attr_preview} text={text!r} "
f"children={n_children}")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def show_tree(self, max_depth=4, indent=0):
""""
Display element tree structure up to specified depth.
Parameters
----------
max_depth : `int`, default 4
Maximum depth to display
indent : `int`, default 0
Starting indentation level
Raises
------
ValueError
If no element is loaded
"""
is_type(max_depth, int)
is_type(indent, int)
if max_depth < 1:
raise ValueError(f"max_depth must be 1 or greater, got "
f"{max_depth}")
if indent < 0:
raise ValueError(f"indent must be non-negative, got {indent}")
if self.element is None:
raise ValueError("no element loaded. check xml path and tag "
"name")
# recurse
self._show_tree_recursive(self.element, max_depth - 1, indent + 1)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def find_all_paths(self) -> list[str]:
"""
Find all unique XPath-like paths in the loaded element tree.
Returns
-------
list of str
Sorted list of all unique paths found in the tree
Raises
------
ValueError
If no element is loaded
Examples
--------
>>> viewer = ViewXML('clinvar.xml.gz', 'VariationArchive')
>>> paths = viewer.find_all_paths()
>>> print(paths[:5])
"""
if self.element is None:
raise ValueError("no element loaded. check xml path and tag name")
paths = set()
self._find_paths_recursive(self.element, "", paths)
return sorted(paths)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def _load_element(self) -> ET.Element | None:
"""
Load the specified element from XML file.
Returns
-------
ET.Element or None
The requested XML element, or None if not found
"""
# handle gzipped files
try:
fh = gzip.open(self.xml_path, 'rt', encoding='utf-8')
except OSError:
fh = open(self.xml_path, 'r', encoding='utf-8')
# Iterate across the file
try:
# get full records
context = ET.iterparse(fh, events=('end',))
count = 0
for _, elem in context:
if elem.tag == self.tag_name:
if count == self.index:
return elem
# not the correct record so increasing the count
count += 1
elem.clear()
return None
finally:
fh.close()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def _find_paths_recursive(self, elem: ET.Element, current_path: str,
paths: set) -> None:
"""
Recursive helper for find_all_paths.
Parameters
----------
elem : ET.Element
Current element being processed
current_path : str
Path accumulated so far
paths : set
Set to collect unique paths
"""
current_path = f"{current_path}/{elem.tag}"
paths.add(current_path)
for child in elem:
self._find_paths_recursive(child, current_path, paths)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def _show_tree_recursive(self, elem: ET.Element, max_depth: int,
indent: int) -> None:
"""
Recursive helper for show_tree.
Parameters
----------
elem : ET.Element
Current element to display
max_depth : int
Remaining depth to display
indent : int
Current indentation level
"""
# Build info string
n_children = len(list(elem))
attr_count = len(elem.attrib)
text = (
elem.text.strip()[:20]
if elem.text and elem.text.strip()
else ""
)
info = f"<{elem.tag}>"
if attr_count:
info += f" [{attr_count} attrs]"
if text:
info += f' "{text}..."'
if n_children:
info += f" ({n_children} children)"
print(f"{' ' * indent}{info}")
if max_depth < 1:
return
for child in elem:
self._show_tree_recursive(child, max_depth - 1, indent + 1)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
@contextmanager
def open_xml_file(xml_path: str | Path,
xsd_path: str | Path | None = None,
strict: bool = True,
verbose: bool = True):
"""
Helper function to open an XML file, handling compression based on
extension.
Automatically closes the file handle when exiting the context, even
if an exception occurs.
Parameters
----------
xml_path : `str`
The path to the file.
xsd_path : `str`, `Path`, or `None`, default `None`
Optional path to XSD schema file for validation. If provided,
validates XML before yielding the file handle.
strict : `bool`, default `True`
If False, ignores elements in the XML that are not in the XSD.
Only used when xsd_path is provided.
verbose : `bool`, default `True`
If True, warns about validation issues when strict=False.
Only used when xsd_path is provided.
Yields
------
file-like object
An open file handle for the compressed or uncompressed XML.
Raises
------
FileNotFoundError
If xml_path or xsd_path does not exist
IOError
If file cannot be opened
XMLValidationError
If XSD validation fails (when xsd_path is provided)
"""
# Validate XML against XSD if provided
if xsd_path is not None:
validate_xml(xml_path, xsd_path, strict=strict, verbose=verbose)
file_handle = None
try:
# Check if the file is compressed
if str(xml_path).lower().endswith('.gz'):
# 'rb' mode for reading binary compressed data
file_handle = gzip.open(xml_path, 'rb')
else:
# 'r' mode for reading uncompressed data
file_handle = open(xml_path, 'r', encoding='utf-8')
yield file_handle
finally:
# close if there is a an error
if file_handle is not None:
file_handle.close()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def validate_xml(xml_path:str|Path, xsd_path:str|Path, strict:bool=True,
verbose:bool=True,
) -> etree._ElementTree:
"""
Validates an XML file against an XSD schema.
Parameters
----------
xml_path : `str` or `Path`
Path to the XML file.
xsd_path : `str` or `Path`
Path to the XSD file.
strict : `bool`, default `True`
If False, ignores elements in the XML that are not in the XSD.
Returns
-------
etree._ElementTree
The parsed XML document.
Raises
------
XMLValidationError
Raised if the XSD and XML are incompatible.
"""
# #### Check input
is_type(strict, bool)
is_type(verbose, bool)
is_type(xml_path, (str, Path))
is_type(xsd_path, (str, Path))
# #### Load the XSD schema
with open(xsd_path, 'rb') as xsd:
schema_root = etree.XML(xsd.read())
schema = etree.XMLSchema(schema_root)
# #### Parse the XML file
try:
with open(xml_path, 'rb') as xml:
xml_doc = etree.parse(xml)
except etree.XMLSyntaxError:
parser = etree.XMLParser(resolve_entities=False, huge_tree=True)
with open(xml_path, 'rb') as xml:
xml_doc = etree.parse(xml, parser=parser,)
# #### Compare XML against XSD
is_valid = schema.validate(xml_doc)
# #### Do we want to raise errors when XML elements are absent in the XSD
if not is_valid and not strict:
# if strict is False remove error relating to missing XSD elements
errors = [err.message for err in schema.error_log]
filtered_errs = [err for err in schema.error_log if\
not LXML_ERROR in err.message]
# check if there are any other type of errors
is_valid = len(filtered_errs) == 0
if is_valid:
# Where there any errors to begin with
if verbose == True and len(schema.error_log) > 0:
warnings.warn(WARN1)
else:
errors = [err.message for err in filtered_errs]
raise XMLValidationError(errors)
elif is_valid == False:
errors = [err.message for err in schema.error_log]
raise XMLValidationError(errors)
# return xml_doc
return xml_doc