Exemplo n.º 1
0
Arquivo: http.go Projeto: set321go/dex
func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc {
	idx := makeConnectorMap(idpcs)
	return func(w http.ResponseWriter, r *http.Request) {
		if r.Method != "GET" {
			w.Header().Set("Allow", "GET")
			phttp.WriteError(w, http.StatusMethodNotAllowed, "GET only acceptable method")
			return
		}

		q := r.URL.Query()
		register := q.Get("register") == "1" && registrationEnabled
		e := q.Get("error")
		if e != "" {
			sessionKey := q.Get("state")
			if err := srv.KillSession(sessionKey); err != nil {
				log.Errorf("Failed killing sessionKey %q: %v", sessionKey, err)
			}
			renderLoginPage(w, r, srv, idpcs, register, tpl)
			return
		}

		connectorID := q.Get("connector_id")
		idpc, ok := idx[connectorID]
		if !ok {
			renderLoginPage(w, r, srv, idpcs, register, tpl)
			return
		}

		acr, err := oauth2.ParseAuthCodeRequest(q)
		if err != nil {
			log.Errorf("Invalid auth request: %v", err)
			writeAuthError(w, err, acr.State)
			return
		}

		cm, err := srv.ClientMetadata(acr.ClientID)
		if err != nil {
			log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err)
			writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
			return
		}
		if cm == nil {
			log.Errorf("Client %q not found", acr.ClientID)
			writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
			return
		}

		if len(cm.RedirectURLs) == 0 {
			log.Errorf("Client %q has no redirect URLs", acr.ClientID)
			writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
			return
		}

		redirectURL, err := client.ValidRedirectURL(acr.RedirectURL, cm.RedirectURLs)
		if err != nil {
			switch err {
			case (client.ErrorCantChooseRedirectURL):
				log.Errorf("Request must provide redirect URL as client %q has registered many", acr.ClientID)
				writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
				return
			case (client.ErrorInvalidRedirectURL):
				log.Errorf("Request provided unregistered redirect URL: %s", acr.RedirectURL)
				writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
				return
			case (client.ErrorNoValidRedirectURLs):
				log.Errorf("There are no registered URLs for the requested client: %s", acr.RedirectURL)
				writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
				return
			}
		}

		if acr.ResponseType != oauth2.ResponseTypeCode {
			log.Errorf("unexpected ResponseType: %v: ", acr.ResponseType)
			redirectAuthError(w, oauth2.NewError(oauth2.ErrorUnsupportedResponseType), acr.State, redirectURL)
			return
		}

		// Check scopes.
		var scopes []string
		foundOpenIDScope := false
		for _, scope := range acr.Scope {
			switch scope {
			case "openid":
				foundOpenIDScope = true
				scopes = append(scopes, scope)
			case "offline_access":
				// According to the spec, for offline_access scope, the client must
				// use a response_type value that would result in an Authorization Code.
				// Currently oauth2.ResponseTypeCode is the only supported response type,
				// and it's been checked above, so we don't need to check it again here.
				//
				// TODO(yifan): Verify that 'consent' should be in 'prompt'.
				scopes = append(scopes, scope)
			default:
				// Pass all other scopes.
				scopes = append(scopes, scope)
			}
		}

		if !foundOpenIDScope {
			log.Errorf("Invalid auth request: missing 'openid' in 'scope'")
			writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
			return
		}

		nonce := q.Get("nonce")

		key, err := srv.NewSession(connectorID, acr.ClientID, acr.State, redirectURL, nonce, register, acr.Scope)
		if err != nil {
			log.Errorf("Error creating new session: %v: ", err)
			redirectAuthError(w, err, acr.State, redirectURL)
			return
		}

		if register {
			_, ok := idpc.(*connector.LocalConnector)
			if ok {
				q := url.Values{}
				q.Set("code", key)
				ru := httpPathRegister + "?" + q.Encode()
				w.Header().Set("Location", ru)
				w.WriteHeader(http.StatusFound)
				return
			}
		}

		var p string
		if register {
			p = "select_account consent"
		}
		if shouldReprompt(r) || register {
			p = "select_account"
		}
		lu, err := idpc.LoginURL(key, p)
		if err != nil {
			log.Errorf("Connector.LoginURL failed: %v", err)
			redirectAuthError(w, err, acr.State, redirectURL)
			return
		}

		http.SetCookie(w, createLastSeenCookie())
		w.Header().Set("Location", lu)
		w.WriteHeader(http.StatusFound)
		return
	}
}
Exemplo n.º 2
0
Arquivo: http.go Projeto: Tecsisa/dex
func handleAuthFunc(srv OIDCServer, baseURL url.URL, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc {
	idx := makeConnectorMap(idpcs)
	return func(w http.ResponseWriter, r *http.Request) {
		if r.Method != "GET" {
			w.Header().Set("Allow", "GET")
			phttp.WriteError(w, http.StatusMethodNotAllowed, "GET only acceptable method")
			return
		}

		q := r.URL.Query()
		register := q.Get("register") == "1" && registrationEnabled
		e := q.Get("error")
		if e != "" {
			sessionKey := q.Get("state")
			if err := srv.KillSession(sessionKey); err != nil {
				log.Errorf("Failed killing sessionKey %q: %v", sessionKey, err)
			}
			renderLoginPage(w, r, srv, idpcs, register, tpl)
			return
		}

		connectorID := q.Get("connector_id")
		idpc, ok := idx[connectorID]
		if !ok {
			renderLoginPage(w, r, srv, idpcs, register, tpl)
			return
		}

		acr, err := oauth2.ParseAuthCodeRequest(q)
		if err != nil {
			log.Errorf("Invalid auth request: %v", err)
			writeAuthError(w, err, acr.State)
			return
		}

		cli, err := srv.Client(acr.ClientID)
		if err != nil {
			log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err)
			writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
			return
		}
		if err == client.ErrorNotFound {
			log.Errorf("Client %q not found", acr.ClientID)
			writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
			return
		}

		redirectURL, err := cli.ValidRedirectURL(acr.RedirectURL)
		if err != nil {
			switch err {
			case (client.ErrorCantChooseRedirectURL):
				log.Errorf("Request must provide redirect URL as client %q has registered many", acr.ClientID)
				writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
				return
			case (client.ErrorInvalidRedirectURL):
				log.Errorf("Request provided unregistered redirect URL: %s", acr.RedirectURL)
				writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
				return
			case (client.ErrorNoValidRedirectURLs):
				log.Errorf("There are no registered URLs for the requested client: %s", acr.RedirectURL)
				writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
				return
			}
		}

		if acr.ResponseType != oauth2.ResponseTypeCode {
			log.Errorf("unexpected ResponseType: %v: ", acr.ResponseType)
			redirectAuthError(w, oauth2.NewError(oauth2.ErrorUnsupportedResponseType), acr.State, redirectURL)
			return
		}

		// Check scopes.
		if scopeErr := validateScopes(srv, acr.ClientID, acr.Scope); scopeErr != nil {
			log.Error(scopeErr)
			writeAuthError(w, scopeErr, acr.State)
			return
		}

		nonce := q.Get("nonce")

		key, err := srv.NewSession(connectorID, acr.ClientID, acr.State, redirectURL, nonce, register, acr.Scope)
		if err != nil {
			log.Errorf("Error creating new session: %v: ", err)
			redirectAuthError(w, err, acr.State, redirectURL)
			return
		}

		if register {
			_, ok := idpc.(*connector.LocalConnector)
			if ok {
				q := url.Values{}
				q.Set("code", key)
				ru := path.Join(baseURL.Path, httpPathRegister) + "?" + q.Encode()
				w.Header().Set("Location", ru)
				w.WriteHeader(http.StatusFound)
				return
			}
		}

		var p string
		if register {
			p = "select_account consent"
		}
		if shouldReprompt(r) || register {
			p = "select_account"
		}
		lu, err := idpc.LoginURL(key, p)
		if err != nil {
			log.Errorf("Connector.LoginURL failed: %v", err)
			redirectAuthError(w, err, acr.State, redirectURL)
			return
		}

		http.SetCookie(w, createLastSeenCookie())
		w.Header().Set("Location", lu)
		w.WriteHeader(http.StatusFound)
		return
	}
}