Skip to content

Commit 394ac3a

Browse files
authored
Fix cache deadlock - update query cache so even the initiating scan is a setRequest subscriber. Closes #586
1 parent be638c9 commit 394ac3a

17 files changed

+865
-416
lines changed

grpc/quals.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"bytes"
55
"encoding/json"
66
"fmt"
7-
"golang.org/x/exp/maps"
87
"strings"
98
"time"
109

@@ -45,7 +44,13 @@ func QualMapToLogLine(qualMap map[string]*proto.Quals) string {
4544
if len(qualMap) == 0 {
4645
return "NONE"
4746
}
48-
return strings.Join(maps.Keys(qualMap), ",")
47+
var line strings.Builder
48+
for column, quals := range qualMap {
49+
for _, q := range quals.Quals {
50+
line.WriteString(fmt.Sprintf("%s %s %s, ", column, q.Operator, q.Value.String()))
51+
}
52+
}
53+
return line.String()
4954
}
5055

5156
func QualMapsEqual(l map[string]*proto.Quals, r map[string]*proto.Quals) bool {

plugin/plugin.go

+50-40
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package plugin
22

33
import (
44
"context"
5-
"errors"
65
"fmt"
76
"log"
87
"os"
@@ -21,7 +20,6 @@ import (
2120
"github.com/hashicorp/go-hclog"
2221
"github.com/turbot/go-kit/helpers"
2322
connectionmanager "github.com/turbot/steampipe-plugin-sdk/v5/connection"
24-
"github.com/turbot/steampipe-plugin-sdk/v5/error_helpers"
2523
"github.com/turbot/steampipe-plugin-sdk/v5/grpc"
2624
"github.com/turbot/steampipe-plugin-sdk/v5/grpc/proto"
2725
"github.com/turbot/steampipe-plugin-sdk/v5/logging"
@@ -264,28 +262,24 @@ func (p *Plugin) ConnectionSchemaChanged(connection *Connection) error {
264262
return nil
265263
}
266264

267-
func (p *Plugin) executeForConnection(ctx context.Context, req *proto.ExecuteRequest, connectionName string, outputChan chan *proto.ExecuteResponse) (err error) {
265+
func (p *Plugin) executeForConnection(streamContext context.Context, req *proto.ExecuteRequest, connectionName string, outputChan chan *proto.ExecuteResponse, logger hclog.Logger) (err error) {
268266
const rowBufferSize = 10
269267
var rowChan = make(chan *proto.Row, rowBufferSize)
270268

271269
executeData := req.ExecuteConnectionData[connectionName]
272270

273271
// build callId for this connection (this is necessary is the plugin Execute call may be for an aggregator connection)
274272
connectionCallId := p.getConnectionCallId(req.CallId, connectionName)
275-
// when done, remove call id from map
276-
defer p.clearCallId(connectionCallId)
277273

278274
log.Printf("[INFO] executeForConnection callId: %s, connectionCallId: %s, connection: %s table: %s cols: %s", req.CallId, connectionCallId, connectionName, req.Table, strings.Join(req.QueryContext.Columns, ","))
279275

280276
defer func() {
281-
log.Printf("[TRACE] executeForConnection DEFER (%s) ", connectionCallId)
282277
if r := recover(); r != nil {
283-
log.Printf("[WARN] Execute recover from panic: callId: %s table: %s error: %v", connectionCallId, req.Table, r)
278+
log.Printf("[WARN] executeForConnection recover from panic: callId: %s table: %s error: %v", connectionCallId, req.Table, r)
284279
err = helpers.ToError(r)
285280
return
286281
}
287-
288-
log.Printf("[TRACE] Execute complete callId: %s table: %s ", connectionCallId, req.Table)
282+
log.Printf("[INFO] executeForConnection COMPLETE callId: %s, connectionCallId: %s, connection: %s table: %s cols: %s ", req.CallId, connectionCallId, connectionName, req.Table, strings.Join(req.QueryContext.Columns, ","))
289283
}()
290284

291285
// the connection property must be set already
@@ -312,6 +306,16 @@ func (p *Plugin) executeForConnection(ctx context.Context, req *proto.ExecuteReq
312306
log.Printf("[INFO] caching is disabled for table %s", table.Name)
313307
}
314308
}
309+
310+
// if cache NOT disabled, create a fresh context for this scan
311+
ctx := streamContext
312+
var cancel context.CancelFunc
313+
if cacheEnabled {
314+
// get a fresh context which includes telemetry data and logger
315+
ctx, cancel = context.WithCancel(context.Background())
316+
}
317+
ctx = p.buildExecuteContext(ctx, req, logger)
318+
315319
logging.LogTime("Start execute")
316320

317321
queryContext := NewQueryContext(req.QueryContext, limitParam, cacheEnabled, cacheTTL, table)
@@ -336,6 +340,10 @@ func (p *Plugin) executeForConnection(ctx context.Context, req *proto.ExecuteReq
336340
return err
337341
}
338342

343+
// set the cancel func on the query data
344+
// (this is only used if the cache is enabled - if a set request has no subscribers)
345+
queryData.cancel = cancel
346+
339347
// get the matrix item
340348
log.Printf("[TRACE] GetMatrixItem")
341349
var matrixItem []map[string]any
@@ -362,23 +370,25 @@ func (p *Plugin) executeForConnection(ctx context.Context, req *proto.ExecuteReq
362370
ConnectionName: connectionName,
363371
TtlSeconds: queryContext.CacheTTL,
364372
CallId: connectionCallId,
373+
StreamContext: streamContext,
365374
}
366375
// can we satisfy this request from the cache?
367376
if cacheEnabled {
368377
log.Printf("[INFO] cacheEnabled, trying cache get (%s)", connectionCallId)
369378

370379
// create a function to increment cachedRowsFetched and stream a row
371-
streamRowFunc := func(row *proto.Row) {
380+
streamUncachedRowFunc := queryData.streamRow
381+
streamCachedRowFunc := func(row *proto.Row) {
372382
// if row is not nil (indicating completion), increment cachedRowsFetched
373383
if row != nil {
374384
atomic.AddInt64(&queryData.queryStatus.cachedRowsFetched, 1)
375385
}
376-
queryData.streamRow(row)
386+
streamUncachedRowFunc(row)
377387
}
378388

379389
start := time.Now()
380390
// try to fetch this data from the query cache
381-
cacheErr := p.queryCache.Get(ctx, cacheRequest, streamRowFunc)
391+
cacheErr := p.queryCache.Get(ctx, cacheRequest, streamUncachedRowFunc, streamCachedRowFunc)
382392
if cacheErr == nil {
383393
// so we got a cached result - stream it out
384394
log.Printf("[INFO] queryCacheGet returned CACHE HIT (%s)", connectionCallId)
@@ -393,45 +403,40 @@ func (p *Plugin) executeForConnection(ctx context.Context, req *proto.ExecuteReq
393403
}
394404

395405
// so the cache call failed, with either a cache-miss or other error
396-
if query_cache.IsCacheMiss(cacheErr) {
397-
log.Printf("[TRACE] cache MISS")
398-
} else if errors.Is(cacheErr, error_helpers.QueryError{}) {
399-
// if this is a QueryError, this means the pending item we were waitign for failed
400-
// > we also fail
406+
if !query_cache.IsCacheMiss(cacheErr) {
407+
log.Printf("[WARN] queryCacheGet returned err %s", cacheErr.Error())
401408
return cacheErr
402-
} else {
403-
// otherwise just log the cache error
404-
log.Printf("[TRACE] queryCacheGet returned err %s", cacheErr.Error())
405409
}
406-
410+
// otherwise just log the cache miss error
407411
log.Printf("[INFO] queryCacheGet returned CACHE MISS (%s)", connectionCallId)
408412
} else {
409-
log.Printf("[INFO] Cache DISABLED connectionCallId: %s", connectionCallId)
413+
log.Printf("[INFO] Cache DISABLED (%s)", connectionCallId)
410414
}
411415

416+
// so we need to fetch the data
417+
412418
// asyncronously fetch items
413-
log.Printf("[TRACE] calling fetchItems, table: %s, matrixItem: %v, limit: %d, connectionCallId: %s\"", table.Name, queryData.Matrix, limit, connectionCallId)
419+
log.Printf("[INFO] calling fetchItems, table: %s, matrixItem: %v, limit: %d (%s)", table.Name, queryData.Matrix, limit, connectionCallId)
414420
if err := table.fetchItems(ctx, queryData); err != nil {
415421
log.Printf("[WARN] fetchItems returned an error, table: %s, error: %v", table.Name, err)
416422
return err
417423

418424
}
419-
logging.LogTime("Calling build Rows")
420-
421-
log.Printf("[TRACE] buildRowsAsync connectionCallId: %s", connectionCallId)
422425

423426
// asyncronously build rows
427+
logging.LogTime("Calling build Rows")
428+
log.Printf("[TRACE] buildRowsAsync (%s)", connectionCallId)
429+
424430
// channel used by streamRows when it receives an error to tell buildRowsAsync to stop
425431
doneChan := make(chan bool)
426432
queryData.buildRowsAsync(ctx, rowChan, doneChan)
427433

428-
log.Printf("[TRACE] streamRows connectionCallId: %s", connectionCallId)
429-
434+
// stream rows either into cache (if enabled) or back across GRPC (if not)
430435
logging.LogTime("Calling streamRows")
431436

432-
// stream rows across GRPC
433437
err = queryData.streamRows(ctx, rowChan, doneChan)
434438
if err != nil {
439+
log.Printf("[WARN] queryData.streamRows returned error: %s", err.Error())
435440
return err
436441
}
437442

@@ -569,42 +574,47 @@ func (p *Plugin) buildConnectionSchemaMap() map[string]*grpc.PluginSchema {
569574
return res
570575
}
571576

572-
func (p *Plugin) getConnectionCallId(callId string, connectionName string) string {
573-
// add connection name onto call id
574-
connectionCallId := grpc.BuildConnectionCallId(callId, connectionName)
577+
// ensure callId is unique fo rthis plugin instance - important as it is used to key set requests
578+
func (p *Plugin) getUniqueCallId(callId string) string {
575579
// store as orig as we may mutate connectionCallId to dedupe
576-
orig := connectionCallId
580+
orig := callId
577581
// check if it unique - this is crucial as it is used to key 'set requests` in the query cache
578582
idx := 0
579583
p.callIdLookupMut.RLock()
580584
for {
581-
if _, callIdExists := p.callIdLookup[connectionCallId]; !callIdExists {
585+
if _, callIdExists := p.callIdLookup[callId]; !callIdExists {
582586
// release read lock and get a write lock
583587
p.callIdLookupMut.RUnlock()
584588
p.callIdLookupMut.Lock()
585589

586590
// recheck as ther eis a race condition to acquire a write lockm
587-
if _, callIdExists := p.callIdLookup[connectionCallId]; !callIdExists {
591+
if _, callIdExists := p.callIdLookup[callId]; !callIdExists {
588592
// store in map
589-
p.callIdLookup[connectionCallId] = struct{}{}
593+
p.callIdLookup[callId] = struct{}{}
590594
p.callIdLookupMut.Unlock()
591-
return connectionCallId
595+
return callId
592596
}
593597

594598
// someone must have got in there before us - downgrade lock again
595599
p.callIdLookupMut.Unlock()
596600
p.callIdLookupMut.RLock()
597601
}
598602
// so the id exists already - add a suffix
599-
log.Printf("[WARN] getConnectionCallId duplicate call id %s - adding suffix", connectionCallId)
600-
connectionCallId = fmt.Sprintf("%s%d", orig, idx)
603+
log.Printf("[WARN] getUniqueCallId duplicate call id %s - adding suffix", callId)
604+
callId = fmt.Sprintf("%s%d", orig, idx)
601605
idx++
602606

603607
}
604608
p.callIdLookupMut.RUnlock()
605-
return connectionCallId
609+
return callId
610+
}
611+
612+
func (p *Plugin) getConnectionCallId(callId string, connectionName string) string {
613+
// add connection name onto call id
614+
return grpc.BuildConnectionCallId(callId, connectionName)
606615
}
607616

617+
// remove callId from callIdLookup
608618
func (p *Plugin) clearCallId(connectionCallId string) {
609619
p.callIdLookupMut.Lock()
610620
delete(p.callIdLookup, connectionCallId)

plugin/plugin_grpc.go

+12-9
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,19 @@ func (p *Plugin) getSchema(connectionName string) (*grpc.PluginSchema, error) {
178178
//
179179
// This is the handler function for the execute GRPC function.
180180
func (p *Plugin) execute(req *proto.ExecuteRequest, stream proto.WrapperPlugin_ExecuteServer) (err error) {
181+
ctx := stream.Context()
181182
// add CallId to logs for the execute call
182183
logger := p.Logger.Named(req.CallId)
183184
log.SetOutput(logger.StandardWriter(&hclog.StandardLoggerOptions{InferLevels: true}))
184185
log.SetPrefix("")
185186
log.SetFlags(0)
186187

187-
log.Printf("[INFO] Plugin execute table: %s (%s)", req.Table, req.CallId)
188+
// dedupe the call id
189+
req.CallId = p.getUniqueCallId(req.CallId)
190+
// when done, remove call id from map
191+
defer p.clearCallId(req.CallId)
192+
193+
log.Printf("[INFO] Plugin execute table: %s quals: %s (%s)", req.Table, grpc.QualMapToLogLine(req.QueryContext.Quals), req.CallId)
188194
defer log.Printf("[INFO] Plugin execute complete (%s)", req.CallId)
189195

190196
// limit the plugin memory
@@ -194,11 +200,8 @@ func (p *Plugin) execute(req *proto.ExecuteRequest, stream proto.WrapperPlugin_E
194200

195201
outputChan := make(chan *proto.ExecuteResponse, len(req.ExecuteConnectionData))
196202
errorChan := make(chan error, len(req.ExecuteConnectionData))
197-
//doneChan := make(chan bool)
198-
var outputWg sync.WaitGroup
199203

200-
// get a context which includes telemetry data and logger
201-
ctx := p.buildExecuteContext(stream.Context(), req, logger)
204+
var outputWg sync.WaitGroup
202205

203206
// control how many connections are executed in parallel
204207
maxConcurrentConnections := getMaxConcurrentConnections()
@@ -234,10 +237,9 @@ func (p *Plugin) execute(req *proto.ExecuteRequest, stream proto.WrapperPlugin_E
234237
}
235238
defer sem.Release(1)
236239

237-
if err := p.executeForConnection(ctx, req, c, outputChan); err != nil {
238-
if !error_helpers.IsContextCancelledError(err) {
239-
log.Printf("[WARN] executeForConnection %s returned error %s", c, err.Error())
240-
}
240+
// execute the scan for this connection
241+
if err := p.executeForConnection(ctx, req, c, outputChan, logger); err != nil {
242+
log.Printf("[WARN] executeForConnection %s returned error %s, writing to CHAN", c, err.Error())
241243
errorChan <- err
242244
}
243245
log.Printf("[TRACE] executeForConnection %s returned", c)
@@ -280,6 +282,7 @@ func (p *Plugin) execute(req *proto.ExecuteRequest, stream proto.WrapperPlugin_E
280282
}
281283
}
282284

285+
log.Printf("[INFO] Plugin execute table: %s closing error chan and output chan (%s)", req.Table, req.CallId)
283286
close(outputChan)
284287
close(errorChan)
285288

0 commit comments

Comments
 (0)