diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index 1fe4d2132..32ed55075 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -94,6 +94,47 @@ void Connection::connect(const py::dict& attrs_before) { void Connection::disconnect() { if (_dbcHandle) { LOG("Disconnecting from database"); + + // CRITICAL FIX: Mark all child statement handles as implicitly freed + // When we free the DBC handle below, the ODBC driver will automatically free + // all child STMT handles. We need to tell the SqlHandle objects about this + // so they don't try to free the handles again during their destruction. + + // THREAD-SAFETY: Lock mutex to safely access _childStatementHandles + // This protects against concurrent allocStatementHandle() calls or GC finalizers + { + std::lock_guard lock(_childHandlesMutex); + + // First compact: remove expired weak_ptrs (they're already destroyed) + size_t originalSize = _childStatementHandles.size(); + _childStatementHandles.erase( + std::remove_if(_childStatementHandles.begin(), _childStatementHandles.end(), + [](const std::weak_ptr& wp) { return wp.expired(); }), + _childStatementHandles.end()); + + LOG("Compacted child handles: %zu -> %zu (removed %zu expired)", + originalSize, _childStatementHandles.size(), + originalSize - _childStatementHandles.size()); + + LOG("Marking %zu child statement handles as implicitly freed", + _childStatementHandles.size()); + for (auto& weakHandle : _childStatementHandles) { + if (auto handle = weakHandle.lock()) { + // SAFETY ASSERTION: Only STMT handles should be in this vector + // This is guaranteed by allocStatementHandle() which only creates STMT handles + // If this assertion fails, it indicates a serious bug in handle tracking + if (handle->type() != SQL_HANDLE_STMT) { + LOG_ERROR("CRITICAL: Non-STMT handle (type=%d) found in _childStatementHandles. " + "This will cause a handle leak!", handle->type()); + continue; // Skip marking to prevent leak + } + handle->markImplicitlyFreed(); + } + } + _childStatementHandles.clear(); + _allocationsSinceCompaction = 0; + } // Release lock before potentially slow SQLDisconnect call + SQLRETURN ret = SQLDisconnect_ptr(_dbcHandle->get()); checkError(ret); // triggers SQLFreeHandle via destructor, if last owner @@ -173,7 +214,36 @@ SqlHandlePtr Connection::allocStatementHandle() { SQLHANDLE stmt = nullptr; SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbcHandle->get(), &stmt); checkError(ret); - return std::make_shared(static_cast(SQL_HANDLE_STMT), stmt); + auto stmtHandle = std::make_shared(static_cast(SQL_HANDLE_STMT), stmt); + + // THREAD-SAFETY: Lock mutex before modifying _childStatementHandles + // This protects against concurrent disconnect() or allocStatementHandle() calls, + // or GC finalizers running from different threads + { + std::lock_guard lock(_childHandlesMutex); + + // Track this child handle so we can mark it as implicitly freed when connection closes + // Use weak_ptr to avoid circular references and allow normal cleanup + _childStatementHandles.push_back(stmtHandle); + _allocationsSinceCompaction++; + + // Compact expired weak_ptrs only periodically to avoid O(n²) overhead + // This keeps allocation fast (O(1) amortized) while preventing unbounded growth + // disconnect() also compacts, so this is just for long-lived connections with many cursors + if (_allocationsSinceCompaction >= COMPACTION_INTERVAL) { + size_t originalSize = _childStatementHandles.size(); + _childStatementHandles.erase( + std::remove_if(_childStatementHandles.begin(), _childStatementHandles.end(), + [](const std::weak_ptr& wp) { return wp.expired(); }), + _childStatementHandles.end()); + _allocationsSinceCompaction = 0; + LOG("Periodic compaction: %zu -> %zu handles (removed %zu expired)", + originalSize, _childStatementHandles.size(), + originalSize - _childStatementHandles.size()); + } + } // Release lock + + return stmtHandle; } SQLRETURN Connection::setAttribute(SQLINTEGER attribute, py::object value) { @@ -308,7 +378,7 @@ bool Connection::reset() { disconnect(); return false; } - + // SQL_ATTR_RESET_CONNECTION does NOT reset the transaction isolation level. // Explicitly reset it to the default (SQL_TXN_READ_COMMITTED) to prevent // isolation level settings from leaking between pooled connection usages. @@ -320,7 +390,7 @@ bool Connection::reset() { disconnect(); return false; } - + updateLastUsed(); return true; } diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index d007106af..6c6f1e63c 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -5,10 +5,19 @@ #include "../ddbc_bindings.h" #include #include +#include // Represents a single ODBC database connection. // Manages connection handles. // Note: This class does NOT implement pooling logic directly. +// +// THREADING MODEL (per DB-API 2.0 threadsafety=1): +// - Connections should NOT be shared between threads in normal usage +// - However, _childStatementHandles is mutex-protected because: +// 1. Python GC/finalizers can run from any thread +// 2. Native code may release GIL during blocking ODBC calls +// 3. Provides safety if user accidentally shares connection +// - All accesses to _childStatementHandles are guarded by _childHandlesMutex class Connection { public: @@ -61,6 +70,22 @@ class Connection { std::chrono::steady_clock::time_point _lastUsed; std::wstring wstrStringBuffer; // wstr buffer for string attribute setting std::string strBytesBuffer; // string buffer for byte attributes setting + + // Track child statement handles to mark them as implicitly freed when connection closes + // Uses weak_ptr to avoid circular references and allow normal cleanup + // THREAD-SAFETY: All accesses must be guarded by _childHandlesMutex + std::vector> _childStatementHandles; + + // Counter for periodic compaction of expired weak_ptrs + // Compact every N allocations to avoid O(n²) overhead in hot path + // THREAD-SAFETY: Protected by _childHandlesMutex + size_t _allocationsSinceCompaction = 0; + static constexpr size_t COMPACTION_INTERVAL = 100; + + // Mutex protecting _childStatementHandles and _allocationsSinceCompaction + // Prevents data races between allocStatementHandle() and disconnect(), + // or concurrent GC finalizers running from different threads + mutable std::mutex _childHandlesMutex; }; class ConnectionHandle { diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index f49d860a8..2cf04fe0d 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -1144,6 +1144,21 @@ SQLSMALLINT SqlHandle::type() const { return _type; } +void SqlHandle::markImplicitlyFreed() { + // SAFETY: Only STMT handles should be marked as implicitly freed. + // When a DBC handle is freed, the ODBC driver automatically frees all child STMT handles. + // Other handle types (ENV, DBC, DESC) are NOT automatically freed by parents. + // Calling this on wrong handle types will cause silent handle leaks. + if (_type != SQL_HANDLE_STMT) { + // Log error but don't throw - we're likely in cleanup/destructor path + LOG_ERROR("SAFETY VIOLATION: Attempted to mark non-STMT handle as implicitly freed. " + "Handle type=%d. This will cause handle leak. Only STMT handles are " + "automatically freed by parent DBC handles.", _type); + return; // Refuse to mark - let normal free() handle it + } + _implicitly_freed = true; +} + /* * IMPORTANT: Never log in destructors - it causes segfaults. * During program exit, C++ destructors may run AFTER Python shuts down. @@ -1169,16 +1184,19 @@ void SqlHandle::free() { return; } - // Always clean up ODBC resources, regardless of Python state + // CRITICAL FIX: Check if handle was already implicitly freed by parent handle + // When Connection::disconnect() frees the DBC handle, the ODBC driver automatically + // frees all child STMT handles. We track this state to avoid double-free attempts. + // This approach avoids calling ODBC functions on potentially-freed handles, which + // would cause use-after-free errors. + if (_implicitly_freed) { + _handle = nullptr; // Just clear the pointer, don't call ODBC functions + return; + } + + // Handle is valid and not implicitly freed, proceed with normal freeing SQLFreeHandle_ptr(_type, _handle); _handle = nullptr; - - // Only log if Python is not shutting down (to avoid segfault) - if (!pythonShuttingDown) { - // Don't log during destruction - even in normal cases it can be - // problematic If logging is needed, use explicit close() methods - // instead - } } } @@ -2893,7 +2911,6 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p // Cache decimal separator to avoid repeated system calls - for (SQLSMALLINT i = 1; i <= colCount; ++i) { SQLWCHAR columnName[256]; SQLSMALLINT columnNameLen; @@ -3615,8 +3632,6 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum columnInfos[col].processedColumnSize + 1; // +1 for null terminator } - - // Performance: Build function pointer dispatch table (once per batch) // This eliminates the switch statement from the hot loop - 10,000 rows × 10 // cols reduces from 100,000 switch evaluations to just 10 switch @@ -4033,8 +4048,8 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch lobColumns.push_back(i + 1); // 1-based } } - - // Initialized to 0 for LOB path counter; overwritten by ODBC in non-LOB path; + + // Initialized to 0 for LOB path counter; overwritten by ODBC in non-LOB path; SQLULEN numRowsFetched = 0; // If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap if (!lobColumns.empty()) { @@ -4066,7 +4081,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch LOG("FetchMany_wrap: Error when binding columns - SQLRETURN=%d", ret); return ret; } - + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 391903ef2..fd9e7db71 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -379,9 +379,24 @@ class SqlHandle { SQLSMALLINT type() const; void free(); + // Mark this handle as implicitly freed (freed by parent handle) + // This prevents double-free attempts when the ODBC driver automatically + // frees child handles (e.g., STMT handles when DBC handle is freed) + // + // SAFETY CONSTRAINTS: + // - ONLY call this on SQL_HANDLE_STMT handles + // - ONLY call this when the parent DBC handle is about to be freed + // - Calling on other handle types (ENV, DBC, DESC) will cause HANDLE LEAKS + // - The ODBC spec only guarantees automatic freeing of STMT handles by DBC parents + // + // Current usage: Connection::disconnect() marks all tracked STMT handles + // before freeing the DBC handle. + void markImplicitlyFreed(); + private: SQLSMALLINT _type; SQLHANDLE _handle; + bool _implicitly_freed = false; // Tracks if handle was freed by parent }; using SqlHandlePtr = std::shared_ptr; diff --git a/tests/test_016_connection_invalidation_segfault.py b/tests/test_016_connection_invalidation_segfault.py new file mode 100644 index 000000000..4ae07306a --- /dev/null +++ b/tests/test_016_connection_invalidation_segfault.py @@ -0,0 +1,305 @@ +""" +Test for connection invalidation segfault scenario (Issue: Use-after-free on statement handles) + +This test reproduces the segfault that occurred in SQLAlchemy's RealReconnectTest +when connection invalidation triggered automatic freeing of child statement handles +by the ODBC driver, followed by Python GC attempting to free the same handles again. + +The fix uses state tracking where Connection marks child handles as "implicitly freed" +before disconnecting, preventing SqlHandle::free() from calling ODBC functions on +already-freed handles. + +Background: +- When Connection::disconnect() frees a DBC handle, ODBC automatically frees all child STMT handles +- Python SqlHandle objects weren't aware of this implicit freeing +- GC later tried to free those handles again via SqlHandle::free(), causing use-after-free +- Fix: Connection tracks children in _childStatementHandles vector and marks them as + implicitly freed before DBC is freed +""" + +import gc +import pytest +from mssql_python import connect, DatabaseError, OperationalError + + +def test_connection_invalidation_with_multiple_cursors(conn_str): + """ + Test connection invalidation scenario that previously caused segfaults. + + This test: + 1. Creates a connection with multiple active cursors + 2. Executes queries on those cursors to create statement handles + 3. Simulates connection invalidation by closing the connection + 4. Forces garbage collection to trigger handle cleanup + 5. Verifies no segfault occurs during the cleanup process + + Previously, this would crash because: + - Closing connection freed the DBC handle + - ODBC driver automatically freed all child STMT handles + - Python GC later tried to free those same STMT handles + - Result: use-after-free crash (segfault) + + With the fix: + - Connection marks all child handles as "implicitly freed" before closing + - SqlHandle::free() checks the flag and skips ODBC calls if true + - Result: No crash, clean shutdown + """ + # Create connection + conn = connect(conn_str) + + # Create multiple cursors with statement handles + cursors = [] + for i in range(5): + cursor = conn.cursor() + cursor.execute("SELECT 1 AS id, 'test' AS name") + cursor.fetchall() # Fetch results to complete the query + cursors.append(cursor) + + # Close connection without explicitly closing cursors first + # This simulates the invalidation scenario where connection is lost + conn.close() + + # Force garbage collection to trigger cursor cleanup + # This is where the segfault would occur without the fix + cursors = None + gc.collect() + + # If we reach here without crashing, the fix is working + assert True + + +def test_connection_invalidation_without_cursor_close(conn_str): + """ + Test that cursors are properly cleaned up when connection is closed + without explicitly closing the cursors. + + This mimics the SQLAlchemy scenario where connection pools may + invalidate connections without first closing all cursors. + """ + conn = connect(conn_str) + + # Create cursors and execute queries + cursor1 = conn.cursor() + cursor1.execute("SELECT 1") + cursor1.fetchone() + + cursor2 = conn.cursor() + cursor2.execute("SELECT 2") + cursor2.fetchone() + + cursor3 = conn.cursor() + cursor3.execute("SELECT 3") + cursor3.fetchone() + + # Close connection with active cursors + conn.close() + + # Trigger GC - should not crash + del cursor1, cursor2, cursor3 + gc.collect() + + assert True + + +def test_repeated_connection_invalidation_cycles(conn_str): + """ + Test repeated connection invalidation cycles to ensure no memory + corruption or handle leaks occur across multiple iterations. + + This stress test simulates the scenario from SQLAlchemy's + RealReconnectTest which ran multiple invalidation tests in sequence. + """ + for iteration in range(10): + # Create connection + conn = connect(conn_str) + + # Create and use cursors + for cursor_num in range(3): + cursor = conn.cursor() + cursor.execute(f"SELECT {iteration} AS iteration, {cursor_num} AS cursor_num") + result = cursor.fetchone() + assert result[0] == iteration + assert result[1] == cursor_num + + # Close connection (invalidate) + conn.close() + + # Force GC after each iteration + gc.collect() + + # Final GC to clean up any remaining references + gc.collect() + assert True + + +def test_connection_close_with_uncommitted_transaction(conn_str): + """ + Test that closing a connection with an uncommitted transaction + properly cleans up statement handles without crashing. + """ + conn = connect(conn_str) + cursor = conn.cursor() + + try: + # Start a transaction + cursor.execute("CREATE TABLE #temp_test (id INT, name VARCHAR(50))") + cursor.execute("INSERT INTO #temp_test VALUES (1, 'test')") + # Don't commit - leave transaction open + + # Close connection without commit or rollback + conn.close() + + # Trigger GC + del cursor + gc.collect() + + assert True + except Exception as e: + pytest.fail(f"Unexpected exception during connection close: {e}") + + +def test_cursor_after_connection_invalidation_raises_error(conn_str): + """ + Test that attempting to use a cursor after connection is closed + raises an appropriate error rather than crashing. + """ + conn = connect(conn_str) + cursor = conn.cursor() + cursor.execute("SELECT 1") + cursor.fetchone() + + # Close connection + conn.close() + + # Attempting to execute on cursor should raise an error, not crash + with pytest.raises((DatabaseError, OperationalError)): + cursor.execute("SELECT 2") + + # GC should not crash + del cursor + gc.collect() + + +def test_multiple_connections_concurrent_invalidation(conn_str): + """ + Test that multiple connections can be invalidated concurrently + without interfering with each other's handle cleanup. + """ + connections = [] + all_cursors = [] + + # Create multiple connections with cursors + for conn_num in range(5): + conn = connect(conn_str) + connections.append(conn) + + for cursor_num in range(3): + cursor = conn.cursor() + cursor.execute(f"SELECT {conn_num} AS conn, {cursor_num} AS cursor_num") + cursor.fetchone() + all_cursors.append(cursor) + + # Close all connections + for conn in connections: + conn.close() + + # Verify we have cursors alive (keep them referenced until after connection close) + assert len(all_cursors) == 15 # 5 connections * 3 cursors each + + # Clear references and force GC + connections = None + all_cursors = None + gc.collect() + + assert True + + +def test_connection_invalidation_with_prepared_statements(conn_str): + """ + Test connection invalidation when cursors have prepared statements. + This ensures statement handles are properly marked as implicitly freed. + """ + conn = connect(conn_str) + + # Create cursor with parameterized query (prepared statement) + cursor = conn.cursor() + cursor.execute("SELECT ? AS value", (42,)) + result = cursor.fetchone() + assert result[0] == 42 + + # Execute another parameterized query + cursor.execute("SELECT ? AS name, ? AS age", ("John", 30)) + result = cursor.fetchone() + assert result[0] == "John" + assert result[1] == 30 + + # Close connection with prepared statements + conn.close() + + # GC should handle cleanup without crash + del cursor + gc.collect() + + assert True + + +def test_verify_sqlhandle_free_method_exists(): + """ + Verify that the free method exists on SqlHandle. + The segfault fix uses markImplicitlyFreed internally in C++ (not exposed to Python). + """ + from mssql_python import ddbc_bindings + + # Verify free method exists + assert hasattr(ddbc_bindings.SqlHandle, "free"), "SqlHandle should have free method" + + +def test_connection_invalidation_with_fetchall(conn_str): + """ + Test connection invalidation when cursors have fetched all results. + This ensures all statement handle states are properly cleaned up. + """ + conn = connect(conn_str) + + cursor = conn.cursor() + cursor.execute("SELECT number FROM (VALUES (1), (2), (3), (4), (5)) AS numbers(number)") + results = cursor.fetchall() + assert len(results) == 5 + + # Close connection after fetchall + conn.close() + + # GC cleanup should work without issues + del cursor + gc.collect() + + assert True + + +def test_nested_connection_cursor_cleanup(conn_str): + """ + Test nested connection/cursor creation and cleanup pattern. + This mimics complex application patterns where connections + and cursors are created in nested scopes. + """ + + def inner_function(connection): + cursor = connection.cursor() + cursor.execute("SELECT 'inner' AS scope") + return cursor.fetchone() + + def outer_function(conn_str): + conn = connect(conn_str) + result = inner_function(conn) + conn.close() + return result + + # Run multiple times to ensure no accumulated state issues + for _ in range(5): + result = outer_function(conn_str) + assert result[0] == "inner" + gc.collect() + + # Final cleanup + gc.collect() + assert True