Skip to content

Commit

Permalink
rework session closer
Browse files Browse the repository at this point in the history
  • Loading branch information
salrashid123 committed Jun 14, 2024
1 parent 6445826 commit e77c7a9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 36 deletions.
27 changes: 12 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -321,45 +321,42 @@ Note, you can define your own policy for import too...just implement the "sessio

```golang
type Session interface {
io.Closer // read closer to the TPM
GetSession() (auth tpm2.Session, err error) // this supplies the session handle to the library
GetSession() (auth tpm2.Session, closer func() error, err error) // this supplies the session handle to the library
}
```

eg:

```golang
// for pcr sessions
type MySession struct {
type MyCustomSession struct {
rwr transport.TPM
sel []tpm2.TPMSPCRSelection
}

func NewMySession(rwr transport.TPM, sel []tpm2.TPMSPCRSelection) (MySession, error) {
return MySession{rwr, sel}, nil
func NewMyCustomSession(rwr transport.TPM, sel []tpm2.TPMSPCRSelection) (MyCustomSession, error) {
return MyCustomSession{rwr, sel}, nil
}

func (p MySession) GetSession() (auth tpm2.Session, err error) {
sess, _, err := tpm2.PolicySession(p.rwr, tpm2.TPMAlgSHA256, 16)
func (p MyCustomSession) GetSession() (auth tpm2.Session, closer func() error, err error) {

sess, closer, err := tpm2.PolicySession(p.rwr, tpm2.TPMAlgSHA256, 16)
if err != nil {
return nil, err
return nil, nil, err
}

// defineyour poicy here
// implement whatever you want here, i'm just using policypcr

_, err = tpm2.PolicyPCR{
PolicySession: sess.Handle(),
Pcrs: tpm2.TPMLPCRSelection{
PCRSelections: p.sel,
},
}.Execute(p.rwr)
if err != nil {
return nil, err
return nil, nil, err
}
return sess, nil
}

func (p MySession) Close() error {
return nil
return sess, closer, nil
}
```

Expand Down
33 changes: 12 additions & 21 deletions tpmsigner.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,12 @@ func (s *SigningMethodTPM) Sign(signingString string, key interface{}) ([]byte,

var se tpm2.Session
if config.AuthSession != nil {
se, err = config.AuthSession.GetSession()
var closer func() error
se, closer, err = config.AuthSession.GetSession()
if err != nil {
return nil, fmt.Errorf("tpmjwt: error getting session %s", s.Alg())
}
defer func() {
_, err = (&tpm2.FlushContext{FlushHandle: se.Handle()}).Execute(rwr)
}()
defer closer()
} else {
se = tpm2.PasswordAuth(nil)
}
Expand Down Expand Up @@ -351,8 +350,7 @@ func (s *SigningMethodTPM) Verify(signingString string, signature []byte, key in
}

type Session interface {
io.Closer // read closer to the TPM
GetSession() (auth tpm2.Session, err error) // this supplies the session handle to the library
GetSession() (auth tpm2.Session, closer func() error, err error) // this supplies the session handle to the library
}

// for pcr sessions
Expand All @@ -365,10 +363,10 @@ func NewPCRSession(rwr transport.TPM, sel []tpm2.TPMSPCRSelection) (PCRSession,
return PCRSession{rwr, sel}, nil
}

func (p PCRSession) GetSession() (auth tpm2.Session, err error) {
sess, _, err := tpm2.PolicySession(p.rwr, tpm2.TPMAlgSHA256, 16)
func (p PCRSession) GetSession() (auth tpm2.Session, closer func() error, err error) {
sess, closer, err := tpm2.PolicySession(p.rwr, tpm2.TPMAlgSHA256, 16)
if err != nil {
return nil, err
return nil, nil, err
}
_, err = tpm2.PolicyPCR{
PolicySession: sess.Handle(),
Expand All @@ -377,13 +375,9 @@ func (p PCRSession) GetSession() (auth tpm2.Session, err error) {
},
}.Execute(p.rwr)
if err != nil {
return nil, err
return nil, nil, err
}
return sess, nil
}

func (p PCRSession) Close() error {
return nil
return sess, closer, nil
}

// for password sessions
Expand All @@ -396,10 +390,7 @@ func NewPasswordSession(rwr transport.TPM, password []byte) (PasswordSession, er
return PasswordSession{rwr, password}, nil
}

func (p PasswordSession) GetSession() (auth tpm2.Session, err error) {
return tpm2.PasswordAuth(p.password), nil
}

func (p PasswordSession) Close() error {
return nil
func (p PasswordSession) GetSession() (auth tpm2.Session, closer func() error, err error) {
c := func() error { return nil }
return tpm2.PasswordAuth(p.password), c, nil
}

0 comments on commit e77c7a9

Please sign in to comment.