Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove scanner pool in favor of single-use scanners #765

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions cmd/mal/mal.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,30 +255,20 @@ func main() {
maxScanners = concurrency
}

var pool *malcontent.ScannerPool
if mc.ScannerPool == nil {
pool, err = malcontent.NewScannerPool(yrs, maxScanners)
if err != nil {
returnCode = ExitInvalidRules
}
}

mc = malcontent.Config{
Concurrency: concurrency,
ExitFirstHit: exitFirstHitFlag,
ExitFirstMiss: exitFirstMissFlag,
IgnoreSelf: ignoreSelfFlag,
IgnoreTags: ignoreTags,
IncludeDataFiles: includeDataFiles,
MaxScanners: maxScanners,
MinFileRisk: minFileRisk,
MinRisk: minRisk,
OCI: ociFlag,
QuantityIncreasesRisk: quantityIncreasesRiskFlag,
Renderer: renderer,
Rules: yrs,
ScanPaths: scanPaths,
ScannerPool: pool,
Stats: statsFlag,
}

Expand Down
18 changes: 1 addition & 17 deletions pkg/action/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,6 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
yrs = c.Rules
}

var pool *malcontent.ScannerPool
if c.ScannerPool == nil {
pool, err = malcontent.NewScannerPool(yrs, c.MaxScanners)
if err != nil {
return nil, fmt.Errorf("failed to create scanner pool: %w", err)
}
c.ScannerPool = pool
}

var scanner *yarax.Scanner
scanner, err = c.ScannerPool.Get()
if err != nil {
return nil, fmt.Errorf("failed to retrieve scanner: %w", err)
}
defer c.ScannerPool.Put(scanner)

isArchive := archiveRoot != ""
mime := "<unknown>"
kind, err := programkind.File(path)
Expand All @@ -91,7 +75,7 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
return nil, err
}

