From 4cc16353774eb058b6bff4ac90fe4d4ffd13057d Mon Sep 17 00:00:00 2001 From: parmaster Date: Fri, 11 Oct 2024 17:33:05 +0300 Subject: [PATCH] Checking the number of open connections from local_net_address only --- tds_test.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tds_test.go b/tds_test.go index 386dab5d..497f7a12 100644 --- a/tds_test.go +++ b/tds_test.go @@ -700,7 +700,7 @@ func TestLeakedConnections(t *testing.T) { badParams.Database = "unknown_db" // Connecting with good credentials should not fail - goodConn, err := sql.Open("mssql", goodParams.URL().String()) + goodConn, err := sql.Open("sqlserver", goodParams.URL().String()) if err != nil { t.Fatal("Open connection failed:", err.Error()) } @@ -708,16 +708,23 @@ func TestLeakedConnections(t *testing.T) { if err != nil { t.Fatal("Ping with good credentials should not fail, but got error:", err.Error()) } - // Remember the number of open connections, excluding the current one + + var localNetAddr string + err = goodConn.QueryRow("SELECT local_net_address FROM sys.dm_exec_connections WHERE session_id=@@SPID").Scan(&localNetAddr) + if err != nil { + t.Fatal("cannot scan local_net_address value", err) + } + + // Remember the number of open connections from local_net_address, excluding the current one var openConnections int - err = goodConn.QueryRow("SELECT COUNT(*) AS openConnections FROM sys.dm_exec_connections WHERE session_id!=@@SPID").Scan(&openConnections) + err = goodConn.QueryRow("SELECT COUNT(*) AS openConnections FROM sys.dm_exec_connections WHERE session_id!=@@SPID AND local_net_address=@p1", localNetAddr).Scan(&openConnections) if err != nil { t.Fatal("cannot scan value", err) } // Open 10 connections to the unknown database, all should be closed immediately for i := 0; i < 10; i++ { - conn, err := sql.Open("mssql", badParams.URL().String()) + conn, err := sql.Open("sqlserver", badParams.URL().String()) if err != nil { // should not fail here t.Fatal("sql.Open failed:", err.Error()) @@ -731,7 +738,7 @@ func TestLeakedConnections(t *testing.T) { // Check if the number of open connections is the same as before var newOpenConnections int - err = goodConn.QueryRow("SELECT COUNT(*) AS openConnections FROM sys.dm_exec_connections WHERE session_id!=@@SPID").Scan(&newOpenConnections) + err = goodConn.QueryRow("SELECT COUNT(*) AS openConnections FROM sys.dm_exec_connections WHERE session_id!=@@SPID AND local_net_address=@p1", localNetAddr).Scan(&newOpenConnections) if err != nil { t.Fatal("cannot scan value", err) }