Skip to content

Commit

Permalink
Merge pull request #962 from AtakanColak/fix-ppid-race-961
Browse files Browse the repository at this point in the history
Fix Windows Ppid Cache Race Condition
  • Loading branch information
Lomanic authored Oct 29, 2020
2 parents bb232c4 + 13602a3 commit f810d51
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 7 deletions.
7 changes: 6 additions & 1 deletion process/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"runtime"
"sort"
"sync"
"syscall"
"time"

Expand All @@ -26,6 +27,7 @@ type Process struct {
name string
status string
parent int32
parentMutex *sync.RWMutex // for windows ppid cache
numCtxSwitches *NumCtxSwitchesStat
uids []int32
gids []int32
Expand Down Expand Up @@ -167,7 +169,10 @@ func NewProcess(pid int32) (*Process, error) {
}

func NewProcessWithContext(ctx context.Context, pid int32) (*Process, error) {
p := &Process{Pid: pid}
p := &Process{
Pid: pid,
parentMutex: new(sync.RWMutex),
}

exists, err := PidExistsWithContext(ctx, pid)
if err != nil {
Expand Down
30 changes: 30 additions & 0 deletions process/process_race_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// +build race

package process

import (
"sync"
"testing"
)

func Test_Process_Ppid_Race(t *testing.T) {
wg := sync.WaitGroup{}
testCount := 10
p := testGetProcess()
wg.Add(testCount)
for i := 0; i < testCount; i++ {
go func(j int) {
ppid, err := p.Ppid()
wg.Done()
skipIfNotImplementedErr(t, err)
if err != nil {
t.Errorf("Ppid() failed, %v", err)
}

if j == 9 {
t.Logf("Ppid(): %d", ppid)
}
}(i)
}
wg.Wait()
}
34 changes: 28 additions & 6 deletions process/process_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,18 @@ func PidExistsWithContext(ctx context.Context, pid int32) (bool, error) {

func (p *Process) PpidWithContext(ctx context.Context) (int32, error) {
// if cached already, return from cache
if p.parent != 0 {
return p.parent, nil
cachedPpid := p.getPpid()
if cachedPpid != 0 {
return cachedPpid, nil
}

ppid, _, _, err := getFromSnapProcess(p.Pid)
if err != nil {
return 0, err
}

// if no errors, cache it
p.parent = ppid
// no errors and not cached already, so cache it
p.setPpid(ppid)

return ppid, nil
}
Expand All @@ -258,8 +259,11 @@ func (p *Process) NameWithContext(ctx context.Context) (string, error) {
return "", fmt.Errorf("could not get Name: %s", err)
}

// if no errors, cache ppid
// if no errors and not cached already, cache ppid
p.parent = ppid
if 0 == p.getPpid() {
p.setPpid(ppid)
}

return name, nil
}
Expand Down Expand Up @@ -456,8 +460,11 @@ func (p *Process) NumThreadsWithContext(ctx context.Context) (int32, error) {
return 0, err
}

// if no errors, cache ppid
// if no errors and not cached already, cache ppid
p.parent = ppid
if 0 == p.getPpid() {
p.setPpid(ppid)
}

return ret, nil
}
Expand Down Expand Up @@ -613,6 +620,21 @@ func (p *Process) KillWithContext(ctx context.Context) error {
return process.Kill()
}

// retrieve Ppid in a thread-safe manner
func (p *Process) getPpid() int32 {
p.parentMutex.RLock()
defer p.parentMutex.RUnlock()
return p.parent
}

// cache Ppid in a thread-safe manner (WINDOWS ONLY)
// see https://psutil.readthedocs.io/en/latest/#psutil.Process.ppid
func (p *Process) setPpid(ppid int32) {
p.parentMutex.Lock()
defer p.parentMutex.Unlock()
p.parent = ppid
}

func getFromSnapProcess(pid int32) (int32, int32, string, error) {
snap, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, uint32(pid))
if err != nil {
Expand Down

0 comments on commit f810d51

Please sign in to comment.