diff --git a/session.go b/session.go index a245caa..e87f1f0 100644 --- a/session.go +++ b/session.go @@ -3,6 +3,7 @@ package mockoidc import ( "errors" "strings" + "sync" "time" "github.com/golang-jwt/jwt" @@ -21,6 +22,7 @@ type Session struct { // SessionStore manages our Session objects type SessionStore struct { + sync.RWMutex Store map[string]*Session CodeQueue *CodeQueue } @@ -55,14 +57,18 @@ func (ss *SessionStore) NewSession(scope string, nonce string, user User, codeCh CodeChallenge: codeChallenge, CodeChallengeMethod: codeChallengeMethod, } + ss.Lock() ss.Store[sessionID] = session + ss.Unlock() return session, nil } // GetSessionByID looks up the Session func (ss *SessionStore) GetSessionByID(id string) (*Session, error) { + ss.RLock() session, ok := ss.Store[id] + ss.RUnlock() if !ok { return nil, errors.New("session not found") } diff --git a/session_test.go b/session_test.go index 54cd10b..ed44334 100644 --- a/session_test.go +++ b/session_test.go @@ -43,7 +43,9 @@ func TestSessionStore_NewSession(t *testing.T) { assert.NoError(t, err) assert.Equal(t, session.Scopes, []string{"openid", "email", "profile"}) assert.Equal(t, len(ss.Store), 1) + ss.RLock() assert.Equal(t, ss.Store[session.SessionID], session) + ss.RUnlock() assert.Equal(t, session.CodeChallenge, "sum") assert.Equal(t, session.CodeChallengeMethod, "S256") }