Source code for dataframe_schema_sync.schema_inference

import json
from collections.abc import Iterator
from email.utils import parsedate_to_datetime
from pathlib import Path
from typing import Any, ClassVar, Optional, Union

import numpy as np
import pandas as pd
import yaml  # type: ignore
from sqlalchemy import Boolean, DateTime, Float, Integer, Text
from sqlalchemy.dialects.postgresql import JSON


[docs] class SchemaInference: """ A class for inferring SQLAlchemy types from Pandas DataFrame columns and handling schema I/O operations. """ SQLALCHEMY_TYPE_MAP: ClassVar[dict[str, Any]] = { "TIMESTAMP(timezone=True)": DateTime(timezone=True), "DATETIME": DateTime(timezone=True), "INTEGER": Integer(), "FLOAT": Float(), "TEXT": Text(), "JSON": JSON, "BOOLEAN": Boolean(), }
[docs] @staticmethod def safe_str_conversion(x: Any) -> str: """ Convert a value to string, returning an empty string if the value is NaN. Args: x (Any): The value to convert. Returns: str: The string representation or an empty string if x is NaN. """ return "" if pd.isna(x) else str(x)
[docs] @staticmethod def safe_json_conversion(x: Any) -> list[Any]: """ Convert a value to a JSON object if it's a string. If the value is missing (NaN), empty, or cannot be converted, return an empty array. Args: x (Any): The value to convert. Returns: list: The parsed JSON value or an empty list. """ # If the value is a string, attempt to parse it as JSON. if isinstance(x, str): try: return json.loads(x) except Exception: return [] # If the value is None, return an empty list. if x is None: return [] # For numeric types or other types, try to detect if it's missing. if pd.isna(x): return [] # If x is a numpy array, check its size. if isinstance(x, np.ndarray): if x.size == 0: return [] else: return x # If x is a list, check if it's empty. if isinstance(x, list): if len(x) == 0: return [] else: return x return x
[docs] @staticmethod def save_schema_to_yaml( dtype_map: dict[str, Any], filename: Union[str, Path], schema_name: str, mapping_key: str, sync_method: str, ) -> None: """ Save dtype_map to a YAML file, storing SQLAlchemy types as text strings. The YAML content will be nested under the provided schema_name key and further under the given mapping_key. The sync_method parameter determines how the new schema is saved: - "update": Only new columns (not already present under mapping_key) are added. - "overwrite": The mapping for mapping_key is completely replaced by the new schema. Args: dtype_map (dict): Dictionary mapping column names to SQLAlchemy types. filename (str or Path): Path to the YAML file. schema_name (str): The parent key to use in the YAML file. mapping_key (str): The key under which the schema items are stored. sync_method (str): Must be either "update" or "overwrite". Raises: ValueError: If sync_method is not one of the allowed values. """ if sync_method not in ("update", "overwrite"): raise ValueError("sync_method must be either 'update' or 'overwrite'") schema_path = Path(filename) # Serialize the dtype_map (convert SQLAlchemy types to strings) new_mapping = { col: "TEXT" if isinstance(sql_type, Text) else str(sql_type) for col, sql_type in dtype_map.items() } if schema_path.exists(): with open(filename, encoding="utf-8") as file: content: dict[str, Any] = yaml.safe_load(file) or {} else: content = {} if schema_name in content: # There is an existing schema block existing_mapping = content[schema_name].get(mapping_key, {}) if sync_method == "update": # Only add new columns that are not already defined merged_mapping = existing_mapping.copy() for col, val in new_mapping.items(): if col not in merged_mapping: merged_mapping[col] = val content[schema_name][mapping_key] = merged_mapping else: # overwrite content[schema_name][mapping_key] = new_mapping else: # No existing schema block; add one content[schema_name] = {mapping_key: new_mapping} with open(filename, "w", encoding="utf-8") as file: yaml.dump(content, file, sort_keys=False)
[docs] @staticmethod def load_schema_from_yaml( filename: Union[str, Path], schema_name: str, mapping_key: str, ) -> dict[str, Any]: """ Load schema from a YAML file and convert stored text strings back into SQLAlchemy types. The method looks for the schema under the provided schema_name key and within its dynamic mapping_key. Args: filename (str or Path): Path to the YAML file. schema_name (str): The parent key under which the schema is stored. mapping_key (str): The key under which the schema items are stored. Returns: dict: Dictionary mapping column names to SQLAlchemy types. Raises: KeyError: If the provided schema_name or the mapping_key is not found in the YAML file. """ with open(filename, encoding="utf-8") as file: loaded_content: dict[str, Any] = yaml.safe_load(file) or {} loaded_schema = loaded_content.get(schema_name) if loaded_schema is None: raise KeyError(f"Schema '{schema_name}' not found in {filename}") columns_mapping = loaded_schema.get(mapping_key) if columns_mapping is None: raise KeyError(f"'{mapping_key}' key not found under schema '{schema_name}' in {filename}") return { col: SchemaInference.SQLALCHEMY_TYPE_MAP.get(sql_type or "TEXT", Text()) for col, sql_type in columns_mapping.items() }
[docs] @staticmethod def load_config_from_yaml( filename: Union[str, Path], schema_name: str, mapping_key: str, ) -> dict[str, Any]: """ Load config from a YAML file and convert stored text strings back into SQLAlchemy types. The method looks for the schema under the provided schema_name key and within its dynamic mapping_key. Args: filename (str or Path): Path to the YAML file. schema_name (str): The parent key under which the schema is stored. mapping_key (str): The key under which the schema items are stored. Returns: dict: Dictionary mapping column names to SQLAlchemy types. Raises: KeyError: If the provided schema_name or the mapping_key is not found in the YAML file. """ with open(filename, encoding="utf-8") as file: loaded_content: dict[str, Any] = yaml.safe_load(file) or {} loaded_schema = loaded_content.get(schema_name) if loaded_schema is None: raise KeyError(f"Schema '{schema_name}' not found in {filename}") columns_mapping = loaded_schema.get(mapping_key) if columns_mapping is None: raise KeyError(f"'{mapping_key}' key not found under schema '{schema_name}' in {filename}") return columns_mapping
# return { # col: SchemaInference.SQLALCHEMY_TYPE_MAP.get(sql_type or "TEXT", Text()) # for col, sql_type in columns_mapping.items() # }
[docs] @staticmethod def detect_and_convert_datetime(series: pd.Series) -> tuple[pd.Series, bool]: """ Detects datetime format in a Pandas Series and converts it to datetime64[ns, UTC]. Supports: - ISO 8601 formats (YYYY-MM-DDTHH:MM:SS.sssZ) - RFC 2822 (email/HTTP format) - Standard datetime format (YYYY-MM-DD HH:MM:SS.sss) Returns the converted series and a boolean indicating success. """ # Use only non-null values for detection. non_null = series.dropna() if non_null.empty: return pd.Series(pd.NaT, index=series.index, dtype="datetime64[ns, UTC]"), False # Check if series looks like a standard datetime (YYYY-MM-DD HH:MM:SS.sss) # This format check comes first to catch your specific format if isinstance(non_null.iloc[0], str) and len(non_null.iloc[0]) >= 19: sample = non_null.iloc[0] if ( len(sample) >= 19 and sample[4] == "-" and sample[7] == "-" and sample[10] == " " and sample[13] == ":" and sample[16] == ":" ): parsed_dates = pd.to_datetime(non_null, format="%Y-%m-%d %H:%M:%S.%f", errors="coerce", utc=True) if parsed_dates.notna().all(): # Apply conversion to the full series result = pd.Series(pd.NaT, index=series.index, dtype="datetime64[ns, UTC]") result[series.notna()] = pd.to_datetime( series[series.notna()], format="%Y-%m-%d %H:%M:%S.%f", errors="coerce", utc=True ) return result, True # Check only the non-null values. if parsed_dates.notna().all(): result = pd.Series(pd.NaT, index=series.index, dtype="datetime64[ns, UTC]") result[series.notna()] = pd.to_datetime(series[series.notna()], errors="coerce", utc=True) return result, True # Attempt RFC 2822 conversion parsed_dates = series.apply( lambda x: ( parsedate_to_datetime(x).astimezone(pd.Timestamp.utc) if pd.notnull(x) and isinstance(x, str) else pd.NaT ) ) if parsed_dates[series.notna()].notna().all(): return parsed_dates, True parsed_dates = series.apply( lambda date_str: ( pd.NaT if date_str == "0001-01-01T00:00:00Z" else pd.to_datetime(date_str, errors="coerce", utc=True) ) ) if parsed_dates[series.apply(lambda x: x != "0001-01-01T00:00:00Z")].notna().all(): return parsed_dates, True # If not all non-null values can be converted, do not treat the series as datetime. return series, False
[docs] @staticmethod def infer_sqlalchemy_type( series: pd.Series, infer_dates: bool = True, date_columns: Optional[Union[str, list[str]]] = None ) -> tuple[Any, pd.Series]: """ Given a pandas Series, determine the best matching SQLAlchemy type. Args: series (pd.Series): The column to analyze. infer_dates (bool): If True, attempts to infer datetime columns. date_columns (str or list of str): Specifies columns that should always be parsed as dates. Returns: tuple: (SQLAlchemy type, transformed Pandas Series) """ # Normalize date_columns to a list if provided date_cols_list = [] if date_columns is not None: if isinstance(date_columns, str): date_cols_list = [date_columns] else: date_cols_list = date_columns non_null = series.dropna() if non_null.empty: # When there is no non-null value, default to Text and use safe conversion so NaN becomes an empty cell. return Text(), series.apply(SchemaInference.safe_str_conversion) # If column is already datetime type, return DateTime type if pd.api.types.is_datetime64_any_dtype(series): # Ensure timezone awareness if series.dt.tz is None: return DateTime(timezone=True), pd.to_datetime(series, utc=True) return DateTime(timezone=True), series sample = non_null.iloc[0] # --- JSON-like objects --- if isinstance(sample, (dict, list)): return JSON(), series # --- Boolean conversion --- if pd.api.types.is_bool_dtype(series): return Boolean(), series true_vals = {"true", "t"} false_vals = {"false", "f"} lower_non_null = non_null.astype(str).str.lower().str.strip() if lower_non_null.isin(true_vals.union(false_vals)).all(): converted = series.astype(str).str.lower().str.strip().map(lambda x: x in true_vals) return Boolean(), converted # --- DateTime Conversion --- (Moved before numeric conversion) if infer_dates or (date_cols_list and series.name in date_cols_list): # Quick check if this looks like a datetime string before attempting full conversion if isinstance(sample, str) and len(sample) >= 19: # Check for common datetime pattern YYYY-MM-DD HH:MM:SS if ( sample[4] == "-" and sample[7] == "-" and (sample[10] == " " or sample[10] == "T") and sample[13] == ":" and sample[16] == ":" ): converted_series, is_datetime = SchemaInference.detect_and_convert_datetime(series) if is_datetime: return DateTime(timezone=True), converted_series # If not caught by the quick check but specified as a date column, try conversion anyway if date_cols_list and series.name in date_cols_list: converted_series, is_datetime = SchemaInference.detect_and_convert_datetime(series) if is_datetime: return DateTime(timezone=True), converted_series # --- Numeric conversion --- # Replace silent try-except-pass with proper handling numeric_conversion_successful = False try: numeric_series = pd.to_numeric(series, errors="raise") numeric_conversion_successful = True except (ValueError, TypeError): # Cannot convert to numeric type, continuing to other type checks numeric_conversion_successful = False if numeric_conversion_successful: if numeric_series.dropna().apply(lambda x: float(x).is_integer()).all(): return Integer(), numeric_series.astype("Int64") return Float(), numeric_series.astype(float) # --- Final DateTime Attempt --- (For cases not caught by the quick check) if infer_dates and not (date_cols_list and series.name in date_cols_list): converted_series, is_datetime = SchemaInference.detect_and_convert_datetime(series) if is_datetime: return DateTime(timezone=True), converted_series # --- Fallback to text --- return Text(), series.apply(SchemaInference.safe_str_conversion)
[docs] @staticmethod def convert_dataframe( df: pd.DataFrame, infer_dates: bool = True, date_columns: Optional[Union[str, list[str]]] = None, case: str = "snake", truncate_limit: int = 55, ) -> "SchemaConversionResult": """ Infer the SQLAlchemy type for each column, convert the DataFrame accordingly, and return a result object with the converted DataFrame, column mappings, and original-to-new column name mappings. Args: df (pd.DataFrame): The input DataFrame. infer_dates (bool): If True, attempts to infer datetime columns. date_columns (str or list of str): Columns that should always be parsed as dates. case (str): The naming convention to apply. Default is "snake". truncate_limit (int): The maximum length of the column names. Default is 55. Returns: SchemaConversionResult: Object containing the DataFrame, dtype map, and column name mapping """ df.replace("", pd.NA, inplace=True) df.dropna(axis=1, how="all", inplace=True) # Store original column names before cleaning original_columns = df.columns.tolist() # Clean column names cleaned_df = df.copy() cleaned_df = SchemaInference.clean_dataframe_names(cleaned_df, case=case, truncate_limit=truncate_limit) # Create mapping of original to cleaned column names renamed_columns_mapping = dict(zip(original_columns, cleaned_df.columns)) # Normalize date_columns to a list if provided as a string. if isinstance(date_columns, str): date_columns = [date_columns] elif date_columns is None: date_columns = [] # Determine the schema mapping. schema_map = {} for col in cleaned_df.columns: sql_type, converted_series = SchemaInference.infer_sqlalchemy_type( cleaned_df[col], infer_dates=infer_dates or (col in date_columns), date_columns=date_columns ) schema_map[col] = sql_type cleaned_df[col] = converted_series # Update the DataFrame columns based on the inferred SQLAlchemy types. for col, sql_type in schema_map.items(): if isinstance(sql_type, DateTime): # Ensure datetime columns are in datetime64[ns, UTC] cleaned_df[col] = pd.to_datetime(cleaned_df[col], errors="coerce", utc=True) elif isinstance(sql_type, Integer): cleaned_df[col] = pd.to_numeric(cleaned_df[col], errors="coerce").astype("Int64") elif isinstance(sql_type, Float): cleaned_df[col] = pd.to_numeric(cleaned_df[col], errors="coerce").astype(float) elif isinstance(sql_type, Boolean): cleaned_df[col] = cleaned_df[col].astype(bool) elif isinstance(sql_type, JSON): cleaned_df[col] = cleaned_df[col].apply(SchemaInference.safe_json_conversion) else: cleaned_df[col] = cleaned_df[col].apply(SchemaInference.safe_str_conversion) return SchemaConversionResult( df=cleaned_df, schema_map=schema_map, renamed_columns_mapping=renamed_columns_mapping )
[docs] @staticmethod def clean_dataframe_names(df: pd.DataFrame, case: str = "snake", truncate_limit: int = 55) -> pd.DataFrame: """ Clean the column names and index names of a DataFrame using pyjanitors clean_names method. Works with both regular Index and MultiIndex. Args: df (pd.DataFrame): The DataFrame whose column and index names should be cleaned. case (str): The naming convention to apply. Default is "snake". truncate_limit (int): The maximum length of the column names. Default is 55. Returns: pd.DataFrame: A new DataFrame with cleaned column and index names. Raises: ImportError: If pyjanitors is not installed. """ try: import janitor # noqa: F401 except ImportError as e: raise e # Clean the column names cleaned_df = df.clean_names(case_type=case, truncate_limit=truncate_limit) # Handle index names if df.index.name is not None: # For single indexes with a name # Create a temporary dataframe with just the index as a column temp_df = pd.DataFrame({df.index.name: df.index}) cleaned_temp_df = temp_df.clean_names(case_type=case, truncate_limit=truncate_limit) # The first and only column name is our cleaned index name cleaned_index_name = cleaned_temp_df.columns[0] # Set the index name to the cleaned version cleaned_df.index.name = cleaned_index_name elif hasattr(df.index, "names") and any(name is not None for name in df.index.names): # For MultiIndex with at least some named levels index_names = df.index.names cleaned_index_names = [] for name in index_names: if name is not None: # Create a temporary dataframe with just this index name temp_df = pd.DataFrame({name: [0]}) cleaned_temp_df = temp_df.clean_names(case_type=case, truncate_limit=truncate_limit) # The cleaned name is the new column name cleaned_index_names.append(cleaned_temp_df.columns[0]) else: # Keep None for unnamed levels cleaned_index_names.append(None) # Set the new index names cleaned_df.index.names = cleaned_index_names return cleaned_df
[docs] class SchemaConversionResult: """Class to hold the results of a DataFrame schema conversion.""" def __init__(self, df: pd.DataFrame, schema_map: dict[str, Any], renamed_columns_mapping: dict[str, str]): """ Initialize a schema conversion result. Args: df (pd.DataFrame): The converted DataFrame schema_map (dict): Dictionary mapping columns to SQLAlchemy types renamed_columns_mapping (dict): Dictionary mapping original column names to new column names """ self.dataframe = df self.schema_map = schema_map self.renamed_columns_mapping = renamed_columns_mapping def __iter__(self) -> Iterator[Any]: """ Make the class iterable to support tuple unpacking. Returns the DataFrame, schema_map, and renamed_columns_mapping in that order. """ return iter([self.dataframe, self.schema_map, self.renamed_columns_mapping])