mrs, err := scanner.Scan(fc)
mrs, err := yrs.Scan(fc)
if err != nil {
logger.Debug("skipping", slog.Any("error", err))
return &malcontent.FileReport{Path: path, Error: fmt.Sprintf("scan: %v", err)}, nil
Expand Down
144 changes: 0 additions & 144 deletions pkg/malcontent/malcontent.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@ package malcontent

import (
"context"
"fmt"
"io"
"io/fs"
"runtime"
"sync"
"sync/atomic"

yarax "github.com/VirusTotal/yara-x/go"
orderedmap "github.com/wk8/go-ordered-map/v2"
Expand All @@ -33,7 +30,6 @@ type Config struct {
IgnoreSelf bool
IgnoreTags []string
IncludeDataFiles bool
MaxScanners int
MinFileRisk int
MinRisk int
OCI bool
Expand All @@ -45,7 +41,6 @@ type Config struct {
Rules *yarax.Rules
Scan bool
ScanPaths []string
ScannerPool *ScannerPool
Stats bool
TrimPrefixes []string
}
Expand Down Expand Up @@ -148,142 +143,3 @@ type CombinedReport struct {
RemovedFR *FileReport
Score float64
}

// ScannerPool manages a limited pool of YARA scanners.
type ScannerPool struct {
mu sync.Mutex
rules *yarax.Rules
scanners []*yarax.Scanner
available chan *yarax.Scanner
maxScanners int32
currentCount int32
closed atomic.Bool
}

// NewScannerPool creates a new scanner pool with a maximum number of scanners.
func NewScannerPool(rules *yarax.Rules, maxScanners int) (*ScannerPool, error) {
if rules == nil {
return nil, fmt.Errorf("rules cannot be nil")
}
if maxScanners < 1 {
maxScanners = max(1, runtime.GOMAXPROCS(0)/2)
}

// #nosec G115 // ignore converting int to int32
pool := &ScannerPool{
rules: rules,
available: make(chan *yarax.Scanner, maxScanners),
maxScanners: int32(maxScanners),
scanners: make([]*yarax.Scanner, 0, maxScanners),
closed: atomic.Bool{},
}

scanner := yarax.NewScanner(rules)
if scanner == nil {
return nil, fmt.Errorf("failed to create scanner")
}

pool.available <- scanner
atomic.AddInt32(&pool.currentCount, 1)

return pool, nil
}

// createScanner creates a new yarax scanner.
func (p *ScannerPool) createScanner() (*yarax.Scanner, error) {
if atomic.LoadInt32(&p.currentCount) > p.maxScanners/2 {
runtime.GC()
}

if p.rules == nil {
return nil, fmt.Errorf("rules not initialized")
}

scanner := yarax.NewScanner(p.rules)
if scanner == nil {
return nil, fmt.Errorf("failed to create new scanner")
}

if err := p.validateScanner(scanner); err != nil {
scanner.Destroy()
return nil, err
}

return scanner, nil
}

// validateScanner attempts to compile the provided rules.
func (p *ScannerPool) validateScanner(scanner *yarax.Scanner) error {
if scanner == nil {
return fmt.Errorf("nil scanner")
}
_, err := scanner.Scan([]byte("test"))
if err != nil {
return fmt.Errorf("scanner validation failed: %w", err)
}
return nil
}

// Get retrieves a scanner from the pool or creates a new one if necessary.
func (p *ScannerPool) Get() (*yarax.Scanner, error) {
if p.closed.Load() {
return nil, fmt.Errorf("scanner pool is closed")
}

// Retrieve an existing scanner
// If none are available, create up to the maximum number of scanners
select {
case scanner := <-p.available:
return scanner, nil
default:
p.mu.Lock()
if atomic.LoadInt32(&p.currentCount) < p.maxScanners {
scanner, err := p.createScanner()
if err != nil {
p.mu.Unlock()
return nil, fmt.Errorf("create scanner: %w", err)
}
p.scanners = append(p.scanners, scanner)
atomic.AddInt32(&p.currentCount, 1)
p.mu.Unlock()
return scanner, nil
}
p.mu.Unlock()

return <-p.available, nil
}
}

// Put returns a scanner to the pool.
func (p *ScannerPool) Put(scanner *yarax.Scanner) {
if scanner == nil || p.closed.Load() {
return
}
p.available <- scanner
}

// Cleanup destroys all scanners in the pool.
func (p *ScannerPool) Cleanup() {
p.mu.Lock()
defer p.mu.Unlock()

if p.closed.Swap(true) {
return
}

for len(p.available) > 0 {
if scanner := <-p.available; scanner != nil {
scanner.Destroy()
}
}
close(p.available)

for _, scanner := range p.scanners {
if scanner != nil {
scanner.Destroy()
}
}

p.scanners = nil
atomic.StoreInt32(&p.currentCount, 0)
}
10 changes: 0 additions & 10 deletions pkg/refresh/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ func actionRefresh(ctx context.Context) ([]TestData, error) {
c := &malcontent.Config{
Concurrency: runtime.NumCPU(),
IgnoreSelf: false,
MaxScanners: runtime.NumCPU(),
MinFileRisk: 0,
MinRisk: 0,
OCI: false,
Expand All @@ -80,15 +79,6 @@ func actionRefresh(ctx context.Context) ([]TestData, error) {
TrimPrefixes: []string{"pkg/action/"},
}

var pool *malcontent.ScannerPool
if c.ScannerPool == nil {
pool, err = malcontent.NewScannerPool(yrs, c.MaxScanners)
if err != nil {
return nil, err
}
c.ScannerPool = pool
}

testData = append(testData, TestData{
Config: c,
OutputPath: output,
Expand Down
10 changes: 0 additions & 10 deletions pkg/refresh/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ func diffRefresh(ctx context.Context, rc Config) ([]TestData, error) {
Concurrency: runtime.NumCPU(),
FileRiskChange: td.riskChange,
FileRiskIncrease: td.riskIncrease,
MaxScanners: runtime.NumCPU(),
MinFileRisk: minFileRisk,
MinRisk: minRisk,
QuantityIncreasesRisk: true,
Expand All @@ -206,15 +205,6 @@ func diffRefresh(ctx context.Context, rc Config) ([]TestData, error) {
TrimPrefixes: []string{rc.SamplesPath},
}

var pool *malcontent.ScannerPool
if c.ScannerPool == nil {
pool, err = malcontent.NewScannerPool(yrs, c.MaxScanners)
if err != nil {
return nil, err
}
c.ScannerPool = pool
}

testData = append(testData, TestData{
Config: c,
OutputPath: output,
Expand Down
15 changes: 3 additions & 12 deletions pkg/refresh/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ func newConfig(rc Config) *malcontent.Config {
return &malcontent.Config{
Concurrency: runtime.NumCPU(),
IgnoreTags: []string{"harmless"},
MaxScanners: runtime.NumCPU(),
MinFileRisk: 1,
MinRisk: 1,
QuantityIncreasesRisk: true,
Expand Down Expand Up @@ -135,15 +134,6 @@ func prepareRefresh(ctx context.Context, rc Config) ([]TestData, error) {
c.Renderer = r
c.Rules = yrs

var pool *malcontent.ScannerPool
if c.ScannerPool == nil {
pool, err = malcontent.NewScannerPool(yrs, c.Concurrency)
if err != nil {
return nil, err
}
c.ScannerPool = pool
}

if strings.HasSuffix(data, ".mdiff") || strings.HasSuffix(data, ".sdiff") {
dirPath := filepath.Dir(sample)
files, err := os.ReadDir(dirPath)
Expand Down Expand Up @@ -179,13 +169,14 @@ func prepareRefresh(ctx context.Context, rc Config) ([]TestData, error) {
}

// executeRefresh reads from a populated slice of TestData.
func executeRefresh(ctx context.Context, testData []TestData) error {
func executeRefresh(ctx context.Context, c Config, testData []TestData) error {
g, ctx := errgroup.WithContext(ctx)

var mu sync.Mutex
completed := 0
total := len(testData)

g.SetLimit(c.Concurrency)
for _, data := range testData {
g.Go(func() error {
select {
Expand Down Expand Up @@ -252,5 +243,5 @@ func Refresh(ctx context.Context, rc Config) error {
return fmt.Errorf("failed to prepare sample data refresh: %w", err)
}

return executeRefresh(ctx, testData)
return executeRefresh(ctx, rc, testData)
}
Loading