From e77c7a94a5659318648e8ab42a4d39cbeb1c7222 Mon Sep 17 00:00:00 2001 From: salrashid123 Date: Fri, 14 Jun 2024 07:41:05 -0400 Subject: [PATCH] rework session closer --- README.md | 27 ++++++++++++--------------- tpmsigner.go | 33 ++++++++++++--------------------- 2 files changed, 24 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 70e5274..a26bb81 100644 --- a/README.md +++ b/README.md @@ -321,8 +321,7 @@ Note, you can define your own policy for import too...just implement the "sessio ```golang type Session interface { - io.Closer // read closer to the TPM - GetSession() (auth tpm2.Session, err error) // this supplies the session handle to the library + GetSession() (auth tpm2.Session, closer func() error, err error) // this supplies the session handle to the library } ``` @@ -330,22 +329,24 @@ eg: ```golang // for pcr sessions -type MySession struct { +type MyCustomSession struct { rwr transport.TPM sel []tpm2.TPMSPCRSelection } -func NewMySession(rwr transport.TPM, sel []tpm2.TPMSPCRSelection) (MySession, error) { - return MySession{rwr, sel}, nil +func NewMyCustomSession(rwr transport.TPM, sel []tpm2.TPMSPCRSelection) (MyCustomSession, error) { + return MyCustomSession{rwr, sel}, nil } -func (p MySession) GetSession() (auth tpm2.Session, err error) { - sess, _, err := tpm2.PolicySession(p.rwr, tpm2.TPMAlgSHA256, 16) +func (p MyCustomSession) GetSession() (auth tpm2.Session, closer func() error, err error) { + + sess, closer, err := tpm2.PolicySession(p.rwr, tpm2.TPMAlgSHA256, 16) if err != nil { - return nil, err + return nil, nil, err } - // defineyour poicy here + // implement whatever you want here, i'm just using policypcr + _, err = tpm2.PolicyPCR{ PolicySession: sess.Handle(), Pcrs: tpm2.TPMLPCRSelection{ @@ -353,13 +354,9 @@ func (p MySession) GetSession() (auth tpm2.Session, err error) { }, }.Execute(p.rwr) if err != nil { - return nil, err + return nil, nil, err } - return sess, nil -} - -func (p MySession) Close() error { - return nil + return sess, closer, nil } ``` diff --git a/tpmsigner.go b/tpmsigner.go index e6be978..d5742e3 100644 --- a/tpmsigner.go +++ b/tpmsigner.go @@ -212,13 +212,12 @@ func (s *SigningMethodTPM) Sign(signingString string, key interface{}) ([]byte, var se tpm2.Session if config.AuthSession != nil { - se, err = config.AuthSession.GetSession() + var closer func() error + se, closer, err = config.AuthSession.GetSession() if err != nil { return nil, fmt.Errorf("tpmjwt: error getting session %s", s.Alg()) } - defer func() { - _, err = (&tpm2.FlushContext{FlushHandle: se.Handle()}).Execute(rwr) - }() + defer closer() } else { se = tpm2.PasswordAuth(nil) } @@ -351,8 +350,7 @@ func (s *SigningMethodTPM) Verify(signingString string, signature []byte, key in } type Session interface { - io.Closer // read closer to the TPM - GetSession() (auth tpm2.Session, err error) // this supplies the session handle to the library + GetSession() (auth tpm2.Session, closer func() error, err error) // this supplies the session handle to the library } // for pcr sessions @@ -365,10 +363,10 @@ func NewPCRSession(rwr transport.TPM, sel []tpm2.TPMSPCRSelection) (PCRSession, return PCRSession{rwr, sel}, nil } -func (p PCRSession) GetSession() (auth tpm2.Session, err error) { - sess, _, err := tpm2.PolicySession(p.rwr, tpm2.TPMAlgSHA256, 16) +func (p PCRSession) GetSession() (auth tpm2.Session, closer func() error, err error) { + sess, closer, err := tpm2.PolicySession(p.rwr, tpm2.TPMAlgSHA256, 16) if err != nil { - return nil, err + return nil, nil, err } _, err = tpm2.PolicyPCR{ PolicySession: sess.Handle(), @@ -377,13 +375,9 @@ func (p PCRSession) GetSession() (auth tpm2.Session, err error) { }, }.Execute(p.rwr) if err != nil { - return nil, err + return nil, nil, err } - return sess, nil -} - -func (p PCRSession) Close() error { - return nil + return sess, closer, nil } // for password sessions @@ -396,10 +390,7 @@ func NewPasswordSession(rwr transport.TPM, password []byte) (PasswordSession, er return PasswordSession{rwr, password}, nil } -func (p PasswordSession) GetSession() (auth tpm2.Session, err error) { - return tpm2.PasswordAuth(p.password), nil -} - -func (p PasswordSession) Close() error { - return nil +func (p PasswordSession) GetSession() (auth tpm2.Session, closer func() error, err error) { + c := func() error { return nil } + return tpm2.PasswordAuth(p.password), c, nil }