loginsrv

Unnamed repository; edit this file 'description' to name the repository.
git clone git@jamesshield.xyz:repos/loginsrv.git
Log | Files | Refs | README | LICENSE

commit bc4066a1ae81ed819732cdb7a85db2a6fd01c04c
parent 7730ef1eb824217d831f4ac65f6494a006e31c0d
Author: Sebastian Mancke <s.mancke@tarent.de>
Date:   Tue,  2 May 2017 12:48:10 +0200

Merge pull request #1 from tarent/oauth

Implemented login by Oauth
Diffstat:
MREADME.md | 10++++++++++
Mcaddy/setup.go | 42++++++++++++++++--------------------------
Mcaddy/setup_test.go | 19+++++++++----------
Mhtpasswd/backend.go | 10++++++----
Mhtpasswd/backend_test.go | 6+++---
Mlogin/backend.go | 8++++++--
Mlogin/config.go | 196+++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------------
Mlogin/config_test.go | 106++++++++++++++++++-------------------------------------------------------------
Mlogin/handler.go | 172++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------
Mlogin/handler_test.go | 106++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------
Mlogin/login_form.go | 170+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------
Alogin/login_form_test.go | 12++++++++++++
Mlogin/provider.go | 15+++++++++++++++
Mlogin/provider_description.go | 11++---------
Mlogin/simple_backend.go | 14+++++++-------
Mlogin/simple_backend_test.go | 10++++------
Dlogin/user_info.go | 11-----------
Mmain_test.go | 7++++---
Amodel/user_info.go | 15+++++++++++++++
Aoauth2/github.go | 64++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aoauth2/github_test.go | 59+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aoauth2/manager.go | 145+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aoauth2/manager_test.go | 256+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aoauth2/oauth.go | 162+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aoauth2/oauth_test.go | 202+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aoauth2/provider.go | 50++++++++++++++++++++++++++++++++++++++++++++++++++
Aoauth2/provider_test.go | 16++++++++++++++++
Mosiam/backend.go | 10+++++-----
Mosiam/backend_test.go | 6+++---
Mosiam/setup.go | 10++++++++--
30 files changed, 1583 insertions(+), 337 deletions(-)

