Source code for aequilibrae.utils.db_utils

import sqlite3
from dataclasses import dataclass
from os import PathLike
from pathlib import Path
from sqlite3 import Connection, connect
from typing import Union

import pandas as pd
from aequilibrae import logger


[docs] class AequilibraEConnection(sqlite3.Connection): """ This custom factory class intends to solve the issue of premature commits when trying to use manual transaction control. After ``manual_transaction`` is called, context manager enters and exits are tracked via their depth, the ``sqlite3.Connection`` is placed into manual transaction control and a transaction is started. If another transaction is already in progress an RuntimeError is raised. When exiting with depth == 0, the normal context manager enter and exit is called. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__manual_transaction: bool = False self.__depth: int = 0 self.__isolation_level = self.isolation_level
[docs] def manual_transaction(self): if self.__manual_transaction: raise RuntimeError( "cannot start a manual transaction while another manual transaction is already in progress" ) elif self.in_transaction: raise RuntimeError("cannot start a manual transaction while in another transaction") logger.debug("Manual transaction control enabled") self.__depth = 0 self.__manual_transaction = True self.__isolation_level = self.isolation_level self.isolation_level = None self.execute("BEGIN") return self
def __enter__(self): if self.__manual_transaction: self.__depth += 1 return super().__enter__() if self.__depth == 1 else self else: return super().__enter__() def __exit__(self, exc_type, exc_value, traceback): if self.__manual_transaction: self.__depth -= 1 if self.__depth <= 0: self.__manual_transaction = False res = super().__exit__(exc_type, exc_value, traceback) self.isolation_level = self.__isolation_level return res else: return super().__exit__(exc_type, exc_value, traceback)
[docs] def list_tables_in_db(conn: Connection): sql = "SELECT name FROM sqlite_master WHERE type ='table'" table_list = sorted([x[0].lower() for x in conn.execute(sql).fetchall() if "idx_" not in x[0].lower()]) return table_list
[docs] def safe_connect(filepath: PathLike, missing_ok=False): if Path(filepath).exists() or missing_ok or str(filepath) == ":memory:": return connect(filepath, factory=AequilibraEConnection) raise FileNotFoundError(f"Attempting to open non-existent SQLite database: {filepath}")
[docs] class commit_and_close: """A context manager for sqlite connections which closes and commits.""" def __init__(self, db: Union[str, Path, Connection], commit: bool = True, missing_ok: bool = False, spatial=False): """ :Arguments: **db** (:obj:`Union[str, Path, Connection]`): The database (filename or connection) to be managed **commit** (:obj:`bool`): Boolean indicating if a commit/rollback should be attempted on closing **missing_ok** (:obj:`bool`): Boolean indicating that the db is not expected to exist yet """ from aequilibrae.utils.spatialite_utils import connect_spatialite, load_spatialite_extension if spatial: if isinstance(db, Connection): load_spatialite_extension(db) self.conn = db elif not isinstance(db, (str, PathLike)): raise Exception("You must provide a database path to connect to spatialite") else: self.conn = connect_spatialite(db, missing_ok) elif isinstance(db, (str, PathLike)): self.conn = safe_connect(db, missing_ok) else: self.conn = db self.commit = commit def __enter__(self): return self.conn def __exit__(self, err_typ, err_value, traceback): if self.commit: if err_typ is None: self.conn.commit() else: self.conn.rollback() self.conn.close()
[docs] def read_and_close(filepath, spatial=False): """A context manager for sqlite connections (alias for `commit_and_close(db,commit=False))`.""" return commit_and_close(filepath, commit=False, spatial=spatial)
[docs] def read_sql(sql, filepath, **kwargs): with read_and_close(filepath) as conn: return pd.read_sql(sql, conn, **kwargs)
[docs] def has_table(conn, table_name): sql = f"SELECT name FROM sqlite_master WHERE type='table' AND name like '{table_name}';" return len(conn.execute(sql).fetchall()) > 0
[docs] @dataclass class ColumnDef: idx: int name: str type: str not_null: bool default: str is_pk: bool
[docs] def get_schema(conn, table_name): rv = [ColumnDef(*e) for e in conn.execute(f"PRAGMA table_info({table_name});").fetchall()] return {e.name: e for e in rv}
[docs] def list_columns(conn, table_name): return list(get_schema(conn, table_name).keys())
[docs] def has_column(conn, table_name, col_name): return col_name in get_schema(conn, table_name)
[docs] def add_column_unless_exists(conn, table_name, col_name, col_type, constraints=None): if not has_column(conn, table_name, col_name): add_column(conn, table_name, col_name, col_type, constraints)
[docs] def add_column(conn, table_name, col_name, col_type, constraints=None): sql = f"ALTER TABLE {table_name} ADD {col_name} {col_type} {constraints};" conn.execute(sql)