diff --git a/google/cloud/spanner/CMakeLists.txt b/google/cloud/spanner/CMakeLists.txt index aefd54430f93d..a80fdcf01525c 100644 --- a/google/cloud/spanner/CMakeLists.txt +++ b/google/cloud/spanner/CMakeLists.txt @@ -162,6 +162,7 @@ add_library( internal/range_from_pagination.h internal/retry_loop.cc internal/retry_loop.h + internal/session.cc internal/session.h internal/session_pool.cc internal/session_pool.h diff --git a/google/cloud/spanner/internal/connection_impl.cc b/google/cloud/spanner/internal/connection_impl.cc index 04d5addae8d3d..dfb6f5fb66244 100644 --- a/google/cloud/spanner/internal/connection_impl.cc +++ b/google/cloud/spanner/internal/connection_impl.cc @@ -106,11 +106,10 @@ ConnectionImpl::ConnectionImpl(Database db, std::shared_ptr stub, std::unique_ptr retry_policy, std::unique_ptr backoff_policy) : db_(std::move(db)), - stub_(std::move(stub)), retry_policy_prototype_(std::move(retry_policy)), backoff_policy_prototype_(std::move(backoff_policy)), session_pool_(std::make_shared( - db_, stub_, retry_policy_prototype_->clone(), + db_, std::move(stub), retry_policy_prototype_->clone(), backoff_policy_prototype_->clone())) {} RowStream ConnectionImpl::Read(ReadParams params) { @@ -328,10 +327,9 @@ RowStream ConnectionImpl::ReadImpl(SessionHolder& session, request.set_partition_token(*std::move(params.partition_token)); } - auto const& stub = stub_; // Capture a copy of `stub` to ensure the `shared_ptr<>` remains valid through - // the lifetime of the lambda. Note that the local variable `stub` is a - // reference to avoid increasing refcounts twice, but the capture is by value. + // the lifetime of the lambda. + auto stub = session_pool_->GetStub(*session); auto factory = [stub, request](std::string const& resume_token) mutable { request.set_resume_token(resume_token); auto context = google::cloud::internal::make_unique(); @@ -379,12 +377,13 @@ StatusOr> ConnectionImpl::PartitionReadImpl( *request.mutable_key_set() = internal::ToProto(params.keys); *request.mutable_partition_options() = internal::ToProto(partition_options); + auto stub = session_pool_->GetStub(*session); auto response = internal::RetryLoop( retry_policy_prototype_->clone(), backoff_policy_prototype_->clone(), true, - [this](grpc::ClientContext& context, - spanner_proto::PartitionReadRequest const& request) { - return stub_->PartitionRead(context, request); + [&stub](grpc::ClientContext& context, + spanner_proto::PartitionReadRequest const& request) { + return stub->PartitionRead(context, request); }, request, __func__); if (!response.ok()) { @@ -452,11 +451,10 @@ ResultType ConnectionImpl::CommonQueryImpl( if (!prepare_status.ok()) { return MakeStatusOnlyResult(std::move(prepare_status)); } - // Capture a copy of of these member variables to ensure the `shared_ptr<>` - // remains valid through the lifetime of the lambda. Note that the local - // variables are a reference to avoid increasing refcounts twice, but the - // capture is by value. - auto const& stub = stub_; + // Capture a copy of of these to ensure the `shared_ptr<>` remains valid + // through the lifetime of the lambda. Note that the local variables are a + // reference to avoid increasing refcounts twice, but the capture is by value. + auto stub = session_pool_->GetStub(*session); auto const& retry_policy = retry_policy_prototype_; auto const& backoff_policy = backoff_policy_prototype_; @@ -512,11 +510,10 @@ StatusOr ConnectionImpl::CommonDmlImpl( if (!prepare_status.ok()) { return prepare_status; } - // Capture a copy of of these member variables to ensure the `shared_ptr<>` - // remains valid through the lifetime of the lambda. Note that the local - // variables are a reference to avoid increasing refcounts twice, but the - // capture is by value. - auto const& stub = stub_; + // Capture a copy of of these to ensure the `shared_ptr<>` remains valid + // through the lifetime of the lambda. Note that the local variables are a + // reference to avoid increasing refcounts twice, but the capture is by value. + auto stub = session_pool_->GetStub(*session); auto const& retry_policy = retry_policy_prototype_; auto const& backoff_policy = backoff_policy_prototype_; @@ -586,12 +583,13 @@ StatusOr> ConnectionImpl::PartitionQueryImpl( std::move(*sql_statement.mutable_param_types()); *request.mutable_partition_options() = internal::ToProto(partition_options); + auto stub = session_pool_->GetStub(*session); auto response = internal::RetryLoop( retry_policy_prototype_->clone(), backoff_policy_prototype_->clone(), true, - [this](grpc::ClientContext& context, - spanner_proto::PartitionQueryRequest const& request) { - return stub_->PartitionQuery(context, request); + [&stub](grpc::ClientContext& context, + spanner_proto::PartitionQueryRequest const& request) { + return stub->PartitionQuery(context, request); }, request, __func__); if (!response.ok()) { @@ -628,12 +626,13 @@ StatusOr ConnectionImpl::ExecuteBatchDmlImpl( *request.add_statements() = internal::ToProto(std::move(sql)); } + auto stub = session_pool_->GetStub(*session); auto response = internal::RetryLoop( retry_policy_prototype_->clone(), backoff_policy_prototype_->clone(), true, - [this](grpc::ClientContext& context, - spanner_proto::ExecuteBatchDmlRequest const& request) { - return stub_->ExecuteBatchDml(context, request); + [&stub](grpc::ClientContext& context, + spanner_proto::ExecuteBatchDmlRequest const& request) { + return stub->ExecuteBatchDml(context, request); }, request, __func__); if (!response) { @@ -665,12 +664,13 @@ StatusOr ConnectionImpl::ExecutePartitionedDmlImpl( begin_request.set_session(session->session_name()); *begin_request.mutable_options()->mutable_partitioned_dml() = spanner_proto::TransactionOptions_PartitionedDml(); + auto stub = session_pool_->GetStub(*session); auto begin_response = internal::RetryLoop( retry_policy_prototype_->clone(), backoff_policy_prototype_->clone(), true, - [this](grpc::ClientContext& context, - spanner_proto::BeginTransactionRequest const& request) { - return stub_->BeginTransaction(context, request); + [&stub](grpc::ClientContext& context, + spanner_proto::BeginTransactionRequest const& request) { + return stub->BeginTransaction(context, request); }, begin_request, __func__); if (!begin_response) { @@ -690,9 +690,9 @@ StatusOr ConnectionImpl::ExecutePartitionedDmlImpl( auto response = internal::RetryLoop( retry_policy_prototype_->clone(), backoff_policy_prototype_->clone(), true, - [this](grpc::ClientContext& context, - spanner_proto::ExecuteSqlRequest const& request) { - return stub_->ExecuteSql(context, request); + [&stub](grpc::ClientContext& context, + spanner_proto::ExecuteSqlRequest const& request) { + return stub->ExecuteSql(context, request); }, request, __func__); if (!response) { @@ -729,12 +729,13 @@ StatusOr ConnectionImpl::CommitImpl( is_idempotent = true; } + auto stub = session_pool_->GetStub(*session); auto response = internal::RetryLoop( retry_policy_prototype_->clone(), backoff_policy_prototype_->clone(), is_idempotent, - [this](grpc::ClientContext& context, - spanner_proto::CommitRequest const& request) { - return stub_->Commit(context, request); + [&stub](grpc::ClientContext& context, + spanner_proto::CommitRequest const& request) { + return stub->Commit(context, request); }, request, __func__); if (!response) { @@ -764,12 +765,13 @@ Status ConnectionImpl::RollbackImpl(SessionHolder& session, spanner_proto::RollbackRequest request; request.set_session(session->session_name()); request.set_transaction_id(s.id()); + auto stub = session_pool_->GetStub(*session); return internal::RetryLoop( retry_policy_prototype_->clone(), backoff_policy_prototype_->clone(), true, - [this](grpc::ClientContext& context, - spanner_proto::RollbackRequest const& request) { - return stub_->Rollback(context, request); + [&stub](grpc::ClientContext& context, + spanner_proto::RollbackRequest const& request) { + return stub->Rollback(context, request); }, request, __func__); } diff --git a/google/cloud/spanner/internal/connection_impl.h b/google/cloud/spanner/internal/connection_impl.h index 3840f03ae7d59..4e1bf6ed60a19 100644 --- a/google/cloud/spanner/internal/connection_impl.h +++ b/google/cloud/spanner/internal/connection_impl.h @@ -159,7 +159,6 @@ class ConnectionImpl : public Connection { google::spanner::v1::ExecuteSqlRequest::QueryMode query_mode); Database db_; - std::shared_ptr stub_; std::shared_ptr retry_policy_prototype_; std::shared_ptr backoff_policy_prototype_; std::shared_ptr session_pool_; diff --git a/google/cloud/spanner/internal/session.cc b/google/cloud/spanner/internal/session.cc new file mode 100644 index 0000000000000..ba385cdaf146b --- /dev/null +++ b/google/cloud/spanner/internal/session.cc @@ -0,0 +1,32 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "google/cloud/spanner/internal/session.h" + +namespace google { +namespace cloud { +namespace spanner { +inline namespace SPANNER_CLIENT_NS { +namespace internal { + +SessionHolder MakeDissociatedSessionHolder(std::string session_name) { + return SessionHolder(new Session(std::move(session_name), /*stub=*/nullptr), + std::default_delete()); +} + +} // namespace internal +} // namespace SPANNER_CLIENT_NS +} // namespace spanner +} // namespace cloud +} // namespace google diff --git a/google/cloud/spanner/internal/session.h b/google/cloud/spanner/internal/session.h index 6040ae1178fda..219b460dae293 100644 --- a/google/cloud/spanner/internal/session.h +++ b/google/cloud/spanner/internal/session.h @@ -15,6 +15,7 @@ #ifndef GOOGLE_CLOUD_CPP_SPANNER_GOOGLE_CLOUD_SPANNER_INTERNAL_SESSION_H_ #define GOOGLE_CLOUD_CPP_SPANNER_GOOGLE_CLOUD_SPANNER_INTERNAL_SESSION_H_ +#include "google/cloud/spanner/internal/spanner_stub.h" #include "google/cloud/spanner/version.h" #include #include @@ -32,8 +33,8 @@ namespace internal { */ class Session { public: - Session(std::string session_name) noexcept - : session_name_(std::move(session_name)) {} + Session(std::string session_name, std::shared_ptr stub) noexcept + : session_name_(std::move(session_name)), stub_(std::move(stub)) {} // Not copyable or moveable. Session(Session const&) = delete; @@ -44,7 +45,11 @@ class Session { std::string const& session_name() const { return session_name_; } private: - std::string session_name_; + friend class SessionPool; // for access to stub() + std::shared_ptr stub() const { return stub_; } + + std::string const session_name_; + std::shared_ptr const stub_; }; /** @@ -60,11 +65,7 @@ using SessionHolder = std::unique_ptr>; * like partitioned operations where the `Session` may be used on multiple * machines and should not be returned to the pool. */ -template -SessionHolder MakeDissociatedSessionHolder(Args&&... args) { - return SessionHolder(new Session(std::forward(args)...), - std::default_delete()); -} +SessionHolder MakeDissociatedSessionHolder(std::string session_name); } // namespace internal } // namespace SPANNER_CLIENT_NS diff --git a/google/cloud/spanner/internal/session_pool.cc b/google/cloud/spanner/internal/session_pool.cc index 4421d5abd7841..90bc838003008 100644 --- a/google/cloud/spanner/internal/session_pool.cc +++ b/google/cloud/spanner/internal/session_pool.cc @@ -141,6 +141,15 @@ StatusOr SessionPool::Allocate(bool dissociate_from_pool) { } } +std::shared_ptr SessionPool::GetStub(Session const& session) { + std::shared_ptr stub = session.stub(); + if (stub) return stub; + + // Sessions that were created for partitioned Reads/Queries do not have + // their own stub, so return one to use. + return stub_; +} + void SessionPool::Release(Session* session) { std::unique_lock lk(mu_); bool notify = sessions_.empty(); @@ -172,7 +181,7 @@ StatusOr>> SessionPool::CreateSessions( sessions.reserve(response->session_size()); for (auto& session : *response->mutable_session()) { sessions.push_back(google::cloud::internal::make_unique( - std::move(*session.mutable_name()))); + std::move(*session.mutable_name()), stub_)); } return {std::move(sessions)}; } diff --git a/google/cloud/spanner/internal/session_pool.h b/google/cloud/spanner/internal/session_pool.h index 8ea2ae1491198..5f1a00a2af106 100644 --- a/google/cloud/spanner/internal/session_pool.h +++ b/google/cloud/spanner/internal/session_pool.h @@ -122,6 +122,11 @@ class SessionPool : public std::enable_shared_from_this { */ StatusOr Allocate(bool dissociate_from_pool = false); + /** + * Return a `SpannerStub` to be used when making calls using `session`. + */ + std::shared_ptr GetStub(Session const& session); + private: /** * Release session back to the pool. diff --git a/google/cloud/spanner/internal/session_pool_test.cc b/google/cloud/spanner/internal/session_pool_test.cc index 1c679d94f2d23..8bed4dc0e0967 100644 --- a/google/cloud/spanner/internal/session_pool_test.cc +++ b/google/cloud/spanner/internal/session_pool_test.cc @@ -84,6 +84,7 @@ TEST(SessionPool, Allocate) { auto session = pool->Allocate(); ASSERT_STATUS_OK(session); EXPECT_EQ((*session)->session_name(), "session1"); + EXPECT_EQ(pool->GetStub(**session), mock); } TEST(SessionPool, CreateError) { @@ -233,6 +234,15 @@ TEST(SessionPool, MaxSessionsBlockUntilRelease) { t.join(); } +TEST(SessionPool, GetStubForStublessSession) { + auto mock = std::make_shared(); + auto db = Database("project", "instance", "database"); + auto pool = MakeSessionPool(db, mock); + // ensure we get a stub even if we didn't allocate from the pool. + auto session = MakeDissociatedSessionHolder("session_id"); + EXPECT_EQ(pool->GetStub(*session), mock); +} + } // namespace } // namespace internal } // namespace SPANNER_CLIENT_NS diff --git a/google/cloud/spanner/spanner_client.bzl b/google/cloud/spanner/spanner_client.bzl index 8382c687199d1..c0c5fb2975037 100644 --- a/google/cloud/spanner/spanner_client.bzl +++ b/google/cloud/spanner/spanner_client.bzl @@ -110,6 +110,7 @@ spanner_client_srcs = [ "internal/partial_result_set_resume.cc", "internal/partial_result_set_source.cc", "internal/retry_loop.cc", + "internal/session.cc", "internal/session_pool.cc", "internal/spanner_stub.cc", "internal/time.cc",