diff --git a/Doc/library/sqlite3.rst b/Doc/library/sqlite3.rst index d1f7a6f120620b..410e1443ebb189 100644 --- a/Doc/library/sqlite3.rst +++ b/Doc/library/sqlite3.rst @@ -521,6 +521,43 @@ Connection Objects f.write('%s\n' % line) + .. method:: backup(filename, *, pages=0, progress=None, name="main") + + This method makes a backup of a SQLite database into the mandatory argument + *filename*, even while it's being accessed by other clients, or concurrently by + the same connection. + + By default, or when *pages* is either ``0`` or a negative integer, the entire + database is copied in a single step; otherwise the method performs a loop + copying up to the specified *pages* at a time. + + If *progress* is specified, it must either be ``None`` or a callable object that + will be executed at each iteration with three integer arguments, respectively + the *status* of the last iteration, the *remaining* number of pages still to be + copied and the *total* number of pages. + + The *name* argument specifies the database name that will be copied: it must be + a string containing either ``"main"``, the default, to indicate the main + database, ``"temp"`` to indicate the temporary database or the name specified + after the ``AS`` keyword in an ``ATTACH`` statement for an attached database. + + Example:: + + # Copy an existing database into another file + import sqlite3 + + def progress(status, remaining, total): + print(f"Copied {total-remaining} of {total} pages...") + + con = sqlite3.connect('existing_db.db') + con.backup('copy_of_existing_db.db', 1, progress) + + .. note:: This is available only when the underlying SQLite library is at + version 3.6.11 or higher. + + .. versionadded:: 3.7 + + .. _sqlite3-cursor-objects: Cursor Objects diff --git a/Lib/sqlite3/test/backup.py b/Lib/sqlite3/test/backup.py new file mode 100644 index 00000000000000..f121b538b721c5 --- /dev/null +++ b/Lib/sqlite3/test/backup.py @@ -0,0 +1,135 @@ +import os +import sqlite3 as sqlite +from tempfile import NamedTemporaryFile +import unittest + +@unittest.skipIf(sqlite.sqlite_version_info < (3, 6, 11), "Backup API not supported") +class BackupTests(unittest.TestCase): + def setUp(self): + cx = self.cx = sqlite.connect(":memory:") + cx.execute('CREATE TABLE foo (key INTEGER)') + cx.executemany('INSERT INTO foo (key) VALUES (?)', [(3,), (4,)]) + cx.commit() + + def tearDown(self): + self.cx.close() + + def testBackup(self, bckfn): + cx = sqlite.connect(bckfn) + result = cx.execute("SELECT key FROM foo ORDER BY key").fetchall() + self.assertEqual(result[0][0], 3) + self.assertEqual(result[1][0], 4) + + def CheckKeywordOnlyArgs(self): + with self.assertRaises(TypeError): + self.cx.backup('foo', 1) + + def CheckSimple(self): + with NamedTemporaryFile(suffix='.sqlite') as bckfn: + self.cx.backup(bckfn.name) + self.testBackup(bckfn.name) + + def CheckProgress(self): + journal = [] + + def progress(status, remaining, total): + journal.append(status) + + with NamedTemporaryFile(suffix='.sqlite') as bckfn: + self.cx.backup(bckfn.name, pages=1, progress=progress) + self.testBackup(bckfn.name) + + self.assertEqual(len(journal), 2) + self.assertEqual(journal[0], sqlite.SQLITE_OK) + self.assertEqual(journal[1], sqlite.SQLITE_DONE) + + def CheckProgressAllPagesAtOnce_0(self): + journal = [] + + def progress(status, remaining, total): + journal.append(remaining) + + with NamedTemporaryFile(suffix='.sqlite') as bckfn: + self.cx.backup(bckfn.name, progress=progress) + self.testBackup(bckfn.name) + + self.assertEqual(len(journal), 1) + self.assertEqual(journal[0], 0) + + def CheckProgressAllPagesAtOnce_1(self): + journal = [] + + def progress(status, remaining, total): + journal.append(remaining) + + with NamedTemporaryFile(suffix='.sqlite') as bckfn: + self.cx.backup(bckfn.name, pages=-1, progress=progress) + self.testBackup(bckfn.name) + + self.assertEqual(len(journal), 1) + self.assertEqual(journal[0], 0) + + def CheckNonCallableProgress(self): + with NamedTemporaryFile(suffix='.sqlite') as bckfn: + with self.assertRaises(TypeError) as err: + self.cx.backup(bckfn.name, pages=1, progress='bar') + self.assertEqual(str(err.exception), 'progress argument must be a callable') + + def CheckModifyingProgress(self): + journal = [] + + def progress(status, remaining, total): + if not journal: + self.cx.execute('INSERT INTO foo (key) VALUES (?)', (remaining+1000,)) + self.cx.commit() + journal.append(remaining) + + with NamedTemporaryFile(suffix='.sqlite') as bckfn: + self.cx.backup(bckfn.name, pages=1, progress=progress) + self.testBackup(bckfn.name) + + cx = sqlite.connect(bckfn.name) + result = cx.execute("SELECT key FROM foo" + " WHERE key >= 1000" + " ORDER BY key").fetchall() + self.assertEqual(result[0][0], 1001) + + self.assertEqual(len(journal), 3) + self.assertEqual(journal[0], 1) + self.assertEqual(journal[1], 1) + self.assertEqual(journal[2], 0) + + def CheckFailingProgress(self): + def progress(status, remaining, total): + raise SystemError('nearly out of space') + + with NamedTemporaryFile(suffix='.sqlite', delete=False) as bckfn: + with self.assertRaises(SystemError) as err: + self.cx.backup(bckfn.name, progress=progress) + self.assertEqual(str(err.exception), 'nearly out of space') + self.assertFalse(os.path.exists(bckfn.name)) + + def CheckDatabaseSourceName(self): + with NamedTemporaryFile(suffix='.sqlite', delete=False) as bckfn: + self.cx.backup(bckfn.name, name='main') + self.cx.backup(bckfn.name, name='temp') + with self.assertRaises(sqlite.OperationalError): + self.cx.backup(bckfn.name, name='non-existing') + self.assertFalse(os.path.exists(bckfn.name)) + self.cx.execute("ATTACH DATABASE ':memory:' AS attached_db") + self.cx.execute('CREATE TABLE attached_db.foo (key INTEGER)') + self.cx.executemany('INSERT INTO attached_db.foo (key) VALUES (?)', [(3,), (4,)]) + self.cx.commit() + with NamedTemporaryFile(suffix='.sqlite') as bckfn: + self.cx.backup(bckfn.name, name='attached_db') + self.testBackup(bckfn.name) + +def suite(): + return unittest.TestSuite(unittest.makeSuite(BackupTests, "Check")) + +def test(): + runner = unittest.TextTestRunner() + runner.run(suite()) + +if __name__ == "__main__": + test() diff --git a/Lib/test/test_sqlite.py b/Lib/test/test_sqlite.py index adfcd9994575b3..9564da35193f1f 100644 --- a/Lib/test/test_sqlite.py +++ b/Lib/test/test_sqlite.py @@ -7,7 +7,7 @@ import sqlite3 from sqlite3.test import (dbapi, types, userfunctions, factory, transactions, hooks, regression, - dump) + dump, backup) def load_tests(*args): if test.support.verbose: @@ -18,7 +18,8 @@ def load_tests(*args): userfunctions.suite(), factory.suite(), transactions.suite(), hooks.suite(), regression.suite(), - dump.suite()]) + dump.suite(), + backup.suite()]) if __name__ == "__main__": unittest.main() diff --git a/Misc/NEWS b/Misc/NEWS index 4f19e75aeaa8fc..9cd90f7fd99ff0 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -10,6 +10,9 @@ What's New in Python 3.7.0 alpha 1? Core and Builtins ----------------- +- bpo-27645: sqlite3.Connection now exposes a backup() method, if the underlying SQLite + library is at version 3.6.11 or higher. Patch by Lele Gaifax. + - bpo-28598: Support __rmod__ for subclasses of str being called before str.__mod__. Patch by Martijn Pieters. diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 37b45f330b3493..c0935f715bddc3 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -21,6 +21,12 @@ * 3. This notice may not be removed or altered from any source distribution. */ +#ifdef HAVE_UNISTD_H +#include +#else +extern int unlink(const char *); +#endif + #include "cache.h" #include "module.h" #include "structmember.h" @@ -41,6 +47,10 @@ #endif #endif +#if SQLITE_VERSION_NUMBER >= 3006011 +#define HAVE_BACKUP_API +#endif + _Py_IDENTIFIER(cursor); static const char * const begin_statements[] = { @@ -1477,6 +1487,112 @@ pysqlite_connection_iterdump(pysqlite_Connection* self, PyObject* args) return retval; } +#ifdef HAVE_BACKUP_API +static PyObject * +pysqlite_connection_backup(pysqlite_Connection* self, PyObject* args, PyObject* kwds) +{ + char* filename; + int pages = -1; + PyObject* progress = Py_None; + char* name = "main"; + PyObject* retval = NULL; + int rc; + int cberr = 0; + sqlite3 *bckconn; + sqlite3_backup *bckhandle; + static char *keywords[] = {"filename", "pages", "progress", "name", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "s|$iOs:backup", keywords, + &filename, &pages, &progress, &name)) { + goto finally; + } + + if (progress != Py_None && !PyCallable_Check(progress)) { + PyErr_SetString(PyExc_TypeError, "progress argument must be a callable"); + goto finally; + } + + if (pages == 0) { + pages = -1; + } + + Py_BEGIN_ALLOW_THREADS + rc = sqlite3_open(filename, &bckconn); + Py_END_ALLOW_THREADS + + if (rc != SQLITE_OK) { + goto finally; + } + + Py_BEGIN_ALLOW_THREADS + bckhandle = sqlite3_backup_init(bckconn, "main", self->db, name); + Py_END_ALLOW_THREADS + + if (bckhandle) { + do { + Py_BEGIN_ALLOW_THREADS + rc = sqlite3_backup_step(bckhandle, pages); + Py_END_ALLOW_THREADS + + if (progress != Py_None) { + if (!PyObject_CallFunction(progress, "iii", rc, + sqlite3_backup_remaining(bckhandle), + sqlite3_backup_pagecount(bckhandle))) { + /* User's callback raised an error: interrupt the loop and + propagate it. */ + cberr = 1; + rc = -1; + } + } + + /* Sleep for 250ms if there are still further pages to copy and + the engine could not make any progress */ + if (rc == SQLITE_BUSY || rc == SQLITE_LOCKED) { + Py_BEGIN_ALLOW_THREADS + sqlite3_sleep(250); + Py_END_ALLOW_THREADS + } + } while (rc == SQLITE_OK || rc == SQLITE_BUSY || rc == SQLITE_LOCKED); + + Py_BEGIN_ALLOW_THREADS + rc = sqlite3_backup_finish(bckhandle); + Py_END_ALLOW_THREADS + } else { + rc = _pysqlite_seterror(bckconn, NULL); + } + + if (cberr == 0 && rc != SQLITE_OK) { + /* We cannot use _pysqlite_seterror() here because the backup APIs do + not set the error status on the connection object, but rather on + the backup handle. */ + if (rc == SQLITE_NOMEM) { + (void)PyErr_NoMemory(); + } else { + PyErr_SetString(pysqlite_OperationalError, sqlite3_errstr(rc)); + } + } + + Py_BEGIN_ALLOW_THREADS + sqlite3_close(bckconn); + Py_END_ALLOW_THREADS + + if (cberr == 0 && rc == SQLITE_OK) { + Py_INCREF(Py_None); + retval = Py_None; + } else { + /* Remove the probably incomplete/invalid backup */ + if (unlink(filename) < 0) { + /* FIXME: this should probably be chained to the outstanding + exception */ + return PyErr_SetFromErrno(PyExc_OSError); + } + } + +finally: + return retval; +} +#endif + static PyObject * pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args) { @@ -1649,6 +1765,10 @@ static PyMethodDef connection_methods[] = { PyDoc_STR("Abort any pending database operation. Non-standard.")}, {"iterdump", (PyCFunction)pysqlite_connection_iterdump, METH_NOARGS, PyDoc_STR("Returns iterator to the dump of the database in an SQL text format. Non-standard.")}, + #ifdef HAVE_BACKUP_API + {"backup", (PyCFunction)pysqlite_connection_backup, METH_VARARGS | METH_KEYWORDS, + PyDoc_STR("Makes a backup of the database. Non-standard.")}, + #endif {"__enter__", (PyCFunction)pysqlite_connection_enter, METH_NOARGS, PyDoc_STR("For context manager. Non-standard.")}, {"__exit__", (PyCFunction)pysqlite_connection_exit, METH_VARARGS, diff --git a/Modules/_sqlite/module.c b/Modules/_sqlite/module.c index 72c3a7f34fca0e..07eeaa2a696a10 100644 --- a/Modules/_sqlite/module.c +++ b/Modules/_sqlite/module.c @@ -315,6 +315,9 @@ static const IntConstantPair _int_constants[] = { #endif #if SQLITE_VERSION_NUMBER >= 3008003 {"SQLITE_RECURSIVE", SQLITE_RECURSIVE}, +#endif +#if SQLITE_VERSION_NUMBER >= 3006011 + {"SQLITE_DONE", SQLITE_DONE}, #endif {(char*)NULL, 0} };