Skip to content

Commit

Permalink
revert changes in HTTPContext
Browse files Browse the repository at this point in the history
  • Loading branch information
Samu Tamminen committed Jan 27, 2022
1 parent 7d6a087 commit 5cb513a
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 69 deletions.
10 changes: 0 additions & 10 deletions pkg/context/contexttest/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package contexttest

import (
"crypto/tls"
"io"
"net/http"

Expand Down Expand Up @@ -48,7 +47,6 @@ type MockedHTTPRequest struct {
MockedSetBody func(io.Reader)
MockedStd func() *http.Request
MockedSize func() uint64
MockedTLS func() *tls.ConnectionState
}

// RealIP mocks the RealIP function of HTTPRequest
Expand Down Expand Up @@ -212,11 +210,3 @@ func (r *MockedHTTPRequest) Size() uint64 {
}
return 0
}

// TLS mocks the TLS function of HTTPRequest
func (r *MockedHTTPRequest) TLS() *tls.ConnectionState {
if r.MockedTLS != nil {
return r.MockedTLS()
}
return nil
}
3 changes: 0 additions & 3 deletions pkg/context/httpcontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package context

import (
stdcontext "context"
"crypto/tls"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -106,8 +105,6 @@ type (
Std() *http.Request

Size() uint64 // bytes

TLS() *tls.ConnectionState
}

// HTTPResponse is all operations for HTTP response.
Expand Down
5 changes: 0 additions & 5 deletions pkg/context/httprequest.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package context

import (
"crypto/tls"
"io"
"net/http"

Expand Down Expand Up @@ -164,10 +163,6 @@ func (r *httpRequest) Size() uint64 {
return uint64(r.metaSize + r.bodyCount)
}

func (r *httpRequest) TLS() *tls.ConnectionState {
return r.std.TLS
}

func (r *httpRequest) finish() {
// NOTE: We don't use this line in case of large flow attack.
// io.Copy(io.Discard, r.std.Body)
Expand Down
88 changes: 47 additions & 41 deletions pkg/filter/certextractor/certextractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,47 +90,53 @@ func (ce *CertExtractor) Handle(ctx httpcontext.HTTPContext) string {
// CertExtractor extracts given field from TLS certificates and sets it to request headers.
func (ce *CertExtractor) handle(ctx httpcontext.HTTPContext) string {
r := ctx.Request()
if connectionState := r.TLS(); connectionState != nil {
certs := connectionState.PeerCertificates
if certs != nil && len(certs) > 0 {
n := int16(len(certs))
// positive ce.spec.CertIndex from the beginning, negative from the end
index := (n + ce.spec.CertIndex) % n
cert := certs[index]

var target pkix.Name
if ce.spec.Target == "subject" {
target = cert.Subject
} else {
target = cert.Issuer
}

var result []string
switch ce.spec.Field {
case "Country":
result = target.Country
case "Organization":
result = target.Organization
case "OrganizationalUnit":
result = target.OrganizationalUnit
case "Locality":
result = target.Locality
case "Province":
result = target.Province
case "StreetAddress":
result = target.StreetAddress
case "PostalCode":
result = target.PostalCode
case "SerialNumber":
result = append(result, target.SerialNumber)
case "CommonName":
result = append(result, target.CommonName)
}
for _, res := range result {
if res != "" {
r.Header().Add(ce.headerKey, res)
}
}
connectionState := r.Std().TLS
if connectionState == nil {
return ""
}

certs := connectionState.PeerCertificates
if certs == nil || len(certs) < 1 {
return ""
}

n := int16(len(certs))
// positive ce.spec.CertIndex from the beginning, negative from the end
relativeIndex := ce.spec.CertIndex % n
index := (n + relativeIndex) % n
cert := certs[index]

var target pkix.Name
if ce.spec.Target == "subject" {
target = cert.Subject
} else {
target = cert.Issuer
}

var result []string
switch ce.spec.Field {
case "Country":
result = target.Country
case "Organization":
result = target.Organization
case "OrganizationalUnit":
result = target.OrganizationalUnit
case "Locality":
result = target.Locality
case "Province":
result = target.Province
case "StreetAddress":
result = target.StreetAddress
case "PostalCode":
result = target.PostalCode
case "SerialNumber":
result = append(result, target.SerialNumber)
case "CommonName":
result = append(result, target.CommonName)
}
for _, res := range result {
if res != "" {
r.Header().Add(ce.headerKey, res)
}
}
return ""
Expand Down
27 changes: 17 additions & 10 deletions pkg/filter/certextractor/certextractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ field: "CommonName"
func prepareCtxAndHeader(connState *tls.ConnectionState) (*contexttest.MockedHTTPContext, http.Header) {
ctx := &contexttest.MockedHTTPContext{}
header := http.Header{}
ctx.MockedRequest.MockedTLS = func() *tls.ConnectionState {
return connState
stdr := &http.Request{}
stdr.TLS = connState
ctx.MockedRequest.MockedStd = func() *http.Request {
return stdr
}
ctx.MockedRequest.MockedHeader = func() *httpheader.HTTPHeader {
return httpheader.New(header)
Expand Down Expand Up @@ -176,13 +178,6 @@ func TestHandle(t *testing.T) {
assert.Equal([]string{"1", "2", "3", "4", "5", "6", "7", "8"}, header["Key"])
})
t.Run("multiple certs", func(t *testing.T) {
yamlConfig := `
kind: "CertExtractor"
name: "cn-extractor"
certIndex: -2 # second last certificate
target: "subject"
field: "Province"
`
for _, val := range []string{"second", "third", "fourth"} {
peerCertificates = append(peerCertificates, &x509.Certificate{
Subject: pkix.Name{
Expand All @@ -193,10 +188,22 @@ field: "Province"
}
connState := &tls.ConnectionState{PeerCertificates: peerCertificates}
ctx, header := prepareCtxAndHeader(connState)

yamlConfig := `
kind: "CertExtractor"
name: "cn-extractor"
certIndex: -2 # second last certificate
target: "subject"
field: "Province"
`
ce, _ := createCertExtractor(yamlConfig, nil, nil)
assert.Equal("", ce.Handle(ctx))
assert.Equal("third", header.Get("tls-subject-province"))

ctx, header = prepareCtxAndHeader(connState)
yamlConfig2 := strings.ReplaceAll(yamlConfig, "certIndex: -2", "certIndex: -15")
ce, _ = createCertExtractor(yamlConfig2, nil, nil)
assert.Equal("", ce.Handle(ctx))
assert.Equal("second", header.Get("tls-subject-province"))
})
})
}

0 comments on commit 5cb513a

Please sign in to comment.