diff --git a/README.md b/README.md
index 169541eff05..7feed7d72cd 100644
--- a/README.md
+++ b/README.md
@@ -50,20 +50,20 @@ If you are using Maven without the BOM, add this to your dependencies:
 If you are using Gradle 5.x or later, add this to your dependencies:
 
 ```Groovy
-implementation platform('com.google.cloud:libraries-bom:26.24.0')
+implementation platform('com.google.cloud:libraries-bom:26.26.0')
 
 implementation 'com.google.cloud:google-cloud-spanner'
 ```
 If you are using Gradle without BOM, add this to your dependencies:
 
 ```Groovy
-implementation 'com.google.cloud:google-cloud-spanner:6.49.0'
+implementation 'com.google.cloud:google-cloud-spanner:6.52.1'
 ```
 
 If you are using SBT, add this to your dependencies:
 
 ```Scala
-libraryDependencies += "com.google.cloud" % "google-cloud-spanner" % "6.49.0"
+libraryDependencies += "com.google.cloud" % "google-cloud-spanner" % "6.52.1"
 ```
 <!-- {x-version-update-end} -->
 
@@ -432,7 +432,7 @@ Java is a registered trademark of Oracle and/or its affiliates.
 [kokoro-badge-link-5]: http://storage.googleapis.com/cloud-devrel-public/java/badges/java-spanner/java11.html
 [stability-image]: https://img.shields.io/badge/stability-stable-green
 [maven-version-image]: https://img.shields.io/maven-central/v/com.google.cloud/google-cloud-spanner.svg
-[maven-version-link]: https://central.sonatype.com/artifact/com.google.cloud/google-cloud-spanner/6.49.0
+[maven-version-link]: https://central.sonatype.com/artifact/com.google.cloud/google-cloud-spanner/6.52.1
 [authentication]: https://github.com/googleapis/google-cloud-java#authentication
 [auth-scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
 [predefined-iam-roles]: https://cloud.google.com/iam/docs/understanding-roles#predefined_roles
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java
index d5b1abe0b5b..a0b25cb64c0 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java
@@ -38,6 +38,7 @@
 import com.google.cloud.spanner.SessionImpl.SessionTransaction;
 import com.google.cloud.spanner.spi.v1.SpannerRpc;
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
 import com.google.common.util.concurrent.MoreExecutors;
 import com.google.protobuf.ByteString;
 import com.google.spanner.v1.BeginTransactionRequest;
@@ -72,6 +73,7 @@ abstract static class Builder<B extends Builder<?, T>, T extends AbstractReadCon
     private int defaultPrefetchChunks = SpannerOptions.Builder.DEFAULT_PREFETCH_CHUNKS;
     private QueryOptions defaultQueryOptions = SpannerOptions.Builder.DEFAULT_QUERY_OPTIONS;
     private ExecutorProvider executorProvider;
+    private Clock clock = new Clock();
 
     Builder() {}
 
@@ -110,6 +112,11 @@ B setExecutorProvider(ExecutorProvider executorProvider) {
       return self();
     }
 
+    B setClock(Clock clock) {
+      this.clock = Preconditions.checkNotNull(clock);
+      return self();
+    }
+
     abstract T build();
   }
 