diff --git a/README.md b/README.md @@ -81,6 +81,10 @@ Returns a simple bootstrap styled login form. The returned html follows the ui composition conventions from (lib-compose)[https://github.com/tarent/lib-compose], so it can be embedded into an existing layout. +### GET /login/<provider> + +Starts the Oauth Web Flow with the configured provider. E.g. `GET /login/github` redirects to the github login form. + ### POST /login Does the login and returns the JWT. Depending on the content-type, and parameters a classical JSON-Rest or a redirect can be performed. @@ -108,6 +112,12 @@ Does the login and returns the JWT. Depending on the content-type, and parameter Hint: The status `401 Unauthorized` is not used as a return code to not conflict with an Http BasicAuth Authentication. +### DELETE /login + +Deletes the JWT Cookie. + +For simple usage in web applications, this can also be called by `GET|POST /login?logout=true` + #### Example: Default is to return the token as Content-Type application/jwt within the body. ``` diff --git a/caddy/setup.go b/caddy/setup.go @@ -1,21 +1,21 @@ package caddy import ( + "flag" "fmt" "github.com/mholt/caddy" "github.com/mholt/caddy/caddyhttp/httpserver" "github.com/tarent/lib-compose/logging" "github.com/tarent/loginsrv/login" "os" - "strconv" ) func init() { - caddy.RegisterPlugin("loginsrv", caddy.Plugin{ + caddy.RegisterPlugin("login", caddy.Plugin{ ServerType: "http", Action: setup, }) - httpserver.RegisterDevDirective("loginsrv", "jwt") + httpserver.RegisterDevDirective("login", "jwt") } // setup configures a new loginsrv instance. @@ -43,26 +43,29 @@ func setup(c *caddy.Controller) error { } else { os.Setenv("JWT_SECRET", config.JwtSecret) } - - loginHandler, err := login.NewHandler(&config) + fmt.Printf("config %+v\n", config) + loginHandler, err := login.NewHandler(config) if err != nil { return err } httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { - return NewCaddyHandler(next, args[0], loginHandler, &config) + return NewCaddyHandler(next, args[0], loginHandler, config) }) } return nil } -func parseConfig(c *caddy.Controller) (login.Config, error) { - cfg := login.DefaultConfig +func parseConfig(c *caddy.Controller) (*login.Config, error) { + cfg := login.DefaultConfig() cfg.Host = "" cfg.Port = "" cfg.LogLevel = "" + fs := flag.NewFlagSet("loginsrv-config", flag.ContinueOnError) + cfg.ConfigureFlagSet(fs) + for c.NextBlock() { name := c.Val() args := c.RemainingArgs() @@ -71,25 +74,12 @@ func parseConfig(c *caddy.Controller) (login.Config, error) { } value := args[0] - switch name { - case "success-url": - cfg.SuccessUrl = value - case "cookie-name": - cfg.CookieName = value - case "cookie-http-only": - b, err := strconv.ParseBool(value) - if err != nil { - return cfg, fmt.Errorf("error parsing bool value %v: %v (%v:%v)", name, value, c.File(), c.Line()) - } - cfg.CookieHttpOnly = b - case "backend": - err := (&cfg.Backends).Set(value) - if err != nil { - return cfg, fmt.Errorf("error parsing backend configuration %v: %v (%v:%v)", name, value, c.File(), c.Line()) - } - default: - return cfg, fmt.Errorf("Unknown option within loginsrv: %v (%v:%v)", name, c.File(), c.Line()) + f := fs.Lookup(name) + if f == nil { + c.ArgErr() + continue } + f.Value.Set(value) } return cfg, nil diff --git a/caddy/setup_test.go b/caddy/setup_test.go @@ -30,12 +30,12 @@ func TestSetup(t *testing.T) { SuccessUrl: "/", CookieName: "jwt_token", CookieHttpOnly: true, - Backends: login.BackendOptions{ - map[string]string{ - "provider": "simple", - "bob": "secret", + Backends: login.Options{ + "simple": map[string]string{ + "bob": "secret", }, }, + Oauth: login.Options{}, }}, { input: `loginsrv / { @@ -52,18 +52,17 @@ func TestSetup(t *testing.T) { SuccessUrl: "successurl", CookieName: "cookiename", CookieHttpOnly: false, - Backends: login.BackendOptions{ - map[string]string{ - "provider": "simple", - "bob": "secret", + Backends: login.Options{ + "simple": map[string]string{ + "bob": "secret", }, - map[string]string{ - "provider": "osiam", + "osiam": map[string]string{ "endpoint": "http://localhost:8080", "clientId": "example-client", "clientSecret": "secret", }, }, + Oauth: login.Options{}, }}, // error cases {input: "loginsrv {\n}", shouldErr: true}, diff --git a/htpasswd/backend.go b/htpasswd/backend.go @@ -3,6 +3,7 @@ package htpasswd import ( "errors" "github.com/tarent/loginsrv/login" + "github.com/tarent/loginsrv/model" ) const ProviderName = "htpasswd" @@ -10,7 +11,8 @@ const ProviderName = "htpasswd" func init() { login.RegisterProvider( &login.ProviderDescription{ - Name: ProviderName, + Name: ProviderName, + HelpText: "Htpasswd login backend opts: file=/path/to/pwdfile", }, BackendFactory) } @@ -35,10 +37,10 @@ func NewBackend(filename string) (*Backend, error) { }, err } -func (sb *Backend) Authenticate(username, password string) (bool, login.UserInfo, error) { +func (sb *Backend) Authenticate(username, password string) (bool, model.UserInfo, error) { authenticated, err := sb.auth.Authenticate(username, password) if authenticated && err == nil { - return authenticated, login.UserInfo{Username: username}, err + return authenticated, model.UserInfo{Sub: username}, err } - return false, login.UserInfo{}, err + return false, model.UserInfo{}, err } diff --git a/htpasswd/backend_test.go b/htpasswd/backend_test.go @@ -37,16 +37,16 @@ func TestSimpleBackend_Authenticate(t *testing.T) { authenticated, userInfo, err := backend.Authenticate("bob-bcrypt", "secret") assert.True(t, authenticated) - assert.Equal(t, "bob-bcrypt", userInfo.Username) + assert.Equal(t, "bob-bcrypt", userInfo.Sub) assert.NoError(t, err) authenticated, userInfo, err = backend.Authenticate("bob-bcrypt", "fooo") assert.False(t, authenticated) - assert.Equal(t, "", userInfo.Username) + assert.Equal(t, "", userInfo.Sub) assert.NoError(t, err) authenticated, userInfo, err = backend.Authenticate("", "") assert.False(t, authenticated) - assert.Equal(t, "", userInfo.Username) + assert.Equal(t, "", userInfo.Sub) assert.NoError(t, err) } diff --git a/login/backend.go b/login/backend.go @@ -1,9 +1,13 @@ package login +import ( + "github.com/tarent/loginsrv/model" +) + type Backend interface { // Authenticate checks the username/password against the backend. - // On success it returns true ans a UserInfo object which has at least the username set. + // On success it returns true and a UserInfo object which has at least the username set. // If the credentials do not match, false is returned. // The error parameter is nil, unless a communication error with the backend occured. - Authenticate(username, password string) (bool, UserInfo, error) + Authenticate(username, password string) (bool, model.UserInfo, error) } diff --git a/login/config.go b/login/config.go @@ -1,41 +1,126 @@ package login import ( + "errors" "flag" "fmt" - "github.com/caarlos0/env" + "github.com/tarent/loginsrv/oauth2" "math/rand" "os" "strings" "time" ) -var DefaultConfig Config +var jwtDefaultSecret string func init() { rand.Seed(time.Now().UTC().UnixNano()) - DefaultConfig = Config{ + jwtDefaultSecret = randStringBytes(32) +} + +func DefaultConfig() *Config { + return &Config{ Host: "localhost", Port: "6789", LogLevel: "info", - JwtSecret: randStringBytes(32), + JwtSecret: jwtDefaultSecret, SuccessUrl: "/", CookieName: "jwt_token", CookieHttpOnly: true, - Backends: BackendOptions{}, + Backends: Options{}, + Oauth: Options{}, } } +const envPrefix = "LOGINSRV_" + type Config struct { - Host string `env:"LOGINSRV_HOST"` - Port string `env:"LOGINSRV_PORT"` - LogLevel string `env:"LOGINSRV_LOG_LEVEL"` - TextLogging bool `env:"LOGINSRV_TEXT_LOGGING"` - JwtSecret string `env:"LOGINSRV_JWT_SECRET"` - SuccessUrl string `env:"LOGINSRV_SUCCESS_URL"` - CookieName string `env:"LOGINSRV_COOKIE_NAME"` - CookieHttpOnly bool `env:"LOGINSRV_COOKIE_HTTP_ONLY"` - Backends BackendOptions + Host string + Port string + LogLevel string + TextLogging bool + JwtSecret string + SuccessUrl string + CookieName string + CookieHttpOnly bool + Backends Options + Oauth Options +} + +// Options is the configuration structure for oauth and backend provider +// key is the providername, value is a options map. +type Options map[string]map[string]string + +// addOauthOpts adds the options for a provider in the form of key=value,key=value,.. +func (c *Config) addOauthOpts(providerName, optsKvList string) error { + opts, err := parseOptions(optsKvList) + if err != nil { + return err + } + + c.Oauth[providerName] = opts + return nil +} + +// addBackendOpts adds the options for a provider in the form of key=value,key=value,.. +func (c *Config) addBackendOpts(providerName, optsKvList string) error { + opts, err := parseOptions(optsKvList) + if err != nil { + return err + } + + c.Backends[providerName] = opts + return nil +} + +// ConfigureFlagSet adds all flags to the supplied flag set +func (c *Config) ConfigureFlagSet(f *flag.FlagSet) { + f.StringVar(&c.Host, "host", c.Host, "The host to listen on") + f.StringVar(&c.Port, "port", c.Port, "The port to listen on") + f.StringVar(&c.LogLevel, "log-level", c.LogLevel, "The log level") + f.BoolVar(&c.TextLogging, "text-logging", c.TextLogging, "Log in text format instead of json") + f.StringVar(&c.JwtSecret, "jwt-secret", "random key", "The secret to sign the jwt token") + f.StringVar(&c.CookieName, "cookie-name", c.CookieName, "The name of the jwt cookie") + f.BoolVar(&c.CookieHttpOnly, "cookie-http-only", c.CookieHttpOnly, "Set the cookie with the http only flag") + f.StringVar(&c.SuccessUrl, "success-url", c.SuccessUrl, "The url to redirect after login") + + // the -backends is deprecated, but we support it for backwards compatibility + deprecatedBackends := setFunc(func(optsKvList string) error { + opts, err := parseOptions(optsKvList) + if err != nil { + return err + } + pName, ok := opts["provider"] + if !ok { + return errors.New("missing provder name provider=..") + } + delete(opts, "provider") + c.Backends[pName] = opts + return nil + }) + f.Var(deprecatedBackends, "backend", "Deprecated, please use the explicit flags") + + // One option for each oauth provider + for _, pName := range oauth2.ProviderList() { + func(pName string) { + setter := setFunc(func(optsKvList string) error { + return c.addOauthOpts(pName, optsKvList) + }) + f.Var(setter, pName, "Oauth config in the form: client_id=..,client_secret=..[,scope=..,][redirect_uri=..]") + }(pName) + } + + // One option for each backend provider + for _, pName := range ProviderList() { + func(pName string) { + setter := setFunc(func(optsKvList string) error { + fmt.Printf("set %v\n", pName) + return c.addBackendOpts(pName, optsKvList) + }) + desc, _ := GetProviderDescription(pName) + f.Var(setter, pName, desc.HelpText) + }(pName) + } } func ReadConfig() *Config { @@ -46,32 +131,19 @@ func ReadConfig() *Config { } return c } -func readConfig(f *flag.FlagSet, args []string) (*Config, error) { - config := DefaultConfig - err := env.Parse(&config) - if err != nil { - return nil, err - } +func readConfig(f *flag.FlagSet, args []string) (*Config, error) { + config := DefaultConfig() + config.ConfigureFlagSet(f) - for _, v := range os.Environ() { - pair := strings.SplitN(v, "=", 2) - if len(pair) == 2 && strings.HasPrefix(pair[0], "LOGINSRV_BACKEND") { - (&config.Backends).Set(pair[1]) + // prefer environment settings + f.VisitAll(func(f *flag.Flag) { + if val, isPresent := os.LookupEnv(envName(f.Name)); isPresent { + f.Value.Set(val) } - } + }) - f.StringVar(&config.Host, "host", config.Host, "The host to listen on") - f.StringVar(&config.Port, "port", config.Port, "The port to listen on") - f.StringVar(&config.LogLevel, "log-level", config.LogLevel, "The log level") - f.BoolVar(&config.TextLogging, "text-logging", config.TextLogging, "Log in text format instead of json") - f.StringVar(&config.JwtSecret, "jwt-secret", "random key", "The secret to sign the jwt token") - f.StringVar(&config.CookieName, "cookie-name", config.CookieName, "The name of the jwt cookie") - f.BoolVar(&config.CookieHttpOnly, "cookie-http-only", config.CookieHttpOnly, "Set the cookie with the http only flag") - f.StringVar(&config.SuccessUrl, "success-url", config.SuccessUrl, "The url to redirect after login") - f.Var(&config.Backends, "backend", "Backend configuration in form 'provider=name,key=val,key=...', can be declared multiple times") - - err = f.Parse(args) + err := f.Parse(args) if err != nil { return nil, err } @@ -79,52 +151,48 @@ func readConfig(f *flag.FlagSet, args []string) (*Config, error) { if config.JwtSecret == "random key" { if s, set := os.LookupEnv("LOGINSRV_JWT_SECRET"); set { config.JwtSecret = s - } else { - config.JwtSecret = DefaultConfig.JwtSecret + config.JwtSecret = jwtDefaultSecret } } - return &config, err + return config, err } -func parseBackendOptions(b string) (map[string]string, error) { +const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + +func randStringBytes(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} + +func envName(flagName string) string { + return envPrefix + strings.Replace(strings.ToUpper(flagName), "-", "_", -1) +} + +func parseOptions(b string) (map[string]string, error) { opts := map[string]string{} pairs := strings.Split(b, ",") for _, p := range pairs { pair := strings.SplitN(p, "=", 2) if len(pair) != 2 { - return nil, fmt.Errorf("provider configuration has to be in form 'provider=name,key1=value1,key2=..', but was %v", p) + return nil, fmt.Errorf("provider configuration has to be in form 'key1=value1,key2=..', but was %v", p) } opts[pair[0]] = pair[1] } - if _, exist := opts["provider"]; !exist { - return nil, fmt.Errorf("no provider name specified in %v", b) - } return opts, nil } -const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - -func randStringBytes(n int) string { - b := make([]byte, n) - for i := range b { - b[i] = letterBytes[rand.Intn(len(letterBytes))] - } - return string(b) -} +// Helper type to wrap a function closure with the Value interface +type setFunc func(optsKvList string) error -type BackendOptions []map[string]string - -func (bo *BackendOptions) String() string { - return fmt.Sprintf("%v", *bo) +func (f setFunc) Set(value string) error { + return f(value) } -func (bo *BackendOptions) Set(value string) error { - optionMap, err := parseBackendOptions(value) - if err != nil { - return err - } - *bo = append(*bo, optionMap) - return nil +func (f setFunc) String() string { + return "setFunc" } diff --git a/login/config_test.go b/login/config_test.go @@ -2,7 +2,6 @@ package login import ( "flag" - "fmt" "github.com/stretchr/testify/assert" "os" "testing" @@ -12,7 +11,11 @@ func TestConfig_ReadConfigDefaults(t *testing.T) { originalArgs := os.Args defer func() { os.Args = originalArgs }() - assert.Equal(t, &DefaultConfig, ReadConfig()) + defaultConfig := DefaultConfig() + gotConfig := ReadConfig() + defaultConfig.JwtSecret = "random" + gotConfig.JwtSecret = "random" + assert.Equal(t, defaultConfig, gotConfig) } func TestConfig_ReadConfig(t *testing.T) { @@ -27,6 +30,7 @@ func TestConfig_ReadConfig(t *testing.T) { "--cookie-http-only=false", "--backend=provider=simple", "--backend=provider=foo", + "--github=client_id=foo,client_secret=bar", } expected := &Config{ @@ -38,12 +42,14 @@ func TestConfig_ReadConfig(t *testing.T) { SuccessUrl: "successurl", CookieName: "cookiename", CookieHttpOnly: false, - Backends: BackendOptions{ - map[string]string{ - "provider": "simple", - }, - map[string]string{ - "provider": "foo", + Backends: Options{ + "simple": map[string]string{}, + "foo": map[string]string{}, + }, + Oauth: Options{ + "github": map[string]string{ + "client_id": "foo", + "client_secret": "bar", }, }, } @@ -62,9 +68,8 @@ func TestConfig_ReadConfigFromEnv(t *testing.T) { assert.NoError(t, os.Setenv("LOGINSRV_SUCCESS_URL", "successurl")) assert.NoError(t, os.Setenv("LOGINSRV_COOKIE_NAME", "cookiename")) assert.NoError(t, os.Setenv("LOGINSRV_COOKIE_HTTP_ONLY", "false")) - assert.NoError(t, os.Setenv("LOGINSRV_BACKEND", "provider=simple,foo=bar")) - assert.NoError(t, os.Setenv("LOGINSRV_BACKEND_FOO", "provider=foo")) - assert.NoError(t, os.Setenv("LOGINSRV_BACKEND_BAR", "provider=bar")) + assert.NoError(t, os.Setenv("LOGINSRV_SIMPLE", "foo=bar")) + assert.NoError(t, os.Setenv("LOGINSRV_GITHUB", "client_id=foo,client_secret=bar")) expected := &Config{ Host: "host", @@ -75,16 +80,15 @@ func TestConfig_ReadConfigFromEnv(t *testing.T) { SuccessUrl: "successurl", CookieName: "cookiename", CookieHttpOnly: false, - Backends: BackendOptions{ - map[string]string{ - "provider": "simple", - "foo": "bar", + Backends: Options{ + "simple": map[string]string{ + "foo": "bar", }, - map[string]string{ - "provider": "foo", - }, - map[string]string{ - "provider": "bar", + }, + Oauth: Options{ + "github": map[string]string{ + "client_id": "foo", + "client_secret": "bar", }, }, } @@ -93,65 +97,3 @@ func TestConfig_ReadConfigFromEnv(t *testing.T) { assert.NoError(t, err) assert.Equal(t, expected, cfg) } - -func TestConfig_ParseBackendOptions(t *testing.T) { - testCases := []struct { - input []string - expected BackendOptions - expectError bool - }{ - { - []string{}, - BackendOptions{}, - false, - }, - { - []string{"name=p1,key1=value1,key2=value2"}, - BackendOptions{}, - true, // no provider name specified - }, - { - []string{ - "provider=simple,name=p1,key1=value1,key2=value2", - "provider=simple,name=p2,key3=value3,key4=value4", - }, - BackendOptions{ - map[string]string{ - "provider": "simple", - "name": "p1", - "key1": "value1", - "key2": "value2", - }, - map[string]string{ - "provider": "simple", - "name": "p2", - "key3": "value3", - "key4": "value4", - }, - }, - false, - }, - { - []string{"foo"}, - BackendOptions{}, - true, - }, - } - for i, test := range testCases { - t.Run(fmt.Sprintf("test %v", i), func(t *testing.T) { - options := &BackendOptions{} - for _, input := range test.input { - err := options.Set(input) - if test.expectError { - assert.Error(t, err) - } else { - if err != nil { - assert.NoError(t, err) - continue - } - } - } - assert.Equal(t, test.expected, *options) - }) - } -} diff --git a/login/handler.go b/login/handler.go @@ -6,9 +6,12 @@ import ( "fmt" "github.com/dgrijalva/jwt-go" "github.com/tarent/lib-compose/logging" + "github.com/tarent/loginsrv/model" + "github.com/tarent/loginsrv/oauth2" "io/ioutil" "net/http" "strings" + "time" ) const contentTypeHtml = "text/html; charset=utf-8" @@ -17,54 +20,86 @@ const contentTypePlain = "text/plain" type Handler struct { backends []Backend + oauth *oauth2.Manager config *Config } // NewHandler creates a login handler based on the supplied configuration. func NewHandler(config *Config) (*Handler, error) { + if len(config.Backends) == 0 && len(config.Oauth) == 0 { + return nil, errors.New("No login backends or oauth provider configured!") + } + backends := []Backend{} - for _, opt := range config.Backends { - p, exist := GetProvider(opt["provider"]) + for pName, opts := range config.Backends { + p, exist := GetProvider(pName) if !exist { - return nil, fmt.Errorf("No such provider: %v", opt["provider"]) + return nil, fmt.Errorf("No such provider: %v", pName) } - b, err := p(opt) + b, err := p(opts) if err != nil { return nil, err } backends = append(backends, b) } - if len(backends) == 0 { - return nil, errors.New("No login backends configured!") + + oauth := oauth2.NewManager() + for providerName, opts := range config.Oauth { + err := oauth.AddConfig(providerName, opts) + if err != nil { + return nil, err + } } + return &Handler{ backends: backends, config: config, + oauth: oauth, }, nil } -func (h *Handler) authenticate(username, password string) (bool, UserInfo, error) { - for _, b := range h.backends { - authenticated, userInfo, err := b.Authenticate(username, password) - if err != nil { - return false, UserInfo{}, err - } - if authenticated { - return authenticated, userInfo, nil - } +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + + _, err := h.oauth.GetConfigFromRequest(r) + if err == nil { + h.handleOauth(w, r) + return } - return false, UserInfo{}, nil + + h.handleLogin(w, r) + return } -func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if !strings.HasSuffix(r.URL.Path, "/login") { - w.WriteHeader(404) - fmt.Fprintf(w, "404 Ressource not found") +func (h *Handler) handleOauth(w http.ResponseWriter, r *http.Request) { + startedFlow, authenticated, userInfo, err := h.oauth.Handle(w, r) + + if startedFlow { + // the oauth flow started return } + if err != nil { + logging.Application(r.Header).WithError(err).Error() + h.respondError(w, r) + return + } + + if authenticated { + logging.Application(r.Header). + WithField("username", userInfo.Sub).Info("sucessfully authenticated") + h.respondAuthenticated(w, r, userInfo) + return + } + logging.Application(r.Header). + WithField("username", userInfo.Sub).Info("failed authentication") + + h.respondAuthFailure(w, r) + return +} + +func (h *Handler) handleLogin(w http.ResponseWriter, r *http.Request) { contentType := r.Header.Get("Content-Type") - if !(r.Method == "GET" || + if !(r.Method == "GET" || r.Method == "DELETE" || (r.Method == "POST" && (strings.HasPrefix(contentType, "application/json") || strings.HasPrefix(contentType, "application/x-www-form-urlencoded") || @@ -74,11 +109,24 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } r.ParseForm() + if r.Method == "DELETE" || r.FormValue("logout") == "true" { + h.deleteToken(w) + writeLoginForm(w, + loginFormData{ + Path: r.URL.Path, + Config: h.config, + }) + return + } + if r.Method == "GET" { + userInfo, valid := h.getToken(r) writeLoginForm(w, - map[string]interface{}{ - "path": r.URL.Path, - "config": h.config, + loginFormData{ + Path: r.URL.Path, + Config: h.config, + Authenticated: valid, + UserInfo: userInfo, }) return } @@ -110,7 +158,18 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (h *Handler) respondAuthenticated(w http.ResponseWriter, r *http.Request, userInfo UserInfo) { +func (h *Handler) deleteToken(w http.ResponseWriter) { + cookie := &http.Cookie{ + Name: h.config.CookieName, + Value: "delete", + HttpOnly: true, + Expires: time.Unix(0, 0), + Path: "/", + } + http.SetCookie(w, cookie) +} + +func (h *Handler) respondAuthenticated(w http.ResponseWriter, r *http.Request, userInfo jwt.Claims) { token, err := h.createToken(userInfo) if err != nil { logging.Application(r.Header).WithError(err).Error() @@ -118,9 +177,13 @@ func (h *Handler) respondAuthenticated(w http.ResponseWriter, r *http.Request, u return } if wantHtml(r) { - // TODO: set livetime - cookie := &http.Cookie{Name: h.config.CookieName, Value: token, HttpOnly: true} + cookie := &http.Cookie{ + Name: h.config.CookieName, + Value: token, + HttpOnly: true, + Path: "/", + } http.SetCookie(w, cookie) w.Header().Set("Location", h.config.SuccessUrl) w.WriteHeader(303) @@ -132,22 +195,37 @@ func (h *Handler) respondAuthenticated(w http.ResponseWriter, r *http.Request, u fmt.Fprintf(w, "%s", token) } -func (h *Handler) createToken(userInfo UserInfo) (string, error) { +func (h *Handler) createToken(userInfo jwt.Claims) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS512, userInfo) return token.SignedString([]byte(h.config.JwtSecret)) } +func (h *Handler) getToken(r *http.Request) (userInfo model.UserInfo, valid bool) { + c, err := r.Cookie(h.config.CookieName) + if err != nil { + return model.UserInfo{}, false + } + + token, err := jwt.ParseWithClaims(c.Value, &model.UserInfo{}, func(*jwt.Token) (interface{}, error) { + return []byte(h.config.JwtSecret), nil + }) + if err != nil { + return model.UserInfo{}, false + } + + u, v := token.Claims.(*model.UserInfo) + return *u, v +} + func (h *Handler) respondError(w http.ResponseWriter, r *http.Request) { if wantHtml(r) { - w.Header().Set("Content-Type", contentTypeHtml) - w.WriteHeader(500) username, _, _ := getCredentials(r) writeLoginForm(w, - map[string]interface{}{ - "path": r.URL.Path, - "error": true, - "config": h.config, - "username": username, + loginFormData{ + Path: r.URL.Path, + Error: true, + Config: h.config, + UserInfo: model.UserInfo{Sub: username}, }) return } @@ -167,12 +245,11 @@ func (h *Handler) respondAuthFailure(w http.ResponseWriter, r *http.Request) { w.WriteHeader(403) username, _, _ := getCredentials(r) writeLoginForm(w, - map[string]interface{}{ - "path": r.URL.Path, - "failure": true, - "config": h.config, - - "username": username, + loginFormData{ + Path: r.URL.Path, + Failure: true, + Config: h.config, + UserInfo: model.UserInfo{Sub: username}, }) return } @@ -200,3 +277,16 @@ func getCredentials(r *http.Request) (string, string, error) { } return r.PostForm.Get("username"), r.PostForm.Get("password"), nil } + +func (h *Handler) authenticate(username, password string) (bool, model.UserInfo, error) { + for _, b := range h.backends { + authenticated, userInfo, err := b.Authenticate(username, password) + if err != nil { + return false, model.UserInfo{}, err + } + if authenticated { + return authenticated, userInfo, nil + } + } + return false, model.UserInfo{}, nil +} diff --git a/login/handler_test.go b/login/handler_test.go @@ -5,6 +5,8 @@ import ( "fmt" "github.com/dgrijalva/jwt-go" "github.com/stretchr/testify/assert" + "github.com/tarent/loginsrv/model" + "github.com/tarent/loginsrv/oauth2" "net/http" "net/http/httptest" "strings" @@ -24,14 +26,14 @@ func TestHandler_NewFromConfig(t *testing.T) { expectError bool }{ { - &Config{Backends: BackendOptions{map[string]string{"provider": "simple", "bob": "secret"}}}, + &Config{Backends: Options{"simple": map[string]string{"bob": "secret"}}}, 1, false, }, // error cases { // init error because no users are provided - &Config{Backends: BackendOptions{map[string]string{"provider": "simple"}}}, + &Config{Backends: Options{"simple": map[string]string{}}}, 1, true, }, @@ -41,12 +43,7 @@ func TestHandler_NewFromConfig(t *testing.T) { true, }, { - &Config{Backends: BackendOptions{map[string]string{"foo": ""}}}, - 1, - true, - }, - { - &Config{Backends: BackendOptions{map[string]string{"provider": "simpleFoo", "bob": "secret"}}}, + &Config{Backends: Options{"simpleFoo": map[string]string{"bob": "secret"}}}, 1, true, }, @@ -64,17 +61,11 @@ func TestHandler_NewFromConfig(t *testing.T) { } } -func TestHandler_404(t *testing.T) { - recorder := call(req("GET", "/foo", "")) - assert.Equal(t, recorder.Code, 404) -} - func TestHandler_LoginForm(t *testing.T) { recorder := call(req("GET", "/context/login", "")) assert.Equal(t, recorder.Code, 200) - assert.Contains(t, recorder.Body.String(), "form") - assert.Contains(t, recorder.Body.String(), `method="POST"`) - assert.Contains(t, recorder.Body.String(), `action="/context/login"`) + assert.Contains(t, recorder.Body.String(), `class="container`) + assert.Equal(t, "no-cache, no-store, must-revalidate", recorder.Header().Get("Cache-Control")) } func TestHandler_HEAD(t *testing.T) { @@ -113,16 +104,34 @@ func TestHandler_LoginWeb(t *testing.T) { claims, err := tokenAsMap(strings.SplitN(headerParts[1], ";", 2)[0]) assert.NoError(t, err) assert.Equal(t, map[string]interface{}{"sub": "bob"}, claims) + assert.Contains(t, headerParts[1]+";", "Path=/;") // show the login form again after authentication failed recorder = call(req("POST", "/context/login", "username=bob&password=FOOBAR", TypeForm, AcceptHtml)) assert.Equal(t, 403, recorder.Code) - assert.Contains(t, recorder.Body.String(), "form") - assert.Contains(t, recorder.Body.String(), `method="POST"`) - assert.Contains(t, recorder.Body.String(), `action="/context/login"`) + assert.Contains(t, recorder.Body.String(), `class="container"`) assert.Equal(t, recorder.Header().Get("Set-Cookie"), "") } +func TestHandler_Logout(t *testing.T) { + // DELETE + recorder := call(req("DELETE", "/context/login", "")) + assert.Equal(t, 200, recorder.Code) + assert.Contains(t, recorder.Header().Get("Set-Cookie"), "jwt_token=delete; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT;") + + // GET + param + recorder = call(req("GET", "/context/login?logout=true", "")) + assert.Equal(t, 200, recorder.Code) + assert.Contains(t, recorder.Header().Get("Set-Cookie"), "jwt_token=delete; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT;") + + // POST + param + recorder = call(req("POST", "/context/login", "logout=true", TypeForm)) + assert.Equal(t, 200, recorder.Code) + assert.Contains(t, recorder.Header().Get("Set-Cookie"), "jwt_token=delete; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT;") + + assert.Equal(t, "no-cache, no-store, must-revalidate", recorder.Header().Get("Cache-Control")) +} + func TestHandler_LoginError(t *testing.T) { h := testHandlerWithError() @@ -142,16 +151,62 @@ func TestHandler_LoginError(t *testing.T) { assert.Equal(t, 500, recorder.Code) assert.Contains(t, recorder.Header().Get("Content-Type"), "text/html") - assert.Contains(t, recorder.Body.String(), "form") + assert.Contains(t, recorder.Body.String(), `class="container"`) assert.Contains(t, recorder.Body.String(), "Internal Error") } +func TestHandler_getToken_Valid(t *testing.T) { + h := testHandler() + input := model.UserInfo{Sub: "marvin"} + token, err := h.createToken(input) + assert.NoError(t, err) + r := &http.Request{ + Header: http.Header{"Cookie": {h.config.CookieName + "=" + token + ";"}}, + } + userInfo, valid := h.getToken(r) + assert.True(t, valid) + assert.Equal(t, input, userInfo) +} + +func TestHandler_getToken_InvalidSecret(t *testing.T) { + h := testHandler() + input := model.UserInfo{Sub: "marvin"} + token, err := h.createToken(input) + assert.NoError(t, err) + r := &http.Request{ + Header: http.Header{"Cookie": {h.config.CookieName + "=" + token + ";"}}, + } + + // modify secret + h.config.JwtSecret = "foobar" + + _, valid := h.getToken(r) + assert.False(t, valid) +} + +func TestHandler_getToken_InvalidToken(t *testing.T) { + h := testHandler() + r := &http.Request{ + Header: http.Header{"Cookie": {h.config.CookieName + "=asdcsadcsadc"}}, + } + + _, valid := h.getToken(r) + assert.False(t, valid) +} + +func TestHandler_getToken_InvalidNoToken(t *testing.T) { + h := testHandler() + _, valid := h.getToken(&http.Request{}) + assert.False(t, valid) +} + func testHandler() *Handler { return &Handler{ backends: []Backend{ NewSimpleBackend(map[string]string{"bob": "secret"}), }, - config: &DefaultConfig, + oauth: oauth2.NewManager(), + config: DefaultConfig(), } } @@ -160,7 +215,8 @@ func testHandlerWithError() *Handler { backends: []Backend{ errorTestBackend("test error"), }, - config: &DefaultConfig, + oauth: oauth2.NewManager(), + config: DefaultConfig(), } } @@ -185,7 +241,7 @@ func req(method string, url string, body string, header ...string) *http.Request func tokenAsMap(tokenString string) (map[string]interface{}, error) { token, err := jwt.Parse(tokenString, func(*jwt.Token) (interface{}, error) { - return []byte(DefaultConfig.JwtSecret), nil + return []byte(DefaultConfig().JwtSecret), nil }) if err != nil { return nil, err @@ -200,6 +256,6 @@ func tokenAsMap(tokenString string) (map[string]interface{}, error) { type errorTestBackend string -func (h errorTestBackend) Authenticate(username, password string) (bool, UserInfo, error) { - return false, UserInfo{}, errors.New(string(h)) +func (h errorTestBackend) Authenticate(username, password string) (bool, model.UserInfo, error) { + return false, model.UserInfo{}, errors.New(string(h)) } diff --git a/login/login_form.go b/login/login_form.go @@ -1,54 +1,156 @@ package login import ( + "bytes" + "github.com/tarent/lib-compose/logging" + "github.com/tarent/loginsrv/model" "html/template" - "io" + "net/http" + "strings" ) const loginForm = `<!DOCTYPE html> <html> <head> <link uic-remove rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.5/css/bootstrap.min.css"> + <link uic-remove rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/bootstrap-social/5.1.1/bootstrap-social.min.css"> + <link uic-remove rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.css"> <style> - .vertical-offset-100{ - padding-top:100px; - } + .vertical-offset-100{ + padding-top:100px; + } + .login-or-container { + text-align: center; + margin: 0; + margin-bottom: 10px; + clear: both; + color: #6a737c; + font-variant: small-caps; + } + .login-or-hr { + margin-bottom: 0; + position: relative; + top: 28px; + height: 0; + border: 0; + border-top: 1px solid #e4e6e8; + } + .login-or { + display: inline-block; + position: relative; + padding: 10px; + background-color: #FFF; + } + .login-picture { + width: 120px; + height: 120px; + border-radius: 3px; + } </style> </head> <body> <uic-fragment name="content"> -<div class="container"> - <div class="row vertical-offset-100"> - <div class="col-md-4 col-md-offset-4"> - <div class="panel panel-default"> - <div class="panel-heading"> - <h3 class="panel-title">Please sign in</h3> - {{ if .error}}Internal Error. Please try again later{{end}} - {{ if .failure}}Wrong credentials{{end}} - </div> - <div class="panel-body"> - <form accept-charset="UTF-8" role="form" method="POST" action="{{.path}}"> - <fieldset> - <div class="form-group"> - <input class="form-control" placeholder="Username" name="username" value="{{.username}}" type="text"> - </div> - <div class="form-group"> - <input class="form-control" placeholder="Password" name="password" type="password" value=""> - </div> - <input class="btn btn-lg btn-success btn-block" type="submit" value="Login"> - </fieldset> - </form> - </div> - </div> - </div> + <div class="container"> + <div class="row vertical-offset-100"> + <div class="col-md-4 col-md-offset-4"> + + {{ if .Error}} + <div class="alert alert-danger" role="alert"> + <strong>Internal Error. </strong> Please try again later. + </div> + {{end}} + + {{ if .Authenticated}} + {{with .UserInfo}} + <h1>Welcome {{.Sub}}!</h1> + <br/> + {{if .Picture}}<img class="login-picture" src="{{.Picture}}?s=120">{{end}} + {{if .Name}}<h3>{{.Name}}</h3>{{end}} + <br/> + <a class="btn btn-md btn-primary" href="login?logout=true">Logout</a> + {{end}} + {{else}} + + {{ range $providerName, $opts := .Config.Oauth }} + <a class="btn btn-block btn-lg btn-social btn-{{ $providerName }}" href="login/{{ $providerName }}"> + <span class="fa fa-{{ $providerName }}"></span> Sign in with {{ $providerName | ucfirst }} + </a> + {{end}} + + {{if and (not (eq (len .Config.Backends) 0)) (not (eq (len .Config.Oauth) 0))}} + <div class="login-or-container"> + <hr class="login-or-hr"> + <div class="login-or lead">or</div> + </div> + {{end}} + + {{if not (eq (len .Config.Backends) 0) }} + <div class="panel panel-default"> + <div class="panel-heading"> + <div class="panel-title"> + <h4>Sign in</h4> + {{ if .Failure}}<div class="alert alert-warning" role="alert">Invalid credentials</div>{{end}} + </div> + </div> + <div class="panel-body"> + <form accept-charset="UTF-8" role="form" method="POST" action="{{.Path}}"> + <fieldset> + <div class="form-group"> + <input class="form-control" placeholder="Username" name="username" value="{{.UserInfo.Sub}}" type="text"> + </div> + <div class="form-group"> + <input class="form-control" placeholder="Password" name="password" type="password" value=""> + </div> + <input class="btn btn-lg btn-success btn-block" type="submit" value="Login"> + </fieldset> + </form> + </div> + </div> + {{end}} + {{end}} + </div> </div> -</div> + </div> </uic-fragment> </body> -</html> -` +</html>` + +type loginFormData struct { + Path string + Error bool + Failure bool + Config *Config + Authenticated bool + UserInfo model.UserInfo +} + +func writeLoginForm(w http.ResponseWriter, params loginFormData) { + funcMap := template.FuncMap{ + "ucfirst": ucfirst, + } + t := template.Must(template.New("loginForm").Funcs(funcMap).Parse(loginForm)) + b := bytes.NewBuffer(nil) + err := t.Execute(b, params) + if err != nil { + logging.Logger.WithError(err).Error() + w.WriteHeader(500) + w.Write([]byte(`Internal Server Error`)) + return + } + + w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") + w.Header().Set("Content-Type", contentTypeHtml) + if params.Error { + w.WriteHeader(500) + } + + w.Write(b.Bytes()) +} + +func ucfirst(in string) string { + if in == "" { + return "" + } -func writeLoginForm(w io.Writer, params map[string]interface{}) { - t := template.Must(template.New("loginForm").Parse(loginForm)) - t.Execute(w, params) + return strings.ToUpper(in[0:1]) + in[1:] } diff --git a/login/login_form_test.go b/login/login_form_test.go @@ -0,0 +1,12 @@ +package login + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func Test_ucfirst(t *testing.T) { + assert.Equal(t, "", ucfirst("")) + assert.Equal(t, "A", ucfirst("a")) + assert.Equal(t, "Abc def", ucfirst("abc def")) +} diff --git a/login/provider.go b/login/provider.go @@ -18,3 +18,18 @@ func GetProvider(providerName string) (Provider, bool) { p, exist := provider[providerName] return p, exist } + +// GetProvider returns the metainfo for a provider +func GetProviderDescription(providerName string) (*ProviderDescription, bool) { + p, exist := providerDescription[providerName] + return p, exist +} + +// ProviderList returns the names of all registered provider +func ProviderList() []string { + list := make([]string, 0, len(provider)) + for k, _ := range provider { + list = append(list, k) + } + return list +} diff --git a/login/provider_description.go b/login/provider_description.go @@ -4,13 +4,6 @@ type ProviderDescription struct { // the name of the provider Name string - // the config options, which the provider supports - options []ProviderOption -} - -type ProviderOption struct { - Name string - Description string - Default string - Required string + // the text for the commandline optio + HelpText string } diff --git a/login/simple_backend.go b/login/simple_backend.go @@ -2,6 +2,7 @@ package login import ( "errors" + "github.com/tarent/loginsrv/model" ) const SimpleProviderName = "simple" @@ -9,7 +10,8 @@ const SimpleProviderName = "simple" func init() { RegisterProvider( &ProviderDescription{ - Name: SimpleProviderName, + Name: SimpleProviderName, + HelpText: "Simple login backend opts: user1=password,user2=password,..", }, SimpleBackendFactory) } @@ -17,9 +19,7 @@ func init() { func SimpleBackendFactory(config map[string]string) (Backend, error) { userPassword := map[string]string{} for k, v := range config { - if k != "provider" && k != "name" { - userPassword[k] = v - } + userPassword[k] = v } if len(userPassword) == 0 { return nil, errors.New("no users provided for simple backend") @@ -39,9 +39,9 @@ func NewSimpleBackend(userPassword map[string]string) *SimpleBackend { } } -func (sb *SimpleBackend) Authenticate(username, password string) (bool, UserInfo, error) { +func (sb *SimpleBackend) Authenticate(username, password string) (bool, model.UserInfo, error) { if p, exist := sb.userPassword[username]; exist && p == password { - return true, UserInfo{Username: username}, nil + return true, model.UserInfo{Sub: username}, nil } - return false, UserInfo{}, nil + return false, model.UserInfo{}, nil } diff --git a/login/simple_backend_test.go b/login/simple_backend_test.go @@ -11,9 +11,7 @@ func TestSetup(t *testing.T) { assert.NotNil(t, p) backend, err := p(map[string]string{ - "provider": "simple", - "name": "myFooProvider", - "bob": "secret", + "bob": "secret", }) assert.NoError(t, err) @@ -31,16 +29,16 @@ func TestSimpleBackend_Authenticate(t *testing.T) { authenticated, userInfo, err := backend.Authenticate("bob", "secret") assert.True(t, authenticated) - assert.Equal(t, "bob", userInfo.Username) + assert.Equal(t, "bob", userInfo.Sub) assert.NoError(t, err) authenticated, userInfo, err = backend.Authenticate("bob", "fooo") assert.False(t, authenticated) - assert.Equal(t, "", userInfo.Username) + assert.Equal(t, "", userInfo.Sub) assert.NoError(t, err) authenticated, userInfo, err = backend.Authenticate("", "") assert.False(t, authenticated) - assert.Equal(t, "", userInfo.Username) + assert.Equal(t, "", userInfo.Sub) assert.NoError(t, err) } diff --git a/login/user_info.go b/login/user_info.go @@ -1,11 +0,0 @@ -package login - -type UserInfo struct { - Username string `json:"sub"` -} - -// this interface implementation -// lets us use the user info as Claim for jwt-go -func (u UserInfo) Valid() error { - return nil -} diff --git a/main_test.go b/main_test.go @@ -3,7 +3,6 @@ package main import ( "github.com/dgrijalva/jwt-go" "github.com/stretchr/testify/assert" - "github.com/tarent/loginsrv/login" "io/ioutil" "net/http" "os" @@ -13,9 +12,11 @@ import ( ) func Test_BasicEndToEnd(t *testing.T) { + originalArgs := os.Args - os.Args = []string{"loginsrv", "-host=localhost", "-port=3000", "-backend=provider=simple,bob=secret"} + secret := "theSecret" + os.Args = []string{"loginsrv", "-jwt-secret", secret, "-host=localhost", "-port=3000", "-backend=provider=simple,bob=secret"} defer func() { os.Args = originalArgs }() go main() @@ -37,7 +38,7 @@ func Test_BasicEndToEnd(t *testing.T) { assert.NoError(t, err) token, err := jwt.Parse(string(b), func(*jwt.Token) (interface{}, error) { - return []byte(login.DefaultConfig.JwtSecret), nil + return []byte(secret), nil }) assert.NoError(t, err) diff --git a/model/user_info.go b/model/user_info.go @@ -0,0 +1,15 @@ +package model + +type UserInfo struct { + Sub string `json:"sub"` + Picture string `json:"picture,omitempty"` + Name string `json:"name,omitempty"` + Email string `json:"email,omitempty"` + Origin string `json:"origin,omitempty"` +} + +// this interface implementation +// lets us use the user info as Claim for jwt-go +func (u UserInfo) Valid() error { + return nil +} diff --git a/oauth2/github.go b/oauth2/github.go @@ -0,0 +1,64 @@ +package oauth2 + +import ( + "encoding/json" + "fmt" + "github.com/tarent/loginsrv/model" + "io/ioutil" + "net/http" + "strings" +) + +var githubApi = "https://api.github.com" + +func init() { + RegisterProvider(providerGithub) +} + +type GithubUser struct { + Login string `json:"login,omitempty"` + AvatarURL string `json:"avatar_url,omitempty"` + Name string `json:"name,omitempty"` + Email string `json:"email,omitempty"` +} + +var providerGithub = Provider{ + Name: "github", + AuthURL: "https://github.com/login/oauth/authorize", + TokenURL: "https://github.com/login/oauth/access_token", + GetUserInfo: func(token TokenInfo) (model.UserInfo, string, error) { + gu := GithubUser{} + url := fmt.Sprintf("%v/user?access_token=%v", githubApi, token.AccessToken) + fmt.Println("url: ", url) + resp, err := http.Get(url) + if err != nil { + return model.UserInfo{}, "", err + } + + if !strings.Contains(resp.Header.Get("Content-Type"), "application/json") { + return model.UserInfo{}, "", fmt.Errorf("wrong content-type on github get user info: %v", resp.Header.Get("Content-Type")) + } + + if resp.StatusCode != 200 { + return model.UserInfo{}, "", fmt.Errorf("got http status %v on github get user info", resp.StatusCode) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return model.UserInfo{}, "", fmt.Errorf("error reading github get user info: %v", err) + } + + err = json.Unmarshal(b, &gu) + if err != nil { + return model.UserInfo{}, "", fmt.Errorf("error parsing github get user info: %v", err) + } + + return model.UserInfo{ + Sub: gu.Login, + Picture: gu.AvatarURL, + Name: gu.Name, + Email: gu.Email, + Origin: "github", + }, string(b), nil + }, +} diff --git a/oauth2/github_test.go b/oauth2/github_test.go @@ -0,0 +1,59 @@ +package oauth2 + +import ( + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +var githubTestUserResponse = `{ + "login": "octocat", + "id": 1, + "avatar_url": "https://github.com/images/error/octocat_happy.gif", + "gravatar_id": "", + "url": "https://api.github.com/users/octocat", + "html_url": "https://github.com/octocat", + "followers_url": "https://api.github.com/users/octocat/followers", + "following_url": "https://api.github.com/users/octocat/following{/other_user}", + "gists_url": "https://api.github.com/users/octocat/gists{/gist_id}", + "starred_url": "https://api.github.com/users/octocat/starred{/owner}{/repo}", + "subscriptions_url": "https://api.github.com/users/octocat/subscriptions", + "organizations_url": "https://api.github.com/users/octocat/orgs", + "repos_url": "https://api.github.com/users/octocat/repos", + "events_url": "https://api.github.com/users/octocat/events{/privacy}", + "received_events_url": "https://api.github.com/users/octocat/received_events", + "type": "User", + "site_admin": false, + "name": "monalisa octocat", + "company": "GitHub", + "blog": "https://github.com/blog", + "location": "San Francisco", + "email": "octocat@github.com", + "hireable": false, + "bio": "There once was...", + "public_repos": 2, + "public_gists": 1, + "followers": 20, + "following": 0, + "created_at": "2008-01-14T04:33:35Z", + "updated_at": "2008-01-14T04:33:35Z" +}` + +func Test_Github_getUserInfo(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "secret", r.FormValue("access_token")) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write([]byte(githubTestUserResponse)) + })) + defer server.Close() + + githubApi = server.URL + + u, rawJson, err := providerGithub.GetUserInfo(TokenInfo{AccessToken: "secret"}) + assert.NoError(t, err) + assert.Equal(t, "octocat", u.Sub) + assert.Equal(t, "octocat@github.com", u.Email) + assert.Equal(t, "monalisa octocat", u.Name) + assert.Equal(t, githubTestUserResponse, rawJson) +} diff --git a/oauth2/manager.go b/oauth2/manager.go @@ -0,0 +1,145 @@ +package oauth2 + +import ( + "fmt" + "github.com/tarent/loginsrv/model" + "net/http" + "net/url" + "strings" +) + +// The manager has the responsibility to handle the user user requests in an oauth flow. +// It has to pick the right configuration and start the oauth redirecting. +type Manager struct { + configs map[string]Config + startFlow func(cfg Config, w http.ResponseWriter) + authenticate func(cfg Config, r *http.Request) (TokenInfo, error) +} + +// NewManager creates a new Manager +func NewManager() *Manager { + return &Manager{ + configs: map[string]Config{}, + startFlow: StartFlow, + authenticate: Authenticate, + } +} + +// Handle is managing the oauth flow. +// Dependent on the code parameter of the url, the oauth flow is started or +// the call is interpreted as the redirect callback and the token exchange is done. +// Return parameters: +// startedFlow - true, if this was the initial call to start the oauth flow +// authenticated - if the authentication was successful or not +// userInfo - the user info from the provider in case of a succesful authentication +// err - an error +func (manager *Manager) Handle(w http.ResponseWriter, r *http.Request) ( + startedFlow bool, + authenticated bool, + userInfo model.UserInfo, + err error) { + + if r.FormValue("error") != "" { + return false, false, model.UserInfo{}, fmt.Errorf("error: %v", r.FormValue("error")) + } + + cfg, err := manager.GetConfigFromRequest(r) + if err != nil { + return false, false, model.UserInfo{}, err + } + + if r.FormValue("code") != "" { + tokenInfo, err := manager.authenticate(cfg, r) + if err != nil { + return false, false, model.UserInfo{}, err + } + + userInfo, _, err := cfg.Provider.GetUserInfo(tokenInfo) + if err != nil { + return false, false, model.UserInfo{}, err + } + return false, true, userInfo, err + } + + manager.startFlow(cfg, w) + return true, false, model.UserInfo{}, nil +} + +func (manager *Manager) GetConfigFromRequest(r *http.Request) (Config, error) { + configName := manager.getConfigNameFromPath(r.URL.Path) + cfg, exist := manager.configs[configName] + if !exist { + return Config{}, fmt.Errorf("no oauth configuration for %v", configName) + } + + if cfg.RedirectURI == "" { + cfg.RedirectURI = redirectUriFromRequest(r) + } + + return cfg, nil +} + +func (manager *Manager) getConfigNameFromPath(path string) string { + parts := strings.Split(path, "/") + return parts[len(parts)-1] +} + +// Add a configuration for a provider +func (manager *Manager) AddConfig(providerName string, opts map[string]string) error { + p, exist := GetProvider(providerName) + if !exist { + return fmt.Errorf("no provider for name %v", providerName) + } + + cfg := Config{ + Provider: p, + AuthURL: p.AuthURL, + TokenURL: p.TokenURL, + } + + if clientId, exist := opts["client_id"]; !exist { + return fmt.Errorf("missing parameter client_id") + } else { + cfg.ClientID = clientId + } + + if clientSecret, exist := opts["client_secret"]; !exist { + return fmt.Errorf("missing parameter client_secret") + } else { + cfg.ClientSecret = clientSecret + } + + if scope, exist := opts["scope"]; exist { + cfg.Scope = scope + } + + if redirectURI, exist := opts["redirect_uri"]; exist { + cfg.RedirectURI = redirectURI + } + + manager.configs[providerName] = cfg + return nil +} + +func redirectUriFromRequest(r *http.Request) string { + u := url.URL{} + u.Path = r.URL.Path + + if ffh := r.Header.Get("X-Forwarded-Host"); ffh == "" { + u.Host = r.Host + } else { + u.Host = ffh + } + + if ffp := r.Header.Get("X-Forwarded-Proto"); ffp == "" { + if r.TLS != nil { + u.Scheme = "https" + } else { + u.Scheme = "http" + } + } else { + u.Scheme = ffp + } + + return u.String() +} diff --git a/oauth2/manager_test.go b/oauth2/manager_test.go @@ -0,0 +1,256 @@ +package oauth2 + +import ( + "crypto/tls" + "errors" + "github.com/stretchr/testify/assert" + "github.com/tarent/loginsrv/model" + "net/http" + "net/http/httptest" + "testing" +) + +func Test_Manager_Positive_Flow(t *testing.T) { + var startFlowCalled, authenticateCalled, getUserInfoCalled bool + var startFlowReceivedConfig, authenticateReceivedConfig Config + expectedToken := TokenInfo{AccessToken: "the-access-token"} + + exampleProvider := Provider{ + Name: "example", + AuthURL: "https://example.com/login/oauth/authorize", + TokenURL: "https://example.com/login/oauth/access_token", + GetUserInfo: func(token TokenInfo) (model.UserInfo, string, error) { + getUserInfoCalled = true + assert.Equal(t, token, expectedToken) + return model.UserInfo{ + Sub: "the-username", + }, "", nil + }, + } + RegisterProvider(exampleProvider) + defer UnRegisterProvider(exampleProvider.Name) + + expectedConfig := Config{ + ClientID: "client42", + ClientSecret: "secret", + AuthURL: exampleProvider.AuthURL, + TokenURL: exampleProvider.TokenURL, + RedirectURI: "http://localhost", + Scope: "email other", + Provider: exampleProvider, + } + + m := NewManager() + m.AddConfig(exampleProvider.Name, map[string]string{ + "client_id": expectedConfig.ClientID, + "client_secret": expectedConfig.ClientSecret, + "scope": expectedConfig.Scope, + "redirect_uri": expectedConfig.RedirectURI, + }) + + m.startFlow = func(cfg Config, w http.ResponseWriter) { + startFlowCalled = true + startFlowReceivedConfig = cfg + } + + m.authenticate = func(cfg Config, r *http.Request) (TokenInfo, error) { + authenticateCalled = true + authenticateReceivedConfig = cfg + return expectedToken, nil + } + + // start flow + r, _ := http.NewRequest("GET", "http://example.com/login/"+exampleProvider.Name, nil) + + startedFlow, authenticated, userInfo, err := m.Handle(httptest.NewRecorder(), r) + assert.NoError(t, err) + assert.True(t, startedFlow) + assert.False(t, authenticated) + assert.Equal(t, model.UserInfo{}, userInfo) + + assert.True(t, startFlowCalled) + assert.False(t, authenticateCalled) + + assertEqualConfig(t, expectedConfig, startFlowReceivedConfig) + + // callback + r, _ = http.NewRequest("GET", "http://example.com/login/"+exampleProvider.Name+"?code=xyz", nil) + + startedFlow, authenticated, userInfo, err = m.Handle(httptest.NewRecorder(), r) + assert.NoError(t, err) + assert.False(t, startedFlow) + assert.True(t, authenticated) + assert.Equal(t, model.UserInfo{Sub: "the-username"}, userInfo) + assert.True(t, authenticateCalled) + assertEqualConfig(t, expectedConfig, authenticateReceivedConfig) + + assert.True(t, getUserInfoCalled) +} + +func Test_Manager_NoAauthOnWrongCode(t *testing.T) { + var authenticateCalled, getUserInfoCalled bool + + exampleProvider := Provider{ + Name: "example", + AuthURL: "https://example.com/login/oauth/authorize", + TokenURL: "https://example.com/login/oauth/access_token", + GetUserInfo: func(token TokenInfo) (model.UserInfo, string, error) { + getUserInfoCalled = true + return model.UserInfo{}, "", nil + }, + } + RegisterProvider(exampleProvider) + defer UnRegisterProvider(exampleProvider.Name) + + m := NewManager() + m.AddConfig(exampleProvider.Name, map[string]string{ + "client_id": "foo", + "client_secret": "bar", + }) + + m.authenticate = func(cfg Config, r *http.Request) (TokenInfo, error) { + authenticateCalled = true + return TokenInfo{}, errors.New("code not valid") + } + + // callback + r, _ := http.NewRequest("GET", "http://example.com/login/"+exampleProvider.Name+"?code=xyz", nil) + + startedFlow, authenticated, userInfo, err := m.Handle(httptest.NewRecorder(), r) + assert.EqualError(t, err, "code not valid") + assert.False(t, startedFlow) + assert.False(t, authenticated) + assert.Equal(t, model.UserInfo{}, userInfo) + assert.True(t, authenticateCalled) + assert.False(t, getUserInfoCalled) +} + +func Test_Manager_getConfig_ErrorCase(t *testing.T) { + r, _ := http.NewRequest("GET", "http://example.com/login", nil) + + m := NewManager() + m.AddConfig("github", map[string]string{ + "client_id": "foo", + "client_secret": "bar", + }) + + _, err := m.GetConfigFromRequest(r) + assert.EqualError(t, err, "no oauth configuration for login") +} + +func Test_Manager_AddConfig_ErrorCases(t *testing.T) { + m := NewManager() + + assert.NoError(t, + m.AddConfig("github", map[string]string{ + "client_id": "foo", + "client_secret": "bar", + })) + + assert.EqualError(t, + m.AddConfig("FOOOO", map[string]string{ + "client_id": "foo", + "client_secret": "bar", + }), + "no provider for name FOOOO", + ) + + assert.EqualError(t, + m.AddConfig("github", map[string]string{ + "client_secret": "bar", + }), + "missing parameter client_id", + ) + + assert.EqualError(t, + m.AddConfig("github", map[string]string{ + "client_id": "foo", + }), + "missing parameter client_secret", + ) + +} + +func Test_Manager_redirectUriFromRequest(t *testing.T) { + tests := []struct { + url string + tls bool + header http.Header + expected string + }{ + { + "http://example.com/login/github", + false, + http.Header{}, + "http://example.com/login/github", + }, + { + "http://localhost/login/github", + false, + http.Header{ + "X-Forwarded-Host": {"example.com"}, + }, + "http://example.com/login/github", + }, + { + "http://localhost/login/github", + true, + http.Header{ + "X-Forwarded-Host": {"example.com"}, + }, + "https://example.com/login/github", + }, + { + "http://localhost/login/github", + false, + http.Header{ + "X-Forwarded-Host": {"example.com"}, + "X-Forwarded-Proto": {"https"}, + }, + "https://example.com/login/github", + }, + } + for _, test := range tests { + t.Run(test.url, func(t *testing.T) { + r, _ := http.NewRequest("GET", test.url, nil) + r.Header = test.header + if test.tls { + r.TLS = &tls.ConnectionState{} + } + uri := redirectUriFromRequest(r) + assert.Equal(t, test.expected, uri) + }) + } +} + +func Test_Manager_RedirectURI_Generation(t *testing.T) { + var startFlowReceivedConfig Config + + m := NewManager() + m.AddConfig("github", map[string]string{ + "client_id": "foo", + "client_secret": "bar", + "scope": "bazz", + }) + + m.startFlow = func(cfg Config, w http.ResponseWriter) { + startFlowReceivedConfig = cfg + } + + callUrl := "http://example.com/login/github" + r, _ := http.NewRequest("GET", callUrl, nil) + + _, _, _, err := m.Handle(httptest.NewRecorder(), r) + assert.NoError(t, err) + assert.Equal(t, callUrl, startFlowReceivedConfig.RedirectURI) +} + +func assertEqualConfig(t *testing.T, c1, c2 Config) { + assert.Equal(t, c1.AuthURL, c2.AuthURL) + assert.Equal(t, c1.ClientID, c2.ClientID) + assert.Equal(t, c1.ClientSecret, c2.ClientSecret) + assert.Equal(t, c1.Scope, c2.Scope) + assert.Equal(t, c1.RedirectURI, c2.RedirectURI) + assert.Equal(t, c1.TokenURL, c2.TokenURL) + assert.Equal(t, c1.Provider.Name, c2.Provider.Name) +} diff --git a/oauth2/oauth.go b/oauth2/oauth.go @@ -0,0 +1,162 @@ +package oauth2 + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "math/rand" + "net/http" + "net/url" + "strings" + "time" +) + +func init() { + rand.Seed(time.Now().UTC().UnixNano()) +} + +// Config describes a typical 3-legged OAuth2 flow, with both the +// client application information and the server's endpoint URLs. +type Config struct { + // ClientID is the application's ID. + ClientID string + + // ClientSecret is the application's secret. + ClientSecret string + + // The oauth authentication url to redirect to + AuthURL string + + // The url for token exchange + TokenURL string + + // RedirectURL is the URL to redirect users going through + // the OAuth flow, after the resource owner's URLs. + RedirectURI string + + // Scope specifies optional requested permissions, this is a *space* separated list. + Scope string + + // The oauth provider + Provider Provider +} + +// Token represents the crendentials used to authorize +// the requests to access protected resources on the OAuth 2.0 +// provider's backend. +type TokenInfo struct { + // AccessToken is the token that authorizes and authenticates + // the requests. + AccessToken string `json:"access_token"` + + // TokenType is the type of token. + TokenType string `json:"token_type,omitempty"` + + // The scopes for this tolen + Scope string `json:"scope,omitempty"` +} + +// JsonError represents an oauth error response in json form. +type JsonError struct { + Error string `json:"error"` +} + +const stateCookieName = "oauthState" +const defaultTimeout = 5 * time.Second + +// Starts the flow by redirecting the user to the login provider. +// A state parameter to protect against cross-site request forgery attacks is randomly generated and stored in a cookie +func StartFlow(cfg Config, w http.ResponseWriter) { + values := make(url.Values) + values.Set("client_id", cfg.ClientID) + values.Set("scope", cfg.Scope) + values.Set("redirect_uri", cfg.RedirectURI) + values.Set("response_type", "code") + + // set and store the state param + values.Set("state", randStringBytes(15)) + http.SetCookie(w, &http.Cookie{ + Name: stateCookieName, + MaxAge: 60 * 10, // 10 minutes + Value: values.Get("state"), + HttpOnly: true, + }) + + targetUrl := cfg.AuthURL + "?" + values.Encode() + w.Header().Set("Location", targetUrl) + w.WriteHeader(http.StatusFound) +} + +// Check the authentication after coming back from the oauth flow. +// Verify the state parameter againt the state cookie from the request. +func Authenticate(cfg Config, r *http.Request) (TokenInfo, error) { + if r.FormValue("error") != "" { + return TokenInfo{}, fmt.Errorf("error: %v", r.FormValue("error")) + } + + state := r.FormValue("state") + stateCookie, err := r.Cookie(stateCookieName) + if err != nil || stateCookie.Value != state { + return TokenInfo{}, fmt.Errorf("error: oauth state param could not be verified") + } + + code := r.FormValue("code") + if code == "" { + return TokenInfo{}, fmt.Errorf("error: no auth code provided") + } + return getAccessToken(cfg, state, code) +} + +func getAccessToken(cfg Config, state, code string) (TokenInfo, error) { + values := url.Values{} + values.Set("client_id", cfg.ClientID) + values.Set("client_secret", cfg.ClientSecret) + values.Set("code", code) + + r, _ := http.NewRequest("POST", cfg.TokenURL, strings.NewReader(values.Encode())) + cntx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + defer cancel() + r.WithContext(cntx) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r.Header.Set("Accept", "application/json") + resp, err := http.DefaultClient.Do(r) + if err != nil { + return TokenInfo{}, err + } + + if resp.StatusCode != 200 { + return TokenInfo{}, fmt.Errorf("error: expected http status 200 on token exchange, but got %v", resp.StatusCode) + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return TokenInfo{}, fmt.Errorf("error reading token exchange response: %q", err) + } + + jsonError := JsonError{} + json.Unmarshal(body, &jsonError) + if jsonError.Error != "" { + return TokenInfo{}, fmt.Errorf("error: got %q on token exchange", jsonError.Error) + } + + tokenInfo := TokenInfo{} + err = json.Unmarshal(body, &tokenInfo) + if err != nil { + return TokenInfo{}, fmt.Errorf("error on parsing oauth token: %v", err) + } + + if tokenInfo.AccessToken == "" { + return TokenInfo{}, fmt.Errorf("error: no access_token on token exchange") + } + return tokenInfo, nil +} + +const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + +func randStringBytes(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} diff --git a/oauth2/oauth_test.go b/oauth2/oauth_test.go @@ -0,0 +1,202 @@ +package oauth2 + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +var testConfig = Config{ + ClientID: "client42", + ClientSecret: "secret", + AuthURL: "http://auth-provider/auth", + TokenURL: "http://auth-provider/token", + RedirectURI: "http://localhost/callback", + Scope: "email other", +} + +func Test_StartFlow(t *testing.T) { + resp := httptest.NewRecorder() + StartFlow(testConfig, resp) + + assert.Equal(t, http.StatusFound, resp.Code) + + // assert that we received a state cookie + cHeader := strings.Split(resp.Header().Get("Set-Cookie"), ";")[0] + assert.Equal(t, stateCookieName, strings.Split(cHeader, "=")[0]) + state := strings.Split(cHeader, "=")[1] + + expectedLocation := fmt.Sprintf("%v?client_id=%v&redirect_uri=%v&response_type=code&scope=%v&state=%v", + testConfig.AuthURL, + testConfig.ClientID, + url.QueryEscape(testConfig.RedirectURI), + "email+other", + state, + ) + + assert.Equal(t, expectedLocation, resp.Header().Get("Location")) +} + +func Test_Authenticate(t *testing.T) { + // mock a server for token exchange + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + assert.Equal(t, "application/json", r.Header.Get("Accept")) + + body, _ := ioutil.ReadAll(r.Body) + assert.Equal(t, "client_id=client42&client_secret=secret&code=theCode", string(body)) + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"e72e16c7e42f292c6912e7710c838347ae178b4a", "scope":"repo gist", "token_type":"bearer"}`)) + })) + defer server.Close() + + testConfigCopy := testConfig + testConfigCopy.TokenURL = server.URL + + request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil) + request.Header.Set("Cookie", "oauthState=theState") + request.URL, _ = url.Parse("http://localhost/callback?code=theCode&state=theState") + + tokenInfo, err := Authenticate(testConfigCopy, request) + + assert.NoError(t, err) + assert.Equal(t, "e72e16c7e42f292c6912e7710c838347ae178b4a", tokenInfo.AccessToken) + assert.Equal(t, "repo gist", tokenInfo.Scope) + assert.Equal(t, "bearer", tokenInfo.TokenType) +} + +func Test_Authenticate_CodeExchangeError(t *testing.T) { + var testReturnCode int + testResponseJson := `{"error":"bad_verification_code","error_description":"The code passed is incorrect or expired.","error_uri":"https://developer.github.com/v3/oauth/#bad-verification-code"}` + // mock a server for token exchange + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(testReturnCode) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(testResponseJson)) + })) + defer server.Close() + + testConfigCopy := testConfig + testConfigCopy.TokenURL = server.URL + + request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil) + request.Header.Set("Cookie", "oauthState=theState") + request.URL, _ = url.Parse("http://localhost/callback?code=theCode&state=theState") + + testReturnCode = 500 + tokenInfo, err := Authenticate(testConfigCopy, request) + assert.Error(t, err) + assert.EqualError(t, err, "error: expected http status 200 on token exchange, but got 500") + assert.Equal(t, "", tokenInfo.AccessToken) + + testReturnCode = 200 + tokenInfo, err = Authenticate(testConfigCopy, request) + assert.Error(t, err) + assert.EqualError(t, err, `error: got "bad_verification_code" on token exchange`) + assert.Equal(t, "", tokenInfo.AccessToken) + + testReturnCode = 200 + testResponseJson = `{"foo": "bar"}` + tokenInfo, err = Authenticate(testConfigCopy, request) + assert.Error(t, err) + assert.EqualError(t, err, `error: no access_token on token exchange`) + assert.Equal(t, "", tokenInfo.AccessToken) + +} + +func Test_Authentication_ProviderError(t *testing.T) { + request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil) + request.URL, _ = url.Parse("http://localhost/callback?error=provider_login_error") + + _, err := Authenticate(testConfig, request) + + assert.Error(t, err) + assert.Equal(t, "error: provider_login_error", err.Error()) +} + +func Test_Authentication_StateError(t *testing.T) { + request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil) + request.Header.Set("Cookie", "oauthState=XXXXXXX") + request.URL, _ = url.Parse("http://localhost/callback?code=theCode&state=theState") + + _, err := Authenticate(testConfig, request) + + assert.Error(t, err) + assert.Equal(t, "error: oauth state param could not be verified", err.Error()) +} + +func Test_Authentication_NoCodeError(t *testing.T) { + request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil) + request.Header.Set("Cookie", "oauthState=theState") + request.URL, _ = url.Parse("http://localhost/callback?state=theState") + + _, err := Authenticate(testConfig, request) + + assert.Error(t, err) + assert.Equal(t, "error: no auth code provided", err.Error()) +} + +func Test_Authentication_Provider500(t *testing.T) { + // mock a server for token exchange + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + })) + defer server.Close() + + testConfigCopy := testConfig + testConfigCopy.TokenURL = server.URL + + request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil) + request.Header.Set("Cookie", "oauthState=theState") + request.URL, _ = url.Parse("http://localhost/callback?code=theCode&state=theState") + + _, err := Authenticate(testConfigCopy, request) + + assert.Error(t, err) + assert.Equal(t, "error: expected http status 200 on token exchange, but got 500", err.Error()) +} + +func Test_Authentication_ProviderNetworkError(t *testing.T) { + + testConfigCopy := testConfig + testConfigCopy.TokenURL = "http://localhost:12345678" + + request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil) + request.Header.Set("Cookie", "oauthState=theState") + request.URL, _ = url.Parse("http://localhost/callback?code=theCode&state=theState") + + _, err := Authenticate(testConfigCopy, request) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid port") +} + +func Test_Authentication_TokenParseError(t *testing.T) { + // mock a server for token exchange + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_t`)) + + })) + defer server.Close() + + testConfigCopy := testConfig + testConfigCopy.TokenURL = server.URL + + request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil) + request.Header.Set("Cookie", "oauthState=theState") + request.URL, _ = url.Parse("http://localhost/callback?code=theCode&state=theState") + + _, err := Authenticate(testConfigCopy, request) + + assert.Error(t, err) + assert.Equal(t, "error on parsing oauth token: unexpected end of JSON input", err.Error()) +} diff --git a/oauth2/provider.go b/oauth2/provider.go @@ -0,0 +1,50 @@ +package oauth2 + +import ( + "github.com/tarent/loginsrv/model" +) + +// Oauth provider configuration +type Provider struct { + // The name to access the provider in the configuration + Name string + + // The oauth authentication url to redirect to + AuthURL string + + // The url for token exchange + TokenURL string + + // GetUserInfo is a provider specific Implementation + // for fetching the user information. + // Possible keys in the returned map are: + // username, email, name + GetUserInfo func(token TokenInfo) (u model.UserInfo, rawUserJson string, err error) +} + +var provider = map[string]Provider{} + +// Register an Oauth provider +func RegisterProvider(p Provider) { + provider[p.Name] = p +} + +// Unregister an Oauth provider +func UnRegisterProvider(name string) { + delete(provider, name) +} + +// Return a provider +func GetProvider(providerName string) (Provider, bool) { + p, exist := provider[providerName] + return p, exist +} + +// ProviderList returns the names of all registered provider +func ProviderList() []string { + list := make([]string, 0, len(provider)) + for k, _ := range provider { + list = append(list, k) + } + return list +} diff --git a/oauth2/provider_test.go b/oauth2/provider_test.go @@ -0,0 +1,16 @@ +package oauth2 + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func Test_ProviderRegistration(t *testing.T) { + github, exist := GetProvider("github") + assert.NotNil(t, github) + assert.True(t, exist) + + list := ProviderList() + assert.Equal(t, 1, len(list)) + assert.Equal(t, "github", list[0]) +} diff --git a/osiam/backend.go b/osiam/backend.go @@ -3,7 +3,7 @@ package osiam import ( "errors" "fmt" - "github.com/tarent/loginsrv/login" + "github.com/tarent/loginsrv/model" "net/url" ) @@ -29,13 +29,13 @@ func NewBackend(endpoint, clientId, clientSecret string) (*Backend, error) { }, nil } -func (b *Backend) Authenticate(username, password string) (bool, login.UserInfo, error) { +func (b *Backend) Authenticate(username, password string) (bool, model.UserInfo, error) { authenticated, _, err := b.client.GetTokenByPassword(username, password) if !authenticated || err != nil { - return authenticated, login.UserInfo{}, err + return authenticated, model.UserInfo{}, err } - userInfo := login.UserInfo{ - Username: username, + userInfo := model.UserInfo{ + Sub: username, } return true, userInfo, nil } diff --git a/osiam/backend_test.go b/osiam/backend_test.go @@ -2,7 +2,7 @@ package osiam import ( "github.com/stretchr/testify/assert" - "github.com/tarent/loginsrv/login" + "github.com/tarent/loginsrv/model" "net/http" "net/http/httptest" "testing" @@ -20,8 +20,8 @@ func TestBackend_Authenticate(t *testing.T) { assert.NoError(t, err) assert.True(t, authenticated) assert.Equal(t, - login.UserInfo{ - Username: "admin", + model.UserInfo{ + Sub: "admin", }, userInfo) diff --git a/osiam/setup.go b/osiam/setup.go @@ -1,6 +1,7 @@ package osiam import ( + "github.com/tarent/lib-compose/logging" "github.com/tarent/loginsrv/login" ) @@ -9,9 +10,14 @@ const OsiamProviderName = "osiam" func init() { login.RegisterProvider( &login.ProviderDescription{ - Name: OsiamProviderName, + Name: OsiamProviderName, + HelpText: "Osiam login backend opts: endpoint=..,client_id=..,client_secret=..", }, func(config map[string]string) (login.Backend, error) { - return NewBackend(config["endpoint"], config["clientId"], config["clientSecret"]) + if config["clientId"] != "" { + logging.Logger.Warn("DEPRECATED: please use 'client_id' and 'client_secret' in future.") + return NewBackend(config["endpoint"], config["clientId"], config["clientSecret"]) + } + return NewBackend(config["endpoint"], config["client_id"], config["client_secret"]) }) }