Skip to content

Commit

Permalink
Remove scanner pool in favor of single-use scanners
Browse files Browse the repository at this point in the history
Signed-off-by: egibs <[email protected]>
  • Loading branch information
egibs committed Jan 17, 2025
1 parent 983bfae commit d343832
Show file tree
Hide file tree
Showing 6 changed files with 4 additions and 203 deletions.
10 changes: 0 additions & 10 deletions cmd/mal/mal.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,30 +255,20 @@ func main() {
maxScanners = concurrency
}

var pool *malcontent.ScannerPool
if mc.ScannerPool == nil {
pool, err = malcontent.NewScannerPool(yrs, maxScanners)
if err != nil {
returnCode = ExitInvalidRules
}
}

mc = malcontent.Config{
Concurrency: concurrency,
ExitFirstHit: exitFirstHitFlag,
ExitFirstMiss: exitFirstMissFlag,
IgnoreSelf: ignoreSelfFlag,
IgnoreTags: ignoreTags,
IncludeDataFiles: includeDataFiles,
MaxScanners: maxScanners,
MinFileRisk: minFileRisk,
MinRisk: minRisk,
OCI: ociFlag,
QuantityIncreasesRisk: quantityIncreasesRiskFlag,
Renderer: renderer,
Rules: yrs,
ScanPaths: scanPaths,
ScannerPool: pool,
Stats: statsFlag,
}

Expand Down
18 changes: 1 addition & 17 deletions pkg/action/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,6 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
yrs = c.Rules
}

var pool *malcontent.ScannerPool
if c.ScannerPool == nil {
pool, err = malcontent.NewScannerPool(yrs, c.MaxScanners)
if err != nil {
return nil, fmt.Errorf("failed to create scanner pool: %w", err)
}
c.ScannerPool = pool
}

var scanner *yarax.Scanner
scanner, err = c.ScannerPool.Get()
if err != nil {
return nil, fmt.Errorf("failed to retrieve scanner: %w", err)
}
defer c.ScannerPool.Put(scanner)

isArchive := archiveRoot != ""
mime := "<unknown>"
kind, err := programkind.File(path)
Expand All @@ -91,7 +75,7 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
return nil, err
}

