Source code for clinvar_build.utils.parser_tools

'''
XML parsing and SQLite database utilities for ClinVar Build.

This module provides classes and functions for parsing large XML clinvar files
and loading them into SQLite databases. It includes utilities for database
validation, progress tracking, error formatting, and XML inspection.
'''
import gzip
import sqlite3
import logging
import warnings
import xml.etree.ElementTree as ET
from pathlib import Path
from collections import Counter
from typing import (
    Any,
    Iterator,
)
from clinvar_build.errors import (
    is_type,
    XMLValidationError,
)
from contextlib import contextmanager
from clinvar_build.errors import (
    is_type,
)
from clinvar_build.utils.config_tools import (
    ProgressHandler,
)
from lxml import etree

# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@

LXML_ERROR = 'expected'
WARN1 = "XML file is valid (ignoring XML elements not in the supplied XSD)."

logger = logging.getLogger(__name__)

[docs] class SQLiteParser(object): """ A general SQLite parser. Parameters ---------- config : dict[str, Any] A dictionary with instructions to parser an XML file to a SQLite database. Attributes ---------- conn : `sqlite3.Connection` or `None` Active SQLite connection, or None if not connected. cursor : `sqlite3.Cursor` or `None` Cursor for executing SQL statements, or None if not connected. stats : `dict` [`str`, `int`] Statistics on parsed records. config : `dict` [`str`, `any`] A configuration dictionary with parsing instructions. """ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # class attributes _LINE_LEN = 70 _progress_logger = None _progress_handler = None # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def __init__(self, config:dict[str, Any] | None = None): """Initialize parser.""" is_type(config, (type(None),dict)) # #### recording the table names # inserted rows config = config if config is not None else {} stats = {k: 0 for k in config.keys()} # #### Set to self self.conn = None self.cursor = None self.stats = stats self.config = config # #### Set up progress logger for in-place updates (once per class) if type(self)._progress_logger is None: type(self)._progress_logger = logging.getLogger('parser.progress') type(self)._progress_logger.setLevel(logging.INFO) type(self)._progress_handler = ProgressHandler() type(self)._progress_handler.setFormatter( logging.Formatter( '%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) ) type(self)._progress_logger.addHandler( type(self)._progress_handler ) # Dont propagate to root logger type(self)._progress_logger.propagate = False
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def __repr__(self) -> str: """ Return unambiguous string representation. Returns ------- str String representation suitable for debugging """ connection_status = ( f"connected to {self.conn}" if self.conn else "not connected" ) return f"{type(self).__name__}({connection_status})"
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def __str__(self) -> str: """ Return human-readable string representation. Returns ------- str Human-readable description of the parser """ if self.conn: return f"{type(self).__name__} (active connection)" config_count = len([k for k in self.config.keys() if not k.startswith('_')]) return ( f"{type(self).__name__}\n" f"SQLite parser configured for {config_count} tables." )
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def validate_database(self) -> dict[str, Any]: """ Run comprehensive database validation checks. Performs multiple validation checks including foreign key integrity, table row counts, and basic statistics. Returns ------- dict [str, Any] Dictionary containing validation results and statistics Raises ------ ValueError If foreign key violations are found Examples -------- >>> with parser._connection(db_path): ... results = parser.validate_database() >>> print(results) {'foreign_keys': 'valid', 'total_records': 12345, ...} """ # empty results results = {} # foreign key check logger.info(" ") logger.info("=" * type(self)._LINE_LEN) logger.info("STEP 1/2: Foreign Key Integrity Check") logger.info("=" * type(self)._LINE_LEN) self._validate_foreign_keys() results['foreign_keys'] = 'valid' # count records in each table logger.info(" ") logger.info("=" * type(self)._LINE_LEN) logger.info("STEP 2/2: Counting Records per Table") logger.info("=" * type(self)._LINE_LEN) table_counts = self._count_table_records() # adding to results results['table_counts'] = table_counts results['total_records'] = sum(table_counts.values()) # logging logger.info(" ") logger.info("=" * type(self)._LINE_LEN) logger.info(f"✓ Database validation complete") logger.info( f" Total records across all tables: {results['total_records']:,}" ) logger.info("=" * type(self)._LINE_LEN) return results
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def count_duplicates(self): """ Count duplicate rows in all tables based on non-id columns. For each table in the configuration, identifies duplicate rows where all columns (except the primary key 'id') are identical, treating NULL values as equal. Returns ------- dict [str, int] Dictionary mapping table names to number of duplicates found Notes ----- Uses SQLite's rowid to identify duplicates, counting rows that would be removed (keeping the row with lowest rowid). NULL values are treated as equal via GROUP BY. """ logger.info("") logger.info("=" * type(self)._LINE_LEN) logger.info("DUPLICATE CHECK: Counting duplicate rows") logger.info("=" * type(self)._LINE_LEN) # Init duplicates_found = {} # loop over tables for table in [k for k in self.config.keys() if not k.startswith('_')]: try: self.cursor.execute(f"PRAGMA table_info({table})") cols = ','.join( [r[1] for r in self.cursor.fetchall() if r[1] != 'id'] ) if cols: # count rows that would be deleted (not the keeper rows) self.cursor.execute(f""" SELECT COUNT(*) FROM {table} WHERE rowid NOT IN ( SELECT MIN(rowid) FROM {table} GROUP BY {cols} ) """) duplicates_found[table] = self.cursor.fetchone()[0] except sqlite3.OperationalError: duplicates_found[table] = 0 # log only tables with duplicates for table, count in sorted(duplicates_found.items()): if count > 0: logger.info(f" {table}: {count:,} duplicates") total = sum(duplicates_found.values()) logger.info("") logger.info(f"Found {total:,} duplicates across " f"{sum(1 for v in duplicates_found.values() if v > 0)} tables") logger.info("")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def _connect(self, db_path: str | Path) -> None: """ Establish database connection. Parameters ---------- db_path : str or Path Path to SQLite database file """ self.conn = sqlite3.connect(db_path) self.cursor = self.conn.cursor() # NOTE: enable foreign key enforcement: prevents inserting records with # non-existent parent IDs, cascades deletes to child records, and # maintains referential integrity self.cursor.execute("PRAGMA foreign_keys = ON") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def _close(self) -> None: """Close database connection.""" if self.conn: self.conn.commit() self.conn.close() self.conn = None self.cursor = None # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @contextmanager def _connection(self, db_path: str | Path) -> Iterator[None]: """ Context manager for SQLite database connection. Establishes a connection to the SQLite database at `db_path` and ensures the connection is properly closed after use, even if an exception occurs during operations. Any uncommitted transactions are rolled back on error or interrupt. Parameters ---------- db_path : str or Path Path to the SQLite database file Yields ------ None Provides a context in which the database connection is active. `self.conn` and `self.cursor` are available for executing SQL statements within the context """ self._connect(db_path) try: yield except (Exception, KeyboardInterrupt): if self.conn: self.conn.rollback() logger.info("Rolled back uncommitted changes") raise finally: self._close() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def _validate_foreign_keys(self) -> None: """ Validate all foreign key constraints after parsing. Checks every foreign key relationship in the database to ensure referential integrity. Called after re-enabling foreign key enforcement following parsing. Raises ------ ValueError If any foreign key violations are found """ logger.info("Running foreign key integrity check...") self.cursor.execute("PRAGMA foreign_key_check") violations = self.cursor.fetchall() if violations: logger.error(f"Found {len(violations)} foreign key violations!") logger.error("First 10 violations:") for i, violation in enumerate(violations[:10], 1): # violation format: (table, rowid, parent_table, fkid) logger.error( f" {i}. Table '{violation[0]}' row {violation[1]} " f"references non-existent parent in '{violation[2]}'" ) raise ValueError( f"Foreign key integrity check failed! " f"Found {len(violations)} violations." ) logger.info("✓ Foreign key integrity validated - all constraints satisfied") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def _count_table_records(self) -> dict[str, int]: """ Count records in all configured tables. Iterates through all tables in the configuration and counts the number of records in each. Updates the parser's stats dictionary with the counts. Returns ------- dict [str, int] Dictionary mapping table names to record counts Notes ----- Tables that do not exist will have a count of 0. The method also updates `self.stats` with the counts for later reporting. """ table_counts = {} # filter out meta keys config_tables = [ k for k in self.config.keys() if not k.startswith('_') ] total_tables = len(config_tables) # loop over tables for idx, table_name in enumerate(config_tables, 1): try: logger.info( f" [{idx:2d}/{total_tables}] Counting {table_name}..." ) self.cursor.execute(f"SELECT COUNT(*) FROM {table_name}") count = self.cursor.fetchone()[0] table_counts[table_name] = count # add to stats as well, for final report self.stats[table_name] = count if count > 0: logger.info(f" → {count:,} records") except sqlite3.OperationalError as e: # table might not exist logger.warning(f" → Table does not exist: {e}") table_counts[table_name] = 0 return table_counts # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def _format_sqlite_error(self, e: sqlite3.IntegrityError, elem: ET.Element, sql: str, keys: list[str], values: tuple, accession: str | int | None = None) -> str: """ Format SQLite IntegrityError with detailed context for debugging. Parameters ---------- e : `sqlite3.IntegrityError` The SQLite integrity error that occurred elem : `ET.Element` The XML element being parsed when error occurred sql : `str` The SQL statement that failed keys : `list` [`str`] Column names from the SQL statement values : `tuple` Values that were being inserted accession : `int` or `None`, default `None` The XML accession number Returns ------- str Formatted error message with context and suggested fixes """ error_msg = str(e) # Build detailed error message msg_parts = [ f"Database constraint violation while inserting {elem.tag}:", ] if accession: msg_parts.append(f" Accession: {accession}") # Add error message msg_parts.append(f" Error: {error_msg}") # Check error type and add specific details if "NOT NULL constraint failed" in error_msg: # Extract column name failed_column =( error_msg.split(":")[-1].strip() if ":" in error_msg else "unknown" ) msg_parts.extend([ f" Failed column: {failed_column}", f" SQL: {sql}", f" Values provided: {dict(zip(keys, values))}", f" Element attributes: {dict(elem.attrib)}", "", "Possible fixes:", " 1. Check 'attributes' config - verify xml_attr/xml_path exists", " 2. Make column nullable in schema (remove NOT NULL)", " 3. Verify xpath in 'children' config finds the element", ]) elif "FOREIGN KEY constraint failed" in error_msg: # foreign key error msg_parts.extend([ f" SQL: {sql}", f" Values: {dict(zip(keys, values))}", "", "Parent record was not inserted or has wrong ID.", "Possible fixes:", " 1. Check 'parent_id' in config matches parent's 'returns_id'", " 2. Verify 'children' xpath correctly locates child elements", " 3. For polymorphic tables, ensure 'entity_type' is set in child spec", " 4. Check parent table has 'returns_id' configured", ]) elif "UNIQUE constraint failed" in error_msg: # unique failed_columns = ( error_msg.split(":")[-1].strip() if ":" in error_msg else "unknown" ) msg_parts.extend([ f" Duplicate columns: {failed_columns}", f" SQL: {sql}", f" Values: {dict(zip(keys, values))}", "", "This record already exists in the database.", "Possible fixes:", " 1. Set 'ignore_duplicates': true in config sql block", " 2. Check if XML contains duplicate records for this accession", " 3. Verify 'children' xpath is not matching same element twice", " 4. Clear database before re-parsing", ]) elif "CHECK constraint failed" in error_msg: # constraint msg_parts.extend([ f" SQL: {sql}", f" Values: {dict(zip(keys, values))}", f" Element attributes: {dict(elem.attrib)}", "", "A CHECK constraint was violated - value outside allowed range.", "Possible fixes:", " 1. Verify 'cast' type in attributes config (int/float/bool/str)", " 2. Check xml_attr/xml_path extracts the correct value", " 3. Review CHECK constraint in schema matches expected XML values", " 4. Add validation in _extract_value for this field", ]) else: # Generic integrity error msg_parts.extend([ f" SQL: {sql}", f" Values: {dict(zip(keys, values))}", ]) return "\n".join(msg_parts)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # At module level in your logging_config module TRACE = 5 logging.addLevelName(TRACE, "TRACE")
[docs] def trace(self, message, *args, **kwargs): """custom trace method for detailed logging.""" if self.isEnabledFor(TRACE): self._log(TRACE, message, args, **kwargs)
# add the trace method to Logger class at import time logging.Logger.trace = trace
[docs] def configure_logging(verbosity: int) -> None: """ Configure logging based on verbosity level. This function configures the root logger, which affects all loggers in the application through inheritance. Parameters ---------- verbosity : int Number of -v flags. 0 = WARNING, 1 = INFO, 2 = DEBUG, 3 = TRACE """ # verbosity level if verbosity == 0: log_level = logging.WARNING elif verbosity == 1: log_level = logging.INFO elif verbosity == 2: log_level = logging.DEBUG else: log_level = TRACE # Get the root logger root_logger = logging.getLogger() root_logger.setLevel(log_level) # Remove existing handlers for handler in root_logger.handlers[:]: root_logger.removeHandler(handler) # Add new handler handler = logging.StreamHandler() handler.setLevel(log_level) formatter = logging.Formatter( '%(asctime)s - %(levelname)s - %(message)s' ) handler.setFormatter(formatter) root_logger.addHandler(handler)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] class ViewXML(object): """ Class to view and inspect XML records. This class loads a single XML element and provides utilities for inspecting and navigating its structure, particularly useful for large ClinVar XML files. Parameters ---------- xml_path : `str` or `Path` Path to ClinVar XML file (supports .gz compression) tag_name : `str` XML tag name to search for (e.g., 'VariationArchive') index : `int`, default 0 Which occurrence of the tag to load (0-indexed) Attributes ---------- xml_path : `Path` Path to the source XML file tag_name : `str` The tag name that was searched for index : `int` The index of the loaded element element : `ET.Element` or `None` The loaded XML element, or None if not found Methods ------- show_tree(max_depth=4, indent=0) Display element tree structure up to specified depth show_children(indent=0) Display immediate children with summary information find_all_paths() Find all unique XPath-like paths in the element tree count_all_tags() Count all descendant tags by frequency Examples -------- >>> viewer = ViewXML('clinvar.xml.gz', 'VariationArchive', index=5) >>> viewer.show_tree() >>> paths = viewer.find_all_paths() """ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def __init__(self, xml_path: str | Path, tag_name: str, index: int = 0): """ Initialise ViewXML with a single XML element. """ is_type(xml_path, (str, Path)) is_type(tag_name, str) is_type(index, int) self.xml_path = Path(xml_path) self.tag_name = tag_name self.index = index self.element = self._load_element()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def __str__(self) -> str: """ Return human-readable string representation. Returns ------- str Summary of the loaded element """ NAME = type(self).__name__ if self.element is None: return (f"{NAME}(tag='{self.tag_name}', index={self.index}, " f"element=None)") n_children = len(list(self.element)) n_attrs = len(self.element.attrib) return (f"{NAME}(tag='{self.tag_name}', index={self.index}, " f"children={n_children}, attrs={n_attrs})")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def __repr__(self) -> str: """ Return unambiguous string representation. Returns ------- str String that could recreate the object """ NAME = type(self).__name__ return (f"{NAME}(xml_path='{self.xml_path}', " f"tag_name='{self.tag_name}', index={self.index})")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def show_children(self, indent: int = 0) -> None: """ Display immediate children of loaded element with summary info. Parameters ---------- indent : `int`, default 0 Indentation level for output formatting Raises ------ ValueError If no element is loaded """ is_type(indent, int) if self.element is None: raise ValueError("no element loaded. check xml path and tag name") # get the element tags and loop over children elem = self.element print(f"{' ' * indent}<{elem.tag}> has " f"{len(list(elem))} children:") for child in elem: attr_preview = dict(list(child.attrib.items())[:2]) text = (child.text.strip()[:30] if child.text and child.text.strip() else "") n_children = len(list(child)) print(f"{' ' * (indent+1)}<{child.tag}> " f"attrs={attr_preview} text={text!r} " f"children={n_children}")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def count_all_tags(self) -> dict[str, int]: """ Count all descendant tags in loaded element tree. Returns ------- dict of {str: int} Dictionary mapping tag names to their occurrence counts, ordered by frequency (most common first) Raises ------ ValueError If no element is loaded """ if self.element is None: raise ValueError("no element loaded. check xml path and tag " "name") counts = Counter() for descendant in self.element.iter(): counts[descendant.tag] += 1 return dict(counts.most_common())
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def show_tree(self, max_depth=4, indent=0): """" Display element tree structure up to specified depth. Parameters ---------- max_depth : `int`, default 4 Maximum depth to display indent : `int`, default 0 Starting indentation level Raises ------ ValueError If no element is loaded """ is_type(max_depth, int) is_type(indent, int) if max_depth < 1: raise ValueError(f"max_depth must be 1 or greater, got " f"{max_depth}") if indent < 0: raise ValueError(f"indent must be non-negative, got {indent}") if self.element is None: raise ValueError("no element loaded. check xml path and tag " "name") # recurse self._show_tree_recursive(self.element, max_depth - 1, indent + 1)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def find_all_paths(self) -> list[str]: """ Find all unique XPath-like paths in the loaded element tree. Returns ------- list of str Sorted list of all unique paths found in the tree Raises ------ ValueError If no element is loaded Examples -------- >>> viewer = ViewXML('clinvar.xml.gz', 'VariationArchive') >>> paths = viewer.find_all_paths() >>> print(paths[:5]) """ if self.element is None: raise ValueError("no element loaded. check xml path and tag name") paths = set() self._find_paths_recursive(self.element, "", paths) return sorted(paths)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def _load_element(self) -> ET.Element | None: """ Load the specified element from XML file. Returns ------- ET.Element or None The requested XML element, or None if not found """ # handle gzipped files try: fh = gzip.open(self.xml_path, 'rt', encoding='utf-8') except OSError: fh = open(self.xml_path, 'r', encoding='utf-8') # Iterate across the file try: # get full records context = ET.iterparse(fh, events=('end',)) count = 0 for _, elem in context: if elem.tag == self.tag_name: if count == self.index: return elem # not the correct record so increasing the count count += 1 elem.clear() return None finally: fh.close() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def _find_paths_recursive(self, elem: ET.Element, current_path: str, paths: set) -> None: """ Recursive helper for find_all_paths. Parameters ---------- elem : ET.Element Current element being processed current_path : str Path accumulated so far paths : set Set to collect unique paths """ current_path = f"{current_path}/{elem.tag}" paths.add(current_path) for child in elem: self._find_paths_recursive(child, current_path, paths) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def _show_tree_recursive(self, elem: ET.Element, max_depth: int, indent: int) -> None: """ Recursive helper for show_tree. Parameters ---------- elem : ET.Element Current element to display max_depth : int Remaining depth to display indent : int Current indentation level """ # Build info string n_children = len(list(elem)) attr_count = len(elem.attrib) text = ( elem.text.strip()[:20] if elem.text and elem.text.strip() else "" ) info = f"<{elem.tag}>" if attr_count: info += f" [{attr_count} attrs]" if text: info += f' "{text}..."' if n_children: info += f" ({n_children} children)" print(f"{' ' * indent}{info}") if max_depth < 1: return for child in elem: self._show_tree_recursive(child, max_depth - 1, indent + 1)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] @contextmanager def open_xml_file(xml_path: str | Path, xsd_path: str | Path | None = None, strict: bool = True, verbose: bool = True): """ Helper function to open an XML file, handling compression based on extension. Automatically closes the file handle when exiting the context, even if an exception occurs. Parameters ---------- xml_path : `str` The path to the file. xsd_path : `str`, `Path`, or `None`, default `None` Optional path to XSD schema file for validation. If provided, validates XML before yielding the file handle. strict : `bool`, default `True` If False, ignores elements in the XML that are not in the XSD. Only used when xsd_path is provided. verbose : `bool`, default `True` If True, warns about validation issues when strict=False. Only used when xsd_path is provided. Yields ------ file-like object An open file handle for the compressed or uncompressed XML. Raises ------ FileNotFoundError If xml_path or xsd_path does not exist IOError If file cannot be opened XMLValidationError If XSD validation fails (when xsd_path is provided) """ # Validate XML against XSD if provided if xsd_path is not None: validate_xml(xml_path, xsd_path, strict=strict, verbose=verbose) file_handle = None try: # Check if the file is compressed if str(xml_path).lower().endswith('.gz'): # 'rb' mode for reading binary compressed data file_handle = gzip.open(xml_path, 'rb') else: # 'r' mode for reading uncompressed data file_handle = open(xml_path, 'r', encoding='utf-8') yield file_handle finally: # close if there is a an error if file_handle is not None: file_handle.close()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def validate_xml(xml_path:str|Path, xsd_path:str|Path, strict:bool=True, verbose:bool=True, ) -> etree._ElementTree: """ Validates an XML file against an XSD schema. Parameters ---------- xml_path : `str` or `Path` Path to the XML file. xsd_path : `str` or `Path` Path to the XSD file. strict : `bool`, default `True` If False, ignores elements in the XML that are not in the XSD. Returns ------- etree._ElementTree The parsed XML document. Raises ------ XMLValidationError Raised if the XSD and XML are incompatible. """ # #### Check input is_type(strict, bool) is_type(verbose, bool) is_type(xml_path, (str, Path)) is_type(xsd_path, (str, Path)) # #### Load the XSD schema with open(xsd_path, 'rb') as xsd: schema_root = etree.XML(xsd.read()) schema = etree.XMLSchema(schema_root) # #### Parse the XML file try: with open(xml_path, 'rb') as xml: xml_doc = etree.parse(xml) except etree.XMLSyntaxError: parser = etree.XMLParser(resolve_entities=False, huge_tree=True) with open(xml_path, 'rb') as xml: xml_doc = etree.parse(xml, parser=parser,) # #### Compare XML against XSD is_valid = schema.validate(xml_doc) # #### Do we want to raise errors when XML elements are absent in the XSD if not is_valid and not strict: # if strict is False remove error relating to missing XSD elements errors = [err.message for err in schema.error_log] filtered_errs = [err for err in schema.error_log if\ not LXML_ERROR in err.message] # check if there are any other type of errors is_valid = len(filtered_errs) == 0 if is_valid: # Where there any errors to begin with if verbose == True and len(schema.error_log) > 0: warnings.warn(WARN1) else: errors = [err.message for err in filtered_errs] raise XMLValidationError(errors) elif is_valid == False: errors = [err.message for err in schema.error_log] raise XMLValidationError(errors) # return xml_doc return xml_doc