@@ -392,6 +399,8 @@ void initTransaction() {
   private final int defaultPrefetchChunks;
   private final QueryOptions defaultQueryOptions;
 
+  private final Clock clock;
+
   @GuardedBy("lock")
   private boolean isValid = true;
 
@@ -416,6 +425,7 @@ void initTransaction() {
     this.defaultQueryOptions = builder.defaultQueryOptions;
     this.span = builder.span;
     this.executorProvider = builder.executorProvider;
+    this.clock = builder.clock;
   }
 
   @Override
@@ -689,6 +699,7 @@ CloseableIterator<PartialResultSet> startStream(@Nullable ByteString resumeToken
             SpannerRpc.StreamingCall call =
                 rpc.executeQuery(
                     request.build(), stream.consumer(), session.getOptions(), isRouteToLeader());
+            session.markUsed(clock.instant());
             call.request(prefetchChunks);
             stream.setCall(call, request.getTransaction().hasBegin());
             return stream;
@@ -826,6 +837,7 @@ CloseableIterator<PartialResultSet> startStream(@Nullable ByteString resumeToken
             SpannerRpc.StreamingCall call =
                 rpc.read(
                     builder.build(), stream.consumer(), session.getOptions(), isRouteToLeader());
+            session.markUsed(clock.instant());
             call.request(prefetchChunks);
             stream.setCall(call, /* withBeginTransaction = */ builder.getTransaction().hasBegin());
             return stream;
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Clock.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Clock.java
new file mode 100644
index 00000000000..bb3507eeb48
--- /dev/null
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Clock.java
@@ -0,0 +1,29 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *       http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner;
+
+import org.threeten.bp.Instant;
+
+/**
+ * Wrapper around current time so that we can fake it in tests. TODO(user): Replace with Java 8
+ * Clock.
+ */
+class Clock {
+  Instant instant() {
+    return Instant.now();
+  }
+}
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java
index 002a00134f6..0e763dbc93d 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java
@@ -53,6 +53,7 @@
 import java.util.Map;
 import java.util.concurrent.ExecutionException;
 import javax.annotation.Nullable;
+import org.threeten.bp.Instant;
 
 /**
  * Implementation of {@link Session}. Sessions are managed internally by the client library, and
@@ -98,12 +99,14 @@ interface SessionTransaction {
   ByteString readyTransactionId;
   private final Map<SpannerRpc.Option, ?> options;
   private Span currentSpan;
+  private volatile Instant lastUseTime;
 
   SessionImpl(SpannerImpl spanner, String name, Map<SpannerRpc.Option, ?> options) {
     this.spanner = spanner;
     this.options = options;
     this.name = checkNotNull(name);
     this.databaseId = SessionId.of(name).getDatabaseId();
+    this.lastUseTime = Instant.now();
   }
 
   @Override
@@ -123,6 +126,14 @@ Span getCurrentSpan() {
     return currentSpan;
   }
 
+  Instant getLastUseTime() {
+    return lastUseTime;
+  }
+
+  void markUsed(Instant instant) {
+    lastUseTime = instant;
+  }
+
   @Override
   public long executePartitionedUpdate(Statement stmt, UpdateOption... options) {
     setActive(null);
@@ -385,6 +396,9 @@ ApiFuture<ByteString> beginTransactionAsync(Options transactionOptions, boolean
   }
 
   TransactionContextImpl newTransaction(Options options) {
+    // A clock instance is passed in {@code SessionPoolOptions} in order to allow mocking via tests.
+    final Clock poolMaintainerClock =
+        spanner.getOptions().getSessionPoolOptions().getPoolMaintainerClock();
     return TransactionContextImpl.newBuilder()
         .setSession(this)
         .setOptions(options)
@@ -396,6 +410,7 @@ TransactionContextImpl newTransaction(Options options) {
         .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks())
         .setSpan(currentSpan)
         .setExecutorProvider(spanner.getAsyncExecutorProvider())
+        .setClock(poolMaintainerClock == null ? new Clock() : poolMaintainerClock)
         .build();
   }
 
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java
index ca61da80583..54a0a292cd8 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java
@@ -144,16 +144,6 @@ void maybeWaitOnMinSessions() {
     }
   }
 
-  /**
-   * Wrapper around current time so that we can fake it in tests. TODO(user): Replace with Java 8
-   * Clock.
-   */
-  static class Clock {
-    Instant instant() {
-      return Instant.now();
-    }
-  }
-
   private abstract static class CachedResultSetSupplier implements Supplier<ResultSet> {
     private ResultSet cached;
 
@@ -1370,7 +1360,6 @@ PooledSession get(final boolean eligibleForLongRunning) {
 
   class PooledSession implements Session {
     @VisibleForTesting SessionImpl delegate;
-    private volatile Instant lastUseTime;
     private volatile SpannerException lastException;
     private volatile boolean allowReplacing = true;
 
@@ -1409,7 +1398,9 @@ class PooledSession implements Session {
     private PooledSession(SessionImpl delegate) {
       this.delegate = delegate;
       this.state = SessionState.AVAILABLE;
-      this.lastUseTime = clock.instant();
+
+      // initialise the lastUseTime field for each session.
+      this.markUsed();
     }
 
     int getChannel() {
@@ -1631,7 +1622,7 @@ private void markClosing() {
     }
 
     void markUsed() {
-      lastUseTime = clock.instant();
+      delegate.markUsed(clock.instant());
     }
 
     @Override
@@ -1827,7 +1818,7 @@ private void removeIdleSessions(Instant currTime) {
         Iterator<PooledSession> iterator = sessions.descendingIterator();
         while (iterator.hasNext()) {
           PooledSession session = iterator.next();
-          if (session.lastUseTime.isBefore(minLastUseTime)) {
+          if (session.delegate.getLastUseTime().isBefore(minLastUseTime)) {
             if (session.state != SessionState.CLOSING) {
               boolean isRemoved = removeFromPool(session);
               if (isRemoved) {
@@ -1929,7 +1920,8 @@ private void removeLongRunningSessions(
             // collection is populated only when the get() method in {@code PooledSessionFuture} is
             // called.
             final PooledSession session = sessionFuture.get();
-            final Duration durationFromLastUse = Duration.between(session.lastUseTime, currentTime);
+            final Duration durationFromLastUse =
+                Duration.between(session.delegate.getLastUseTime(), currentTime);
             if (!session.eligibleForLongRunning
                 && durationFromLastUse.compareTo(
                         inactiveTransactionRemovalOptions.getIdleTimeThreshold())
@@ -2327,7 +2319,7 @@ private PooledSession findSessionToKeepAlive(
         && (numChecked + numAlreadyChecked)
             < (options.getMinSessions() + options.getMaxIdleSessions() - numSessionsInUse)) {
       PooledSession session = iterator.next();
-      if (session.lastUseTime.isBefore(keepAliveThreshold)) {
+      if (session.delegate.getLastUseTime().isBefore(keepAliveThreshold)) {
         iterator.remove();
         return session;
       }
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolOptions.java
index e29767abab3..cbea1495368 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolOptions.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolOptions.java
@@ -16,7 +16,6 @@
 
 package com.google.cloud.spanner;
 
-import com.google.cloud.spanner.SessionPool.Clock;
 import com.google.cloud.spanner.SessionPool.Position;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java
index ef937e993bd..21c74a400f0 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java
@@ -85,12 +85,19 @@ class TransactionRunnerImpl implements SessionTransaction, TransactionRunner {
   @VisibleForTesting
   static class TransactionContextImpl extends AbstractReadContext implements TransactionContext {
     static class Builder extends AbstractReadContext.Builder<Builder, TransactionContextImpl> {
+
+      private Clock clock = new Clock();
       private ByteString transactionId;
       private Options options;
       private boolean trackTransactionStarter;
 
       private Builder() {}
 
+      Builder setClock(Clock clock) {
+        this.clock = Preconditions.checkNotNull(clock);
+        return self();
+      }
+
       Builder setTransactionId(ByteString transactionId) {
         this.transactionId = transactionId;
         return self();
@@ -189,6 +196,7 @@ public void removeListener(Runnable listener) {
     volatile ByteString transactionId;
 
     private CommitResponse commitResponse;
+    private final Clock clock;
 
     private TransactionContextImpl(Builder builder) {
       super(builder);
@@ -196,6 +204,7 @@ private TransactionContextImpl(Builder builder) {
       this.trackTransactionStarter = builder.trackTransactionStarter;
       this.options = builder.options;
       this.finishedAsyncOperations.set(null);
+      this.clock = builder.clock;
     }
 
     @Override
@@ -389,6 +398,7 @@ public void run() {
               tracer.spanBuilderWithExplicitParent(SpannerImpl.COMMIT, span).startSpan();
           final ApiFuture<com.google.spanner.v1.CommitResponse> commitFuture =
               rpc.commitAsync(commitRequest, session.getOptions());
+          session.markUsed(clock.instant());
           commitFuture.addListener(
               tracer.withSpan(
                   opSpan,
@@ -463,12 +473,15 @@ ApiFuture<Empty> rollbackAsync() {
       // is still in flight. That transaction will then automatically be terminated by the server.
       if (transactionId != null) {
         span.addAnnotation("Starting Rollback");
-        return rpc.rollbackAsync(
-            RollbackRequest.newBuilder()
-                .setSession(session.getName())
-                .setTransactionId(transactionId)
-                .build(),
-            session.getOptions());
+        ApiFuture<Empty> apiFuture =
+            rpc.rollbackAsync(
+                RollbackRequest.newBuilder()
+                    .setSession(session.getName())
+                    .setTransactionId(transactionId)
+                    .build(),
+                session.getOptions());
+        session.markUsed(clock.instant());
+        return apiFuture;
       } else {
         return ApiFutures.immediateFuture(Empty.getDefaultInstance());
       }
@@ -723,6 +736,7 @@ private ResultSet internalExecuteUpdate(
       try {
         com.google.spanner.v1.ResultSet resultSet =
             rpc.executeQuery(builder.build(), session.getOptions(), isRouteToLeader());
+        session.markUsed(clock.instant());
         if (resultSet.getMetadata().hasTransaction()) {
           onTransactionMetadata(
               resultSet.getMetadata().getTransaction(), builder.getTransaction().hasBegin());
@@ -753,6 +767,7 @@ public ApiFuture<Long> executeUpdateAsync(Statement statement, UpdateOption... o
         // commit.
         increaseAsyncOperations();
         resultSet = rpc.executeQueryAsync(builder.build(), session.getOptions(), isRouteToLeader());
+        session.markUsed(clock.instant());
       } catch (Throwable t) {
         decreaseAsyncOperations();
         throw t;
@@ -824,6 +839,7 @@ public long[] batchUpdate(Iterable<Statement> statements, UpdateOption... option
       try {
         com.google.spanner.v1.ExecuteBatchDmlResponse response =
             rpc.executeBatchDml(builder.build(), session.getOptions());
+        session.markUsed(clock.instant());
         long[] results = new long[response.getResultSetsCount()];
         for (int i = 0; i < response.getResultSetsCount(); ++i) {
           results[i] = response.getResultSets(i).getStats().getRowCountExact();
@@ -863,6 +879,7 @@ public ApiFuture<long[]> batchUpdateAsync(
         // commit.
         increaseAsyncOperations();
         response = rpc.executeBatchDmlAsync(builder.build(), session.getOptions());
+        session.markUsed(clock.instant());
       } catch (Throwable t) {
         decreaseAsyncOperations();
         throw t;
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java
index 7f8cf5cc1b0..d36d32bda2a 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java
@@ -23,9 +23,10 @@
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.when;
 
+import com.google.api.core.ApiFuture;
 import com.google.api.core.ApiFutures;
 import com.google.cloud.grpc.GrpcTransportOptions.ExecutorFactory;
-import com.google.cloud.spanner.SessionPool.Clock;
+import com.google.cloud.spanner.Options.TransactionOption;
 import com.google.cloud.spanner.spi.v1.SpannerRpc.Option;
 import com.google.protobuf.Empty;
 import java.util.HashMap;
@@ -35,7 +36,6 @@
 import java.util.concurrent.ScheduledThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
-import org.threeten.bp.Instant;
 
 abstract class BaseSessionPoolTest {
   ScheduledExecutorService mockExecutor;
@@ -84,19 +84,47 @@ SessionImpl mockSession() {
     return session;
   }
 
+  SessionImpl buildMockSession(ReadContext context) {
+    SpannerImpl spanner = mock(SpannerImpl.class);
+    Map options = new HashMap<>();
+    options.put(Option.CHANNEL_HINT, channelHint.getAndIncrement());
+    final SessionImpl session =
+        new SessionImpl(
+            spanner,
+            "projects/dummy/instances/dummy/databases/dummy/sessions/session" + sessionIndex,
+            options) {
+          @Override
+          public ReadContext singleUse(TimestampBound bound) {
+            // The below stubs are added so that we can mock keep-alive.
+            return context;
+          }
+
+          @Override
+          public ApiFuture<Empty> asyncClose() {
+            return ApiFutures.immediateFuture(Empty.getDefaultInstance());
+          }
+
+          @Override
+          public CommitResponse writeAtLeastOnceWithOptions(
+              Iterable<Mutation> mutations, TransactionOption... transactionOptions)
+              throws SpannerException {
+            return new CommitResponse(com.google.spanner.v1.CommitResponse.getDefaultInstance());
+          }
+
+          @Override
+          public CommitResponse writeWithOptions(
+              Iterable<Mutation> mutations, TransactionOption... options) throws SpannerException {
+            return new CommitResponse(com.google.spanner.v1.CommitResponse.getDefaultInstance());
+          }
+        };
+    sessionIndex++;
+    return session;
+  }
+
   void runMaintenanceLoop(FakeClock clock, SessionPool pool, long numCycles) {
     for (int i = 0; i < numCycles; i++) {
       pool.poolMaintainer.maintainPool();
       clock.currentTimeMillis += pool.poolMaintainer.loopFrequency;
     }
   }
-
-  static class FakeClock extends Clock {
-    volatile long currentTimeMillis;
-
-    @Override
-    public Instant instant() {
-      return Instant.ofEpochMilli(currentTimeMillis);
-    }
-  }
 }
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java
index 029ea471cdb..aea8a4dcb64 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java
@@ -53,7 +53,6 @@
 import com.google.cloud.spanner.Options.RpcPriority;
 import com.google.cloud.spanner.Options.TransactionOption;
 import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode;
-import com.google.cloud.spanner.SessionPool.Clock;
 import com.google.cloud.spanner.SessionPool.PooledSessionFuture;
 import com.google.cloud.spanner.SessionPoolOptions.ActionOnInactiveTransaction;
 import com.google.cloud.spanner.SessionPoolOptions.InactiveTransactionRemovalOptions;
@@ -64,6 +63,7 @@
 import com.google.cloud.spanner.connection.RandomResultSetGenerator;
 import com.google.common.base.Stopwatch;
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
 import com.google.common.io.BaseEncoding;
 import com.google.common.util.concurrent.SettableFuture;
 import com.google.protobuf.AbstractMessage;
@@ -536,6 +536,757 @@ public void testPoolMaintainer_whenPDMLFollowedByInactiveTransaction_removeSessi
     assertTrue(client.pool.getNumberOfSessionsInPool() <= client.pool.totalSessions());
   }
 
+  @Test
+  public void
+      testPoolMaintainer_whenLongRunningReadsUsingTransactionRunner_retainSessionForTransaction()
+          throws Exception {
+    FakeClock poolMaintainerClock = new FakeClock();
+    InactiveTransactionRemovalOptions inactiveTransactionRemovalOptions =
+        InactiveTransactionRemovalOptions.newBuilder()
+            .setIdleTimeThreshold(
+                Duration.ofSeconds(
+                    3L)) // any session not used for more than 3s will be long-running
+            .setActionOnInactiveTransaction(ActionOnInactiveTransaction.CLOSE)
+            .setExecutionFrequency(Duration.ofSeconds(1)) // execute thread every 1s
+            .build();
+    SessionPoolOptions sessionPoolOptions =
+        SessionPoolOptions.newBuilder()
+            .setMinSessions(1)
+            .setMaxSessions(1) // to ensure there is 1 session and pool is 100% utilized
+            .setInactiveTransactionRemovalOptions(inactiveTransactionRemovalOptions)
+            .setLoopFrequency(1000L) // main thread runs every 1s
+            .setPoolMaintainerClock(poolMaintainerClock)
+            .build();
+    spanner =
+        SpannerOptions.newBuilder()
+            .setProjectId(TEST_PROJECT)
+            .setDatabaseRole(TEST_DATABASE_ROLE)
+            .setChannelProvider(channelProvider)
+            .setCredentials(NoCredentials.getInstance())
+            .setSessionPoolOption(sessionPoolOptions)
+            .build()
+            .getService();
+    DatabaseClientImpl client =
+        (DatabaseClientImpl)
+            spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE));
+    Instant initialExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    TransactionRunner runner = client.readWriteTransaction();
+    runner.run(
+        transaction -> {
+          try (ResultSet resultSet =
+              transaction.read(
+                  READ_TABLE_NAME,
+                  KeySet.singleKey(Key.of(1L)),
+                  READ_COLUMN_NAMES,
+                  Options.priority(RpcPriority.HIGH))) {
+            while (resultSet.next()) {}
+          }
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(1050).toMillis();
+
+          try (ResultSet resultSet =
+              transaction.read(
+                  READ_TABLE_NAME,
+                  KeySet.singleKey(Key.of(1L)),
+                  READ_COLUMN_NAMES,
+                  Options.priority(RpcPriority.HIGH))) {
+            while (resultSet.next()) {}
+          }
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(2050).toMillis();
+
+          // force trigger pool maintainer to check for long-running sessions
+          client.pool.poolMaintainer.maintainPool();
+
+          return null;
+        });
+
+    Instant endExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    assertNotEquals(
+        endExecutionTime,
+        initialExecutionTime); // if session clean up task runs then these timings won't match
+    assertEquals(0, client.pool.numLeakedSessionsRemoved());
+  }
+
+  @Test
+  public void
+      testPoolMaintainer_whenLongRunningQueriesUsingTransactionRunner_retainSessionForTransaction()
+          throws Exception {
+    FakeClock poolMaintainerClock = new FakeClock();
+    InactiveTransactionRemovalOptions inactiveTransactionRemovalOptions =
+        InactiveTransactionRemovalOptions.newBuilder()
+            .setIdleTimeThreshold(
+                Duration.ofSeconds(
+                    3L)) // any session not used for more than 3s will be long-running
+            .setActionOnInactiveTransaction(ActionOnInactiveTransaction.CLOSE)
+            .setExecutionFrequency(Duration.ofSeconds(1)) // execute thread every 1s
+            .build();
+    SessionPoolOptions sessionPoolOptions =
+        SessionPoolOptions.newBuilder()
+            .setMinSessions(1)
+            .setMaxSessions(1) // to ensure there is 1 session and pool is 100% utilized
+            .setInactiveTransactionRemovalOptions(inactiveTransactionRemovalOptions)
+            .setLoopFrequency(1000L) // main thread runs every 1s
+            .setPoolMaintainerClock(poolMaintainerClock)
+            .build();
+    spanner =
+        SpannerOptions.newBuilder()
+            .setProjectId(TEST_PROJECT)
+            .setDatabaseRole(TEST_DATABASE_ROLE)
+            .setChannelProvider(channelProvider)
+            .setCredentials(NoCredentials.getInstance())
+            .setSessionPoolOption(sessionPoolOptions)
+            .build()
+            .getService();
+    DatabaseClientImpl client =
+        (DatabaseClientImpl)
+            spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE));
+    Instant initialExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    TransactionRunner runner = client.readWriteTransaction();
+    runner.run(
+        transaction -> {
+          try (ResultSet resultSet = transaction.executeQuery(SELECT1)) {
+            while (resultSet.next()) {}
+          }
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(1050).toMillis();
+
+          try (ResultSet resultSet = transaction.executeQuery(SELECT1)) {
+            while (resultSet.next()) {}
+          }
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(2050).toMillis();
+
+          // force trigger pool maintainer to check for long-running sessions
+          client.pool.poolMaintainer.maintainPool();
+
+          return null;
+        });
+
+    Instant endExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    assertNotEquals(
+        endExecutionTime,
+        initialExecutionTime); // if session clean up task runs then these timings won't match
+    assertEquals(0, client.pool.numLeakedSessionsRemoved());
+  }
+
+  @Test
+  public void
+      testPoolMaintainer_whenLongRunningUpdatesUsingTransactionManager_retainSessionForTransaction()
+          throws Exception {
+    FakeClock poolMaintainerClock = new FakeClock();
+    InactiveTransactionRemovalOptions inactiveTransactionRemovalOptions =
+        InactiveTransactionRemovalOptions.newBuilder()
+            .setIdleTimeThreshold(
+                Duration.ofSeconds(
+                    3L)) // any session not used for more than 3s will be long-running
+            .setActionOnInactiveTransaction(ActionOnInactiveTransaction.CLOSE)
+            .setExecutionFrequency(Duration.ofSeconds(1)) // execute thread every 1s
+            .build();
+    SessionPoolOptions sessionPoolOptions =
+        SessionPoolOptions.newBuilder()
+            .setMinSessions(1)
+            .setMaxSessions(1) // to ensure there is 1 session and pool is 100% utilized
+            .setInactiveTransactionRemovalOptions(inactiveTransactionRemovalOptions)
+            .setLoopFrequency(1000L) // main thread runs every 1s
+            .setPoolMaintainerClock(poolMaintainerClock)
+            .build();
+    spanner =
+        SpannerOptions.newBuilder()
+            .setProjectId(TEST_PROJECT)
+            .setDatabaseRole(TEST_DATABASE_ROLE)
+            .setChannelProvider(channelProvider)
+            .setCredentials(NoCredentials.getInstance())
+            .setSessionPoolOption(sessionPoolOptions)
+            .build()
+            .getService();
+    DatabaseClientImpl client =
+        (DatabaseClientImpl)
+            spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE));
+    Instant initialExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    try (TransactionManager manager = client.transactionManager()) {
+      TransactionContext transaction = manager.begin();
+      while (true) {
+        try {
+          transaction.executeUpdate(UPDATE_STATEMENT);
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(1050).toMillis();
+
+          transaction.executeUpdate(UPDATE_STATEMENT);
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(2050).toMillis();
+
+          // force trigger pool maintainer to check for long-running sessions
+          client.pool.poolMaintainer.maintainPool();
+
+          manager.commit();
+          assertNotNull(manager.getCommitTimestamp());
+          break;
+        } catch (AbortedException e) {
+          transaction = manager.resetForRetry();
+        }
+      }
+    }
+    Instant endExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    assertNotEquals(
+        endExecutionTime,
+        initialExecutionTime); // if session clean up task runs then these timings won't match
+    assertEquals(0, client.pool.numLeakedSessionsRemoved());
+    assertTrue(client.pool.getNumberOfSessionsInPool() <= client.pool.totalSessions());
+  }
+
+  @Test
+  public void
+      testPoolMaintainer_whenLongRunningReadsUsingTransactionManager_retainSessionForTransaction() {
+    FakeClock poolMaintainerClock = new FakeClock();
+    InactiveTransactionRemovalOptions inactiveTransactionRemovalOptions =
+        InactiveTransactionRemovalOptions.newBuilder()
+            .setIdleTimeThreshold(
+                Duration.ofSeconds(
+                    3L)) // any session not used for more than 3s will be long-running
+            .setActionOnInactiveTransaction(ActionOnInactiveTransaction.CLOSE)
+            .setExecutionFrequency(Duration.ofSeconds(1)) // execute thread every 1s
+            .build();
+    SessionPoolOptions sessionPoolOptions =
+        SessionPoolOptions.newBuilder()
+            .setMinSessions(1)
+            .setMaxSessions(1) // to ensure there is 1 session and pool is 100% utilized
+            .setInactiveTransactionRemovalOptions(inactiveTransactionRemovalOptions)
+            .setLoopFrequency(1000L) // main thread runs every 1s
+            .setPoolMaintainerClock(poolMaintainerClock)
+            .build();
+    spanner =
+        SpannerOptions.newBuilder()
+            .setProjectId(TEST_PROJECT)
+            .setDatabaseRole(TEST_DATABASE_ROLE)
+            .setChannelProvider(channelProvider)
+            .setCredentials(NoCredentials.getInstance())
+            .setSessionPoolOption(sessionPoolOptions)
+            .build()
+            .getService();
+    DatabaseClientImpl client =
+        (DatabaseClientImpl)
+            spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE));
+    Instant initialExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    try (TransactionManager manager = client.transactionManager()) {
+      TransactionContext transaction = manager.begin();
+      while (true) {
+        try {
+          try (ResultSet resultSet =
+              transaction.read(
+                  READ_TABLE_NAME,
+                  KeySet.singleKey(Key.of(1L)),
+                  READ_COLUMN_NAMES,
+                  Options.priority(RpcPriority.HIGH))) {
+
+            while (resultSet.next()) {}
+          }
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(1050).toMillis();
+
+          try (ResultSet resultSet =
+              transaction.read(
+                  READ_TABLE_NAME,
+                  KeySet.singleKey(Key.of(1L)),
+                  READ_COLUMN_NAMES,
+                  Options.priority(RpcPriority.HIGH))) {
+
+            while (resultSet.next()) {}
+          }
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(2050).toMillis();
+
+          // force trigger pool maintainer to check for long-running sessions
+          client.pool.poolMaintainer.maintainPool();
+
+          manager.commit();
+          assertNotNull(manager.getCommitTimestamp());
+          break;
+        } catch (AbortedException e) {
+          transaction = manager.resetForRetry();
+        }
+      }
+    }
+    Instant endExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    assertNotEquals(
+        endExecutionTime,
+        initialExecutionTime); // if session clean up task runs then these timings won't match
+    assertEquals(0, client.pool.numLeakedSessionsRemoved());
+    assertTrue(client.pool.getNumberOfSessionsInPool() <= client.pool.totalSessions());
+  }
+
+  @Test
+  public void
+      testPoolMaintainer_whenLongRunningReadRowUsingTransactionManager_retainSessionForTransaction() {
+    FakeClock poolMaintainerClock = new FakeClock();
+    InactiveTransactionRemovalOptions inactiveTransactionRemovalOptions =
+        InactiveTransactionRemovalOptions.newBuilder()
+            .setIdleTimeThreshold(
+                Duration.ofSeconds(
+                    3L)) // any session not used for more than 3s will be long-running
+            .setActionOnInactiveTransaction(ActionOnInactiveTransaction.CLOSE)
+            .setExecutionFrequency(Duration.ofSeconds(1)) // execute thread every 1s
+            .build();
+    SessionPoolOptions sessionPoolOptions =
+        SessionPoolOptions.newBuilder()
+            .setMinSessions(1)
+            .setMaxSessions(1) // to ensure there is 1 session and pool is 100% utilized
+            .setInactiveTransactionRemovalOptions(inactiveTransactionRemovalOptions)
+            .setLoopFrequency(1000L) // main thread runs every 1s
+            .setPoolMaintainerClock(poolMaintainerClock)
+            .build();
+    spanner =
+        SpannerOptions.newBuilder()
+            .setProjectId(TEST_PROJECT)
+            .setDatabaseRole(TEST_DATABASE_ROLE)
+            .setChannelProvider(channelProvider)
+            .setCredentials(NoCredentials.getInstance())
+            .setSessionPoolOption(sessionPoolOptions)
+            .build()
+            .getService();
+    DatabaseClientImpl client =
+        (DatabaseClientImpl)
+            spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE));
+    Instant initialExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    try (TransactionManager manager = client.transactionManager()) {
+      TransactionContext transaction = manager.begin();
+      while (true) {
+        try {
+          transaction.readRow(READ_TABLE_NAME, Key.of(1L), READ_COLUMN_NAMES);
+
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(1050).toMillis();
+
+          transaction.readRow(READ_TABLE_NAME, Key.of(1L), READ_COLUMN_NAMES);
+
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(2050).toMillis();
+
+          // force trigger pool maintainer to check for long-running sessions
+          client.pool.poolMaintainer.maintainPool();
+
+          manager.commit();
+          assertNotNull(manager.getCommitTimestamp());
+          break;
+        } catch (AbortedException e) {
+          transaction = manager.resetForRetry();
+        }
+      }
+    }
+    Instant endExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    assertNotEquals(
+        endExecutionTime,
+        initialExecutionTime); // if session clean up task runs then these timings won't match
+    assertEquals(0, client.pool.numLeakedSessionsRemoved());
+    assertTrue(client.pool.getNumberOfSessionsInPool() <= client.pool.totalSessions());
+  }
+
+  @Test
+  public void
+      testPoolMaintainer_whenLongRunningAnalyzeUpdateStatementUsingTransactionManager_retainSessionForTransaction() {
+    FakeClock poolMaintainerClock = new FakeClock();
+    InactiveTransactionRemovalOptions inactiveTransactionRemovalOptions =
+        InactiveTransactionRemovalOptions.newBuilder()
+            .setIdleTimeThreshold(
+                Duration.ofSeconds(
+                    3L)) // any session not used for more than 3s will be long-running
+            .setActionOnInactiveTransaction(ActionOnInactiveTransaction.CLOSE)
+            .setExecutionFrequency(Duration.ofSeconds(1)) // execute thread every 1s
+            .build();
+    SessionPoolOptions sessionPoolOptions =
+        SessionPoolOptions.newBuilder()
+            .setMinSessions(1)
+            .setMaxSessions(1) // to ensure there is 1 session and pool is 100% utilized
+            .setInactiveTransactionRemovalOptions(inactiveTransactionRemovalOptions)
+            .setLoopFrequency(1000L) // main thread runs every 1s
+            .setPoolMaintainerClock(poolMaintainerClock)
+            .build();
+    spanner =
+        SpannerOptions.newBuilder()
+            .setProjectId(TEST_PROJECT)
+            .setDatabaseRole(TEST_DATABASE_ROLE)
+            .setChannelProvider(channelProvider)
+            .setCredentials(NoCredentials.getInstance())
+            .setSessionPoolOption(sessionPoolOptions)
+            .build()
+            .getService();
+    DatabaseClientImpl client =
+        (DatabaseClientImpl)
+            spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE));
+    Instant initialExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    try (TransactionManager manager = client.transactionManager()) {
+      TransactionContext transaction = manager.begin();
+      while (true) {
+        try {
+          try (ResultSet resultSet =
+              transaction.analyzeUpdateStatement(UPDATE_STATEMENT, QueryAnalyzeMode.PROFILE); ) {
+            while (resultSet.next()) {}
+          }
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(1050).toMillis();
+
+          try (ResultSet resultSet =
+              transaction.analyzeUpdateStatement(UPDATE_STATEMENT, QueryAnalyzeMode.PROFILE); ) {
+            while (resultSet.next()) {}
+          }
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(2050).toMillis();
+
+          // force trigger pool maintainer to check for long-running sessions
+          client.pool.poolMaintainer.maintainPool();
+
+          manager.commit();
+          assertNotNull(manager.getCommitTimestamp());
+          break;
+        } catch (AbortedException e) {
+          transaction = manager.resetForRetry();
+        }
+      }
+    }
+    Instant endExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    assertNotEquals(
+        endExecutionTime,
+        initialExecutionTime); // if session clean up task runs then these timings won't match
+    assertEquals(0, client.pool.numLeakedSessionsRemoved());
+    assertTrue(client.pool.getNumberOfSessionsInPool() <= client.pool.totalSessions());
+  }
+
+  @Test
+  public void
+      testPoolMaintainer_whenLongRunningBatchUpdatesUsingTransactionManager_retainSessionForTransaction() {
+    FakeClock poolMaintainerClock = new FakeClock();
+    InactiveTransactionRemovalOptions inactiveTransactionRemovalOptions =
+        InactiveTransactionRemovalOptions.newBuilder()
+            .setIdleTimeThreshold(
+                Duration.ofSeconds(
+                    3L)) // any session not used for more than 3s will be long-running
+            .setActionOnInactiveTransaction(ActionOnInactiveTransaction.CLOSE)
+            .setExecutionFrequency(Duration.ofSeconds(1)) // execute thread every 1s
+            .build();
+    SessionPoolOptions sessionPoolOptions =
+        SessionPoolOptions.newBuilder()
+            .setMinSessions(1)
+            .setMaxSessions(1) // to ensure there is 1 session and pool is 100% utilized
+            .setInactiveTransactionRemovalOptions(inactiveTransactionRemovalOptions)
+            .setLoopFrequency(1000L) // main thread runs every 1s
+            .setPoolMaintainerClock(poolMaintainerClock)
+            .build();
+    spanner =
+        SpannerOptions.newBuilder()
+            .setProjectId(TEST_PROJECT)
+            .setDatabaseRole(TEST_DATABASE_ROLE)
+            .setChannelProvider(channelProvider)
+            .setCredentials(NoCredentials.getInstance())
+            .setSessionPoolOption(sessionPoolOptions)
+            .build()
+            .getService();
+    DatabaseClientImpl client =
+        (DatabaseClientImpl)
+            spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE));
+    Instant initialExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    try (TransactionManager manager = client.transactionManager()) {
+      TransactionContext transaction = manager.begin();
+      while (true) {
+        try {
+          transaction.batchUpdate(Lists.newArrayList(UPDATE_STATEMENT));
+
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(1050).toMillis();
+
+          transaction.batchUpdate(Lists.newArrayList(UPDATE_STATEMENT));
+
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(2050).toMillis();
+
+          // force trigger pool maintainer to check for long-running sessions
+          client.pool.poolMaintainer.maintainPool();
+
+          manager.commit();
+          assertNotNull(manager.getCommitTimestamp());
+          break;
+        } catch (AbortedException e) {
+          transaction = manager.resetForRetry();
+        }
+      }
+    }
+    Instant endExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    assertNotEquals(
+        endExecutionTime,
+        initialExecutionTime); // if session clean up task runs then these timings won't match
+    assertEquals(0, client.pool.numLeakedSessionsRemoved());
+    assertTrue(client.pool.getNumberOfSessionsInPool() <= client.pool.totalSessions());
+  }
+
+  @Test
+  public void
+      testPoolMaintainer_whenLongRunningBatchUpdatesAsyncUsingTransactionManager_retainSessionForTransaction() {
+    FakeClock poolMaintainerClock = new FakeClock();
+    InactiveTransactionRemovalOptions inactiveTransactionRemovalOptions =
+        InactiveTransactionRemovalOptions.newBuilder()
+            .setIdleTimeThreshold(
+                Duration.ofSeconds(
+                    3L)) // any session not used for more than 3s will be long-running
+            .setActionOnInactiveTransaction(ActionOnInactiveTransaction.CLOSE)
+            .setExecutionFrequency(Duration.ofSeconds(1)) // execute thread every 1s
+            .build();
+    SessionPoolOptions sessionPoolOptions =
+        SessionPoolOptions.newBuilder()
+            .setMinSessions(1)
+            .setMaxSessions(1) // to ensure there is 1 session and pool is 100% utilized
+            .setInactiveTransactionRemovalOptions(inactiveTransactionRemovalOptions)
+            .setLoopFrequency(1000L) // main thread runs every 1s
+            .setPoolMaintainerClock(poolMaintainerClock)
+            .build();
+    spanner =
+        SpannerOptions.newBuilder()
+            .setProjectId(TEST_PROJECT)
+            .setDatabaseRole(TEST_DATABASE_ROLE)
+            .setChannelProvider(channelProvider)
+            .setCredentials(NoCredentials.getInstance())
+            .setSessionPoolOption(sessionPoolOptions)
+            .build()
+            .getService();
+    DatabaseClientImpl client =
+        (DatabaseClientImpl)
+            spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE));
+    Instant initialExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    try (TransactionManager manager = client.transactionManager()) {
+      TransactionContext transaction = manager.begin();
+      while (true) {
+        try {
+          transaction.batchUpdateAsync(Lists.newArrayList(UPDATE_STATEMENT));
+
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(1050).toMillis();
+
+          transaction.batchUpdateAsync(Lists.newArrayList(UPDATE_STATEMENT));
+
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(2050).toMillis();
+
+          // force trigger pool maintainer to check for long-running sessions
+          client.pool.poolMaintainer.maintainPool();
+
+          manager.commit();
+          assertNotNull(manager.getCommitTimestamp());
+          break;
+        } catch (AbortedException e) {
+          transaction = manager.resetForRetry();
+        }
+      }
+    }
+    Instant endExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    assertNotEquals(
+        endExecutionTime,
+        initialExecutionTime); // if session clean up task runs then these timings won't match
+    assertEquals(0, client.pool.numLeakedSessionsRemoved());
+    assertTrue(client.pool.getNumberOfSessionsInPool() <= client.pool.totalSessions());
+  }
+
+  @Test
+  public void
+      testPoolMaintainer_whenLongRunningExecuteQueryUsingTransactionManager_retainSessionForTransaction() {
+    FakeClock poolMaintainerClock = new FakeClock();
+    InactiveTransactionRemovalOptions inactiveTransactionRemovalOptions =
+        InactiveTransactionRemovalOptions.newBuilder()
+            .setIdleTimeThreshold(
+                Duration.ofSeconds(
+                    3L)) // any session not used for more than 3s will be long-running
+            .setActionOnInactiveTransaction(ActionOnInactiveTransaction.CLOSE)
+            .setExecutionFrequency(Duration.ofSeconds(1)) // execute thread every 1s
+            .build();
+    SessionPoolOptions sessionPoolOptions =
+        SessionPoolOptions.newBuilder()
+            .setMinSessions(1)
+            .setMaxSessions(1) // to ensure there is 1 session and pool is 100% utilized
+            .setInactiveTransactionRemovalOptions(inactiveTransactionRemovalOptions)
+            .setLoopFrequency(1000L) // main thread runs every 1s
+            .setPoolMaintainerClock(poolMaintainerClock)
+            .build();
+    spanner =
+        SpannerOptions.newBuilder()
+            .setProjectId(TEST_PROJECT)
+            .setDatabaseRole(TEST_DATABASE_ROLE)
+            .setChannelProvider(channelProvider)
+            .setCredentials(NoCredentials.getInstance())
+            .setSessionPoolOption(sessionPoolOptions)
+            .build()
+            .getService();
+    DatabaseClientImpl client =
+        (DatabaseClientImpl)
+            spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE));
+    Instant initialExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    try (TransactionManager manager = client.transactionManager()) {
+      TransactionContext transaction = manager.begin();
+      while (true) {
+        try {
+          try (ResultSet resultSet = transaction.executeQuery(SELECT1)) {
+            while (resultSet.next()) {}
+          }
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(1050).toMillis();
+
+          try (ResultSet resultSet = transaction.executeQuery(SELECT1)) {
+            while (resultSet.next()) {}
+          }
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(2050).toMillis();
+
+          // force trigger pool maintainer to check for long-running sessions
+          client.pool.poolMaintainer.maintainPool();
+
+          manager.commit();
+          assertNotNull(manager.getCommitTimestamp());
+          break;
+        } catch (AbortedException e) {
+          transaction = manager.resetForRetry();
+        }
+      }
+    }
+    Instant endExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    assertNotEquals(
+        endExecutionTime,
+        initialExecutionTime); // if session clean up task runs then these timings won't match
+    assertEquals(0, client.pool.numLeakedSessionsRemoved());
+    assertTrue(client.pool.getNumberOfSessionsInPool() <= client.pool.totalSessions());
+  }
+
+  @Test
+  public void
+      testPoolMaintainer_whenLongRunningExecuteQueryAsyncUsingTransactionManager_retainSessionForTransaction() {
+    FakeClock poolMaintainerClock = new FakeClock();
+    InactiveTransactionRemovalOptions inactiveTransactionRemovalOptions =
+        InactiveTransactionRemovalOptions.newBuilder()
+            .setIdleTimeThreshold(
+                Duration.ofSeconds(
+                    3L)) // any session not used for more than 3s will be long-running
+            .setActionOnInactiveTransaction(ActionOnInactiveTransaction.CLOSE)
+            .setExecutionFrequency(Duration.ofSeconds(1)) // execute thread every 1s
+            .build();
+    SessionPoolOptions sessionPoolOptions =
+        SessionPoolOptions.newBuilder()
+            .setMinSessions(1)
+            .setMaxSessions(1) // to ensure there is 1 session and pool is 100% utilized
+            .setInactiveTransactionRemovalOptions(inactiveTransactionRemovalOptions)
+            .setLoopFrequency(1000L) // main thread runs every 1s
+            .setPoolMaintainerClock(poolMaintainerClock)
+            .build();
+    spanner =
+        SpannerOptions.newBuilder()
+            .setProjectId(TEST_PROJECT)
+            .setDatabaseRole(TEST_DATABASE_ROLE)
+            .setChannelProvider(channelProvider)
+            .setCredentials(NoCredentials.getInstance())
+            .setSessionPoolOption(sessionPoolOptions)
+            .build()
+            .getService();
+    DatabaseClientImpl client =
+        (DatabaseClientImpl)
+            spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE));
+    Instant initialExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    try (TransactionManager manager = client.transactionManager()) {
+      TransactionContext transaction = manager.begin();
+      while (true) {
+        try {
+          try (ResultSet resultSet = transaction.executeQueryAsync(SELECT1)) {
+            while (resultSet.next()) {}
+          }
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(1050).toMillis();
+
+          try (ResultSet resultSet = transaction.executeQueryAsync(SELECT1)) {
+            while (resultSet.next()) {}
+          }
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(2050).toMillis();
+
+          // force trigger pool maintainer to check for long-running sessions
+          client.pool.poolMaintainer.maintainPool();
+
+          manager.commit();
+          assertNotNull(manager.getCommitTimestamp());
+          break;
+        } catch (AbortedException e) {
+          transaction = manager.resetForRetry();
+        }
+      }
+    }
+    Instant endExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    assertNotEquals(
+        endExecutionTime,
+        initialExecutionTime); // if session clean up task runs then these timings won't match
+    assertEquals(0, client.pool.numLeakedSessionsRemoved());
+    assertTrue(client.pool.getNumberOfSessionsInPool() <= client.pool.totalSessions());
+  }
+
+  @Test
+  public void
+      testPoolMaintainer_whenLongRunningAnalyzeQueryUsingTransactionManager_retainSessionForTransaction() {
+    FakeClock poolMaintainerClock = new FakeClock();
+    InactiveTransactionRemovalOptions inactiveTransactionRemovalOptions =
+        InactiveTransactionRemovalOptions.newBuilder()
+            .setIdleTimeThreshold(
+                Duration.ofSeconds(
+                    3L)) // any session not used for more than 3s will be long-running
+            .setActionOnInactiveTransaction(ActionOnInactiveTransaction.CLOSE)
+            .setExecutionFrequency(Duration.ofSeconds(1)) // execute thread every 1s
+            .build();
+    SessionPoolOptions sessionPoolOptions =
+        SessionPoolOptions.newBuilder()
+            .setMinSessions(1)
+            .setMaxSessions(1) // to ensure there is 1 session and pool is 100% utilized
+            .setInactiveTransactionRemovalOptions(inactiveTransactionRemovalOptions)
+            .setLoopFrequency(1000L) // main thread runs every 1s
+            .setPoolMaintainerClock(poolMaintainerClock)
+            .build();
+    spanner =
+        SpannerOptions.newBuilder()
+            .setProjectId(TEST_PROJECT)
+            .setDatabaseRole(TEST_DATABASE_ROLE)
+            .setChannelProvider(channelProvider)
+            .setCredentials(NoCredentials.getInstance())
+            .setSessionPoolOption(sessionPoolOptions)
+            .build()
+            .getService();
+    DatabaseClientImpl client =
+        (DatabaseClientImpl)
+            spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE));
+    Instant initialExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    try (TransactionManager manager = client.transactionManager()) {
+      TransactionContext transaction = manager.begin();
+      while (true) {
+        try {
+          try (ResultSet resultSet = transaction.analyzeQuery(SELECT1, QueryAnalyzeMode.PROFILE)) {
+            while (resultSet.next()) {}
+          }
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(1050).toMillis();
+
+          try (ResultSet resultSet = transaction.analyzeQuery(SELECT1, QueryAnalyzeMode.PROFILE)) {
+            while (resultSet.next()) {}
+          }
+          poolMaintainerClock.currentTimeMillis += Duration.ofMillis(2050).toMillis();
+
+          // force trigger pool maintainer to check for long-running sessions
+          client.pool.poolMaintainer.maintainPool();
+
+          manager.commit();
+          assertNotNull(manager.getCommitTimestamp());
+          break;
+        } catch (AbortedException e) {
+          transaction = manager.resetForRetry();
+        }
+      }
+    }
+    Instant endExecutionTime = client.pool.poolMaintainer.lastExecutionTime;
+
+    assertNotEquals(
+        endExecutionTime,
+        initialExecutionTime); // if session clean up task runs then these timings won't match
+    assertEquals(0, client.pool.numLeakedSessionsRemoved());
+    assertTrue(client.pool.getNumberOfSessionsInPool() <= client.pool.totalSessions());
+  }
+
   @Test
   public void testWrite() {
     DatabaseClient client =
@@ -3552,13 +4303,4 @@ static void assertAsString(ImmutableList<String> expected, ResultSet resultSet,
         expected.stream().collect(Collectors.joining(",", "[", "]")),
         resultSet.getValue(col).getAsString());
   }
-
-  static class FakeClock extends Clock {
-    volatile long currentTimeMillis;
-
-    @Override
-    public Instant instant() {
-      return Instant.ofEpochMilli(currentTimeMillis);
-    }
-  }
 }
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/FakeClock.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/FakeClock.java
new file mode 100644
index 00000000000..4bee3cc18e1
--- /dev/null
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/FakeClock.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *       http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.spanner;
+
+import org.threeten.bp.Instant;
+
+/**
+ * Class which allows to mock {@link Clock} in unit tests and return custom time values within the
+ * tests.
+ */
+class FakeClock extends Clock {
+  volatile long currentTimeMillis;
+
+  @Override
+  public Instant instant() {
+    return Instant.ofEpochMilli(currentTimeMillis);
+  }
+}
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolMaintainerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolMaintainerTest.java
index 2f7b14fdadf..217e214c832 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolMaintainerTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolMaintainerTest.java
@@ -84,7 +84,9 @@ private void setupMockSessionCreation() {
                     SessionConsumerImpl consumer =
                         invocation.getArgument(2, SessionConsumerImpl.class);
                     for (int i = 0; i < sessionCount; i++) {
-                      consumer.onSessionReady(setupMockSession(mockSession()));
+                      ReadContext mockContext = mock(ReadContext.class);
+                      consumer.onSessionReady(
+                          setupMockSession(buildMockSession(mockContext), mockContext));
                     }
                   });
               return null;
@@ -94,10 +96,8 @@ private void setupMockSessionCreation() {
             Mockito.anyInt(), Mockito.anyBoolean(), any(SessionConsumer.class));
   }
 
-  private SessionImpl setupMockSession(final SessionImpl session) {
-    ReadContext mockContext = mock(ReadContext.class);
+  private SessionImpl setupMockSession(final SessionImpl session, final ReadContext mockContext) {
     final ResultSet mockResult = mock(ResultSet.class);
-    when(session.singleUse(any(TimestampBound.class))).thenReturn(mockContext);
     when(mockContext.executeQuery(any(Statement.class)))
         .thenAnswer(
             invocation -> {
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolStressTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolStressTest.java
index c9ba9b360af..8a3a67b2de4 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolStressTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolStressTest.java
@@ -22,6 +22,7 @@
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
+import com.google.api.core.ApiFuture;
 import com.google.api.core.ApiFutures;
 import com.google.cloud.spanner.SessionClient.SessionConsumer;
 import com.google.cloud.spanner.SessionPool.PooledSessionFuture;
@@ -29,6 +30,7 @@
 import com.google.cloud.spanner.SessionPool.SessionConsumerImpl;
 import com.google.cloud.spanner.SessionPoolOptions.ActionOnInactiveTransaction;
 import com.google.cloud.spanner.SessionPoolOptions.InactiveTransactionRemovalOptions;
+import com.google.cloud.spanner.spi.v1.SpannerRpc.Option;
 import com.google.common.util.concurrent.Uninterruptibles;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.Empty;
@@ -89,6 +91,7 @@ public static Collection<Object[]> data() {
   }
 
   private void setupSpanner(DatabaseId db) {
+    ReadContext context = mock(ReadContext.class);
     mockSpanner = mock(SpannerImpl.class);
     spannerOptions = mock(SpannerOptions.class);
     when(spannerOptions.getNumChannels()).thenReturn(4);
@@ -105,8 +108,8 @@ private void setupSpanner(DatabaseId db) {
                     for (int s = 0; s < sessionCount; s++) {
                       SessionImpl session;
                       synchronized (lock) {
-                        session = mockSession();
-                        setupSession(session);
+                        session = getMockedSession(context);
+                        setupSession(session, context);
                         sessions.put(session.getName(), false);
                         if (sessions.size() > maxAliveSessions) {
                           maxAliveSessions = sessions.size();
@@ -124,10 +127,60 @@ private void setupSpanner(DatabaseId db) {
             Mockito.anyInt(), Mockito.anyBoolean(), Mockito.any(SessionConsumer.class));
   }
 
-  private void setupSession(final SessionImpl session) {
-    ReadContext mockContext = mock(ReadContext.class);
+  SessionImpl getMockedSession(ReadContext context) {
+    SpannerImpl spanner = mock(SpannerImpl.class);
+    Map options = new HashMap<>();
+    options.put(Option.CHANNEL_HINT, channelHint.getAndIncrement());
+    final SessionImpl session =
+        new SessionImpl(
+            spanner,
+            "projects/dummy/instances/dummy/databases/dummy/sessions/session" + sessionIndex,
+            options) {
+          @Override
+          public ReadContext singleUse(TimestampBound bound) {
+            // The below stubs are added so that we can mock keep-alive.
+            return context;
+          }
+
+          @Override
+          public ApiFuture<Empty> asyncClose() {
+            synchronized (lock) {
+              if (expiredSessions.contains(this.getName())) {
+                return ApiFutures.immediateFailedFuture(
+                    SpannerExceptionFactoryTest.newSessionNotFoundException(this.getName()));
+              }
+              if (sessions.remove(this.getName()) == null) {
+                setFailed(closedSessions.get(this.getName()));
+              }
+              closedSessions.put(this.getName(), new Exception("Session closed at:"));
+              if (sessions.size() < minSessionsWhenSessionClosed) {
+                minSessionsWhenSessionClosed = sessions.size();
+              }
+            }
+            return ApiFutures.immediateFuture(Empty.getDefaultInstance());
+          }
+
+          @Override
+          public void prepareReadWriteTransaction() {
+            if (random.nextInt(100) < 10) {
+              expireSession(this);
+              throw SpannerExceptionFactoryTest.newSessionNotFoundException(this.getName());
+            }
+            String name = this.getName();
+            synchronized (lock) {
+              if (sessions.put(name, true)) {
+                setFailed();
+              }
+              this.readyTransactionId = ByteString.copyFromUtf8("foo");
+            }
+          }
+        };
+    sessionIndex++;
+    return session;
+  }
+
+  private void setupSession(final SessionImpl session, final ReadContext mockContext) {
     final ResultSet mockResult = mock(ResultSet.class);
-    when(session.singleUse(any(TimestampBound.class))).thenReturn(mockContext);
     when(mockContext.executeQuery(any(Statement.class)))
         .thenAnswer(
             invocation -> {
@@ -135,44 +188,6 @@ private void setupSession(final SessionImpl session) {
               return mockResult;
             });
     when(mockResult.next()).thenReturn(true);
-    doAnswer(
-            invocation -> {
-              synchronized (lock) {
-                if (expiredSessions.contains(session.getName())) {
-                  return ApiFutures.immediateFailedFuture(
-                      SpannerExceptionFactoryTest.newSessionNotFoundException(session.getName()));
-                }
-                if (sessions.remove(session.getName()) == null) {
-                  setFailed(closedSessions.get(session.getName()));
-                }
-                closedSessions.put(session.getName(), new Exception("Session closed at:"));
-                if (sessions.size() < minSessionsWhenSessionClosed) {
-                  minSessionsWhenSessionClosed = sessions.size();
-                }
-              }
-              return ApiFutures.immediateFuture(Empty.getDefaultInstance());
-            })
-        .when(session)
-        .asyncClose();
-
-    doAnswer(
-            invocation -> {
-              if (random.nextInt(100) < 10) {
-                expireSession(session);
-                throw SpannerExceptionFactoryTest.newSessionNotFoundException(session.getName());
-              }
-              String name = session.getName();
-              synchronized (lock) {
-                if (sessions.put(name, true)) {
-                  setFailed();
-                }
-                session.readyTransactionId = ByteString.copyFromUtf8("foo");
-              }
-              return null;
-            })
-        .when(session)
-        .prepareReadWriteTransaction();
-    when(session.hasReadyTransaction()).thenCallRealMethod();
   }
 
   private void expireSession(Session session) {
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java
index b8cade68cc0..65c7c7c03c1 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java
@@ -51,7 +51,6 @@
 import com.google.cloud.spanner.MetricRegistryTestUtils.PointWithFunction;
 import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode;
 import com.google.cloud.spanner.SessionClient.SessionConsumer;
-import com.google.cloud.spanner.SessionPool.Clock;
 import com.google.cloud.spanner.SessionPool.PooledSession;
 import com.google.cloud.spanner.SessionPool.PooledSessionFuture;
 import com.google.cloud.spanner.SessionPool.Position;
@@ -641,6 +640,7 @@ public void failOnPoolExhaustion() {
 
   @Test
   public void idleSessionCleanup() throws Exception {
+    ReadContext context = mock(ReadContext.class);
     options =
         SessionPoolOptions.newBuilder()
             .setMinSessions(1)
@@ -648,9 +648,9 @@ public void idleSessionCleanup() throws Exception {
             .setIncStep(1)
             .setMaxIdleSessions(0)
             .build();
-    SessionImpl session1 = mockSession();
-    SessionImpl session2 = mockSession();
-    SessionImpl session3 = mockSession();
+    SessionImpl session1 = buildMockSession(context);
+    SessionImpl session2 = buildMockSession(context);
+    SessionImpl session3 = buildMockSession(context);
     final LinkedList<SessionImpl> sessions =
         new LinkedList<>(Arrays.asList(session1, session2, session3));
     doAnswer(
@@ -665,11 +665,12 @@ public void idleSessionCleanup() throws Exception {
             })
         .when(sessionClient)
         .asyncBatchCreateSessions(Mockito.eq(1), Mockito.anyBoolean(), any(SessionConsumer.class));
-    for (SessionImpl session : sessions) {
-      mockKeepAlive(session);
-    }
+
     FakeClock clock = new FakeClock();
     clock.currentTimeMillis = System.currentTimeMillis();
+
+    mockKeepAlive(context);
+
     pool = createPool(clock);
     // Make sure pool has been initialized
     pool.getSession().close();
@@ -1034,9 +1035,10 @@ public void longRunningTransactionsCleanup_whenException_doNothing() throws Exce
   }
 
   private void setupForLongRunningTransactionsCleanup() {
-    SessionImpl session1 = mockSession();
-    SessionImpl session2 = mockSession();
-    SessionImpl session3 = mockSession();
+    ReadContext context = mock(ReadContext.class);
+    SessionImpl session1 = buildMockSession(context);
+    SessionImpl session2 = buildMockSession(context);
+    SessionImpl session3 = buildMockSession(context);
 
     final LinkedList<SessionImpl> sessions =
         new LinkedList<>(Arrays.asList(session1, session2, session3));
@@ -1053,16 +1055,20 @@ private void setupForLongRunningTransactionsCleanup() {
         .when(sessionClient)
         .asyncBatchCreateSessions(Mockito.eq(1), Mockito.anyBoolean(), any(SessionConsumer.class));
 
-    for (SessionImpl session : sessions) {
-      mockKeepAlive(session);
-    }
+    mockKeepAlive(context);
   }
 
   @Test
   public void keepAlive() throws Exception {
+    ReadContext context = mock(ReadContext.class);
     options = SessionPoolOptions.newBuilder().setMinSessions(2).setMaxSessions(3).build();
-    final SessionImpl session = mockSession();
-    mockKeepAlive(session);
+    final SessionImpl mockSession1 = buildMockSession(context);
+    final SessionImpl mockSession2 = buildMockSession(context);
+    final SessionImpl mockSession3 = buildMockSession(context);
+    final LinkedList<SessionImpl> sessions =
+        new LinkedList<>(Arrays.asList(mockSession1, mockSession2, mockSession3));
+
+    mockKeepAlive(context);
     // This is cheating as we are returning the same session each but it makes the verification
     // easier.
     doAnswer(
@@ -1073,7 +1079,7 @@ public void keepAlive() throws Exception {
                     SessionConsumerImpl consumer =
                         invocation.getArgument(2, SessionConsumerImpl.class);
                     for (int i = 0; i < sessionCount; i++) {
-                      consumer.onSessionReady(session);
+                      consumer.onSessionReady(sessions.pop());
                     }
                   });
               return null;
@@ -1090,9 +1096,9 @@ public void keepAlive() throws Exception {
     session1.close();
     session2.close();
     runMaintenanceLoop(clock, pool, pool.poolMaintainer.numKeepAliveCycles);
-    verify(session, never()).singleUse(any(TimestampBound.class));
+    verify(context, never()).executeQuery(any(Statement.class));
     runMaintenanceLoop(clock, pool, pool.poolMaintainer.numKeepAliveCycles);
-    verify(session, times(2)).singleUse(any(TimestampBound.class));
+    verify(context, times(2)).executeQuery(Statement.newBuilder("SELECT 1").build());
     clock.currentTimeMillis +=
         clock.currentTimeMillis + (options.getKeepAliveIntervalMinutes() + 5) * 60 * 1000;
     session1 = pool.getSession();
@@ -1100,8 +1106,8 @@ public void keepAlive() throws Exception {
     session1.close();
     runMaintenanceLoop(clock, pool, pool.poolMaintainer.numKeepAliveCycles);
     // The session pool only keeps MinSessions + MaxIdleSessions alive.
-    verify(session, times(options.getMinSessions() + options.getMaxIdleSessions()))
-        .singleUse(any(TimestampBound.class));
+    verify(context, times(options.getMinSessions() + options.getMaxIdleSessions()))
+        .executeQuery(Statement.newBuilder("SELECT 1").build());
     pool.closeAsync(new SpannerImpl.ClosedException()).get(5L, TimeUnit.SECONDS);
   }
 
@@ -1799,6 +1805,12 @@ public void testWaitOnMinSessionsThrowsExceptionWhenTimeoutIsReached() {
     pool.maybeWaitOnMinSessions();
   }
 
+  private void mockKeepAlive(ReadContext context) {
+    ResultSet resultSet = mock(ResultSet.class);
+    when(resultSet.next()).thenReturn(true, false);
+    when(context.executeQuery(any(Statement.class))).thenReturn(resultSet);
+  }
+
   private void mockKeepAlive(Session session) {
     ReadContext context = mock(ReadContext.class);
     ResultSet resultSet = mock(ResultSet.class);
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java
index e13b7b75093..efc57fb480b 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java
@@ -273,10 +273,15 @@ public void batchDmlFailedPrecondition() {
   @Test
   public void inlineBegin() {
     SpannerImpl spanner = mock(SpannerImpl.class);
+    SpannerOptions options = mock(SpannerOptions.class);
+
     when(spanner.getRpc()).thenReturn(rpc);
     when(spanner.getDefaultQueryOptions(Mockito.any(DatabaseId.class)))
         .thenReturn(QueryOptions.getDefaultInstance());
-    when(spanner.getOptions()).thenReturn(mock(SpannerOptions.class));
+    when(spanner.getOptions()).thenReturn(options);
+    SessionPoolOptions sessionPoolOptions = SessionPoolOptions.newBuilder().build();
+    when(options.getSessionPoolOptions()).thenReturn(sessionPoolOptions);
+
     SessionImpl session =
         new SessionImpl(
             spanner, "projects/p/instances/i/databases/d/sessions/s", Collections.EMPTY_MAP) {