Source code for ligo.skymap.util.sqlite

#
# Copyright (C) 2018-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Tools for reading and writing SQLite databases."""
import copyreg
import sqlite3
_open = open


def _open_a(string):
    return sqlite3.connect(string, check_same_thread=False)


def _open_r(string):
    return sqlite3.connect('file:{}?mode=ro'.format(string),
                           check_same_thread=False, uri=True)


def _open_w(string):
    with _open(string, 'wb'):
        pass
    return sqlite3.connect(string, check_same_thread=False)


_openers = {'a': _open_a, 'r': _open_r, 'w': _open_w}


[docs] def open(string, mode): """Open an SQLite database with an `open`-style mode flag. Parameters ---------- string : str Path of the SQLite database file mode : {'r', 'w', 'a'} Access mode: read only, clobber and overwrite, or modify in place. Returns ------- connection : `sqlite3.Connection` Raises ------ ValueError If the filename is invalid (e.g. ``/dev/stdin``), or if the requested mode is invalid OSError If the database could not be opened in the specified mode Examples -------- >>> import tempfile >>> import os >>> with tempfile.TemporaryDirectory() as d: ... open(os.path.join(d, 'test.sqlite'), 'w') ... <sqlite3.Connection object at 0x...> >>> with tempfile.TemporaryDirectory() as d: ... open(os.path.join(d, 'test.sqlite'), 'r') ... Traceback (most recent call last): ... OSError: Failed to open database ... >>> open('/dev/stdin', 'r') Traceback (most recent call last): ... ValueError: Cannot open stdin/stdout as an SQLite database >>> open('test.sqlite', 'x') Traceback (most recent call last): ... ValueError: Invalid mode "x". Must be one of "arw". """ if string in {'-', '/dev/stdin', '/dev/stdout'}: raise ValueError('Cannot open stdin/stdout as an SQLite database') try: opener = _openers[mode] except KeyError: raise ValueError('Invalid mode "{}". Must be one of "{}".'.format( mode, ''.join(sorted(_openers.keys())))) try: return opener(string) except (OSError, sqlite3.Error) as e: raise OSError('Failed to open database {}: {}'.format(string, e))
[docs] def get_filename(connection): r"""Get the name of the file associated with an SQLite connection. Parameters ---------- connection : `sqlite3.Connection` The database connection Returns ------- str The name of the file that contains the SQLite database Raises ------ RuntimeError If more than one database is attached to the connection Examples -------- >>> import tempfile >>> import os >>> with tempfile.TemporaryDirectory() as d: ... with sqlite3.connect(os.path.join(d, 'test.sqlite')) as db: ... print(get_filename(db)) ... /.../test.sqlite >>> with tempfile.TemporaryDirectory() as d: ... with sqlite3.connect(os.path.join(d, 'test1.sqlite')) as db1, \ ... sqlite3.connect(os.path.join(d, 'test2.sqlite')) as db2: ... filename = get_filename(db1) ... db2.execute('ATTACH DATABASE "{}" AS db2'.format(filename)) ... print(get_filename(db2)) ... Traceback (most recent call last): ... RuntimeError: Expected exactly one attached database """ result = connection.execute('pragma database_list').fetchall() try: (_, _, filename), = result except ValueError: raise RuntimeError('Expected exactly one attached database') return filename
def pickle_sqlite3_connection(obj): return _open_a, (get_filename(obj),) copyreg.pickle(sqlite3.Connection, pickle_sqlite3_connection)