mrs, err := scanner.Scan(fc)
mrs, err := yrs.Scan(fc)
if err != nil {
logger.Debug("skipping", slog.Any("error", err))
return &malcontent.FileReport{Path: path, Error: fmt.Sprintf("scan: %v", err)}, nil
Expand Down
144 changes: 0 additions & 144 deletions pkg/malcontent/malcontent.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@ package malcontent

import (
"context"
"fmt"
"io"
"io/fs"
"runtime"
"sync"
"sync/atomic"

yarax "github.com/VirusTotal/yara-x/go"
orderedmap "github.com/wk8/go-ordered-map/v2"
Expand All @@ -33,7 +30,6 @@ type Config struct {
IgnoreSelf bool
IgnoreTags []string
IncludeDataFiles bool
MaxScanners int
MinFileRisk int
MinRisk int
OCI bool
Expand All @@ -45,7 +41,6 @@ type Config struct {
Rules *yarax.Rules
Scan bool
ScanPaths []string
ScannerPool *ScannerPool
Stats bool
TrimPrefixes []string
}
Expand Down Expand Up @@ -148,142 +143,3 @@ type CombinedReport struct {
RemovedFR *FileReport
Score float64
}

// ScannerPool manages a limited pool of YARA scanners.
type ScannerPool struct {
mu sync.Mutex
rules *yarax.Rules
scanners []*yarax.Scanner
available chan *yarax.Scanner
maxScanners int32
currentCount int32
closed atomic.Bool
}

// NewScannerPool creates a new scanner pool with a maximum number of scanners.
func NewScannerPool(rules *yarax.Rules, maxScanners int) (*ScannerPool, error) {
if rules == nil {
return nil, fmt.Errorf("rules cannot be nil")
}
if maxScanners < 1 {
maxScanners = max(1, runtime.GOMAXPROCS(0)/2)
}

// #nosec G115 // ignore converting int to int32
pool := &ScannerPool{
rules: rules,
available: make(chan *yarax.Scanner, maxScanners),
maxScanners: int32(maxScanners),
scanners: make([]*yarax.Scanner, 0, maxScanners),
closed: atomic.Bool{},
}

scanner := yarax.NewScanner(rules)
if scanner == nil {
return nil, fmt.Errorf("failed to create scanner")
}

pool.available <- scanner
atomic.AddInt32(&pool.currentCount, 1)

return pool, nil
}

// createScanner creates a new yarax scanner.
func (p *ScannerPool) createScanner() (*yarax.Scanner, error) {
if atomic.LoadInt32(&p.currentCount) > p.maxScanners/2 {
runtime.GC()
}

if p.rules == nil {
return nil, fmt.Errorf("rules not initialized")
}

scanner := yarax.NewScanner(p.rules)
if scanner == nil {
return nil, fmt.Errorf("failed to create new scanner")
}

if err := p.validateScanner(scanner); err != nil {
scanner.Destroy()
return nil, err
}

return scanner, nil
}

// validateScanner attempts to compile the provided rules.
func (p *ScannerPool) validateScanner(scanner *yarax.Scanner) error {
if scanner == nil {
return fmt.Errorf("nil scanner")
}
_, err := scanner.Scan([]byte("test"))
if err != nil {
return fmt.Errorf("scanner validation failed: %w", err)
}
return nil
}

// Get retrieves a scanner from the pool or creates a new one if necessary.
func (p *ScannerPool) Get() (*yarax.Scanner, error) {
if p.closed.Load() {
return nil, fmt.Errorf("scanner pool is closed")
}

// Retrieve an existing scanner
// If none are available, create up to the maximum number of scanners
select {
case scanner := <-p.available:
return scanner, nil
default:
p.mu.Lock()
if atomic.LoadInt32(&p.currentCount) < p.maxScanners {
scanner, err := p.createScanner()
if err != nil {
p.mu.Unlock()
return nil, fmt.Errorf("create scanner: %w", err)
}
p.scanners = append(p.scanners, scanner)
atomic.AddInt32(&p.currentCount, 1)
p.mu.Unlock()
return scanner, nil
}
p.mu.Unlock()

return <-p.available, nil
}
}

// Put returns a scanner to the pool.
func (p *ScannerPool) Put(scanner *yarax.Scanner) {
if scanner == nil || p.closed.Load() {
return
}
p.available <- scanner
}

// Cleanup destroys all scanners in the pool.
func (p *ScannerPool) Cleanup() {
p.mu.Lock()
defer p.mu.Unlock()

if p.closed.Swap(true) {
return
}

for len(p.available) > 0 {
if scanner := <-p.available; scanner != nil {
scanner.Destroy()
}
}
close(p.available)

for _, scanner := range p.scanners {
if scanner != nil {
scanner.Destroy()
}
}

p.scanners = nil
atomic.StoreInt32(&p.currentCount, 0)
}
10 changes: 0 additions & 10 deletions pkg/refresh/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ func actionRefresh(ctx context.Context) ([]TestData, error) {
c := &malcontent.Config{
Concurrency: runtime.NumCPU(),
IgnoreSelf: false,
MaxScanners: runtime.NumCPU(),
MinFileRisk: 0,
MinRisk: 0,
OCI: false,
Expand All @@ -80,15 +79,6 @@ func actionRefresh(ctx context.Context) ([]TestData, error) {
TrimPrefixes: []string{"pkg/action/"},
}

var pool *malcontent.ScannerPool
if c.ScannerPool == nil {
pool, err = malcontent.NewScannerPool(yrs, c.MaxScanners)
if err != nil {
return nil, err
}
c.ScannerPool = pool
}

testData = append(testData, TestData{
Config: c,
OutputPath: output,
Expand Down
10 changes: 0 additions & 10 deletions pkg/refresh/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ func diffRefresh(ctx context.Context, rc Config) ([]TestData, error) {
Concurrency: runtime.NumCPU(),
FileRiskChange: td.riskChange,
FileRiskIncrease: td.riskIncrease,
MaxScanners: runtime.NumCPU(),
MinFileRisk: minFileRisk,
MinRisk: minRisk,
QuantityIncreasesRisk: true,
Expand All @@ -206,15 +205,6 @@ func diffRefresh(ctx context.Context, rc Config) ([]TestData, error) {
TrimPrefixes: []string{rc.SamplesPath},
}

var pool *malcontent.ScannerPool
if c.ScannerPool == nil {
pool, err = malcontent.NewScannerPool(yrs, c.MaxScanners)
if err != nil {
return nil, err
}
c.ScannerPool = pool
}

testData = append(testData, TestData{
Config: c,
OutputPath: output,
Expand Down
15 changes: 3 additions & 12 deletions pkg/refresh/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ func newConfig(rc Config) *malcontent.Config {
return &malcontent.Config{
Concurrency: runtime.NumCPU(),
IgnoreTags: []string{"harmless"},
MaxScanners: runtime.NumCPU(),
MinFileRisk: 1,
MinRisk: 1,
QuantityIncreasesRisk: true,
Expand Down Expand Up @@ -135,15 +134,6 @@ func prepareRefresh(ctx context.Context, rc Config) ([]TestData, error) {
c.Renderer = r
c.Rules = yrs

var pool *malcontent.ScannerPool
if c.ScannerPool == nil {
pool, err = malcontent.NewScannerPool(yrs, c.Concurrency)
if err != nil {
return nil, err
}
c.ScannerPool = pool
}

if strings.HasSuffix(data, ".mdiff") || strings.HasSuffix(data, ".sdiff") {
dirPath := filepath.Dir(sample)
files, err := os.ReadDir(dirPath)
Expand Down Expand Up @@ -179,13 +169,14 @@ func prepareRefresh(ctx context.Context, rc Config) ([]TestData, error) {
}

// executeRefresh reads from a populated slice of TestData.
func executeRefresh(ctx context.Context, testData []TestData) error {
func executeRefresh(ctx context.Context, c Config, testData []TestData) error {
g, ctx := errgroup.WithContext(ctx)

var mu sync.Mutex
completed := 0
total := len(testData)

g.SetLimit(c.Concurrency)
for _, data := range testData {
g.Go(func() error {
select {
Expand Down Expand Up @@ -252,5 +243,5 @@ func Refresh(ctx context.Context, rc Config) error {
return fmt.Errorf("failed to prepare sample data refresh: %w", err)
}

return executeRefresh(ctx, testData)
return executeRefresh(ctx, rc, testData)
}

0 comments on commit d343832

Please sign in to comment.