diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py
index 39c9bf5b61..69ffc354c1 100644
--- a/Lib/sqlite3/test/dbapi.py
+++ b/Lib/sqlite3/test/dbapi.py
@@ -92,6 +92,10 @@ def test_shared_cache_deprecated(self):
sqlite.enable_shared_cache(enable)
self.assertIn("dbapi.py", cm.filename)
+ def test_complete_statement(self):
+ self.assertFalse(sqlite.complete_statement("select t"))
+ self.assertTrue(sqlite.complete_statement("create table t(t);"))
+
class ConnectionTests(unittest.TestCase):
@@ -191,6 +195,24 @@ def test_open_uri(self):
with self.assertRaises(sqlite.OperationalError):
cx.execute('insert into test(id) values(1)')
+ def test_interrupt_on_closed_db(self):
+ cx = sqlite.connect(":memory:")
+ cx.close()
+ with self.assertRaises(sqlite.ProgrammingError):
+ cx.interrupt()
+
+ def test_interrupt(self):
+ self.assertIsNone(self.cx.interrupt())
+
+ def test_drop_unused_refs(self):
+ for n in range(500):
+ cu = self.cx.execute(f"select {n}")
+ self.assertEqual(cu.fetchone()[0], n)
+
+ def test_database_keyword(self):
+ with sqlite.connect(database=":memory:") as cx:
+ self.assertEqual(type(cx), sqlite.Connection)
+
class CursorTests(unittest.TestCase):
def setUp(self):
@@ -522,6 +544,10 @@ def test_last_row_id_insert_o_r(self):
]
self.assertEqual(results, expected)
+ def test_same_query_in_multiple_cursors(self):
+ cursors = [self.cx.execute("select 1") for _ in range(3)]
+ for cu in cursors:
+ self.assertEqual(cu.fetchall(), [(1,)])
class ThreadTests(unittest.TestCase):
def setUp(self):
@@ -680,6 +706,21 @@ def run(cur, errors):
if len(errors) > 0:
self.fail("\n".join(errors))
+ def test_dont_check_same_thread(self):
+ def run(con, err):
+ try:
+ cur = con.execute("select 1")
+ except sqlite.Error:
+ err.append("multi-threading not allowed")
+
+ con = sqlite.connect(":memory:", check_same_thread=False)
+ err = []
+ t = threading.Thread(target=run, kwargs={"con": con, "err": err})
+ t.start()
+ t.join()
+ self.assertEqual(len(err), 0, "\n".join(err))
+
+
class ConstructorTests(unittest.TestCase):
def test_date(self):
d = sqlite.Date(2004, 10, 28)
diff --git a/Lib/sqlite3/test/factory.py b/Lib/sqlite3/test/factory.py
index 8764284975..7faa9ac8c1 100644
--- a/Lib/sqlite3/test/factory.py
+++ b/Lib/sqlite3/test/factory.py
@@ -123,6 +123,8 @@ def test_sqlite_row_index(self):
row[-3]
with self.assertRaises(IndexError):
row[2**1000]
+ with self.assertRaises(IndexError):
+ row[complex()] # index must be int or string
def test_sqlite_row_index_unicode(self):
self.con.row_factory = sqlite.Row
diff --git a/Lib/sqlite3/test/types.py b/Lib/sqlite3/test/types.py
index 2370dd1693..0b4a6a87b4 100644
--- a/Lib/sqlite3/test/types.py
+++ b/Lib/sqlite3/test/types.py
@@ -356,9 +356,9 @@ def test_cursor_description_cte(self):
class ObjectAdaptationTests(unittest.TestCase):
+ @staticmethod
def cast(obj):
return float(obj)
- cast = staticmethod(cast)
def setUp(self):
self.con = sqlite.connect(":memory:")
@@ -379,6 +379,43 @@ def test_caster_is_used(self):
val = self.cur.fetchone()[0]
self.assertEqual(type(val), float)
+ def test_missing_adapter(self):
+ with self.assertRaises(sqlite.ProgrammingError):
+ sqlite.adapt(1.) # No float adapter registered
+
+ def test_missing_protocol(self):
+ with self.assertRaises(sqlite.ProgrammingError):
+ sqlite.adapt(1, None)
+
+ def test_defect_proto(self):
+ class DefectProto():
+ def __adapt__(self):
+ return None
+ with self.assertRaises(sqlite.ProgrammingError):
+ sqlite.adapt(1., DefectProto)
+
+ def test_defect_self_adapt(self):
+ class DefectSelfAdapt(float):
+ def __conform__(self, _):
+ return None
+ with self.assertRaises(sqlite.ProgrammingError):
+ sqlite.adapt(DefectSelfAdapt(1.))
+
+ def test_custom_proto(self):
+ class CustomProto():
+ def __adapt__(self):
+ return "adapted"
+ self.assertEqual(sqlite.adapt(1., CustomProto), "adapted")
+
+ def test_adapt(self):
+ val = 42
+ self.assertEqual(float(val), sqlite.adapt(val))
+
+ def test_adapt_alt(self):
+ alt = "other"
+ self.assertEqual(alt, sqlite.adapt(1., None, alt))
+
+
@unittest.skipUnless(zlib, "requires zlib")
class BinaryConverterTests(unittest.TestCase):
def convert(s):
diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py
index 749ea049c8..0ed4a83ec4 100644
--- a/Lib/sqlite3/test/userfunctions.py
+++ b/Lib/sqlite3/test/userfunctions.py
@@ -21,10 +21,35 @@
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
+import contextlib
+import functools
+import io
import unittest
import unittest.mock
import sqlite3 as sqlite
+def with_tracebacks(strings):
+ """Convenience decorator for testing callback tracebacks."""
+ strings.append('Traceback')
+
+ def decorator(func):
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ # First, run the test with traceback enabled.
+ sqlite.enable_callback_tracebacks(True)
+ buf = io.StringIO()
+ with contextlib.redirect_stderr(buf):
+ func(self, *args, **kwargs)
+ tb = buf.getvalue()
+ for s in strings:
+ self.assertIn(s, tb)
+
+ # Then run the test with traceback disabled.
+ sqlite.enable_callback_tracebacks(False)
+ func(self, *args, **kwargs)
+ return wrapper
+ return decorator
+
def func_returntext():
return "foo"
def func_returnunicode():
@@ -227,6 +252,7 @@ def test_func_return_long_long(self):
val = cur.fetchone()[0]
self.assertEqual(val, 1<<31)
+ @with_tracebacks(['func_raiseexception', '5/0', 'ZeroDivisionError'])
def test_func_exception(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
@@ -364,6 +390,7 @@ def test_aggr_no_finalize(self):
val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
+ @with_tracebacks(['__init__', '5/0', 'ZeroDivisionError'])
def test_aggr_exception_in_init(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
@@ -371,6 +398,7 @@ def test_aggr_exception_in_init(self):
val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
+ @with_tracebacks(['step', '5/0', 'ZeroDivisionError'])
def test_aggr_exception_in_step(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
@@ -378,6 +406,7 @@ def test_aggr_exception_in_step(self):
val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
+ @with_tracebacks(['finalize', '5/0', 'ZeroDivisionError'])
def test_aggr_exception_in_finalize(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
@@ -479,6 +508,14 @@ def authorizer_cb(action, arg1, arg2, dbname, source):
raise ValueError
return sqlite.SQLITE_OK
+ @with_tracebacks(['authorizer_cb', 'ValueError'])
+ def test_table_access(self):
+ super().test_table_access()
+
+ @with_tracebacks(['authorizer_cb', 'ValueError'])
+ def test_column_access(self):
+ super().test_table_access()
+
class AuthorizerIllegalTypeTests(AuthorizerTests):
@staticmethod
def authorizer_cb(action, arg1, arg2, dbname, source):