func main() { var exitCode = 0 defer func() { if exitCode != 0 { os.Exit(exitCode) } }() defer rglog.Flush() logutil.InitConsole(logRoot) param, err := parseParameters(os.Args...) if err != nil { log.Err(erro.Unwrap(err)) log.Debug(erro.Wrap(err)) exitCode = 1 return } logutil.SetupConsole(logRoot, param.consLv) if err := logutil.Setup(logRoot, param.logType, param.logLv, param); err != nil { log.Err(erro.Unwrap(err)) log.Debug(erro.Wrap(err)) exitCode = 1 return } if err := serve(param); err != nil { log.Err(erro.Unwrap(err)) log.Debug(erro.Wrap(err)) exitCode = 1 return } log.Info("Shut down") }
func (this *environment) callbackServe(w http.ResponseWriter, r *http.Request) error { if this.sessId == "" { return erro.Wrap(server.NewError(http.StatusBadRequest, "no session ", nil)) } sess, err := this.sessDb.Get(this.sessId) if err != nil { return erro.Wrap(err) } else if sess == nil { return erro.Wrap(server.NewError(http.StatusBadRequest, "declared user session is not exist", nil)) } this.sess = sess log.Debug(this.logPref, "Declared user session is exist") savedDate := sess.Date() sess.Invalidate() if ok, err := this.sessDb.Replace(sess, savedDate); err != nil { return erro.Wrap(err) } else if !ok { return erro.Wrap(server.NewError(http.StatusBadRequest, "reused user session", nil)) } req, err := parseCallbackRequest(r) if err != nil { return erro.Wrap(server.NewError(http.StatusBadRequest, erro.Unwrap(err).Error(), err)) } log.Debug(this.logPref, "Parsed callback request") if req.state() != sess.State() { return erro.Wrap(server.NewError(http.StatusForbidden, "invalid state", nil)) } var idp idpdb.Element var attrs1 map[string]interface{} if sess.IdProvider() != "" { idp, err = this.idpDb.Get(sess.IdProvider()) if err != nil { return erro.Wrap(err) } else if idp == nil { return erro.Wrap(server.NewError(http.StatusBadRequest, "ID provider "+sess.IdProvider()+" is not exist", nil)) } log.Debug(this.logPref, "ID provider "+idp.Id()+" is exist") } else { idTok, err := parseIdToken(req.idToken()) if err != nil { return erro.Wrap(server.NewError(http.StatusBadRequest, erro.Unwrap(err).Error(), err)) } idp, err = this.idpDb.Get(idTok.idProvider()) if err != nil { return erro.Wrap(err) } else if idp == nil { return erro.Wrap(server.NewError(http.StatusBadRequest, "ID provider "+idTok.idProvider()+" is not exist", nil)) } log.Debug(this.logPref, "ID provider "+idp.Id()+" is exist") if idTok.nonce() != sess.Nonce() { return erro.Wrap(server.NewError(http.StatusForbidden, "invalid nonce", nil)) } else if err := idTok.verify(idp.Keys()); err != nil { return erro.Wrap(server.NewError(http.StatusForbidden, erro.Unwrap(err).Error(), err)) } else if err := idTok.verifyCodeHash(req.code()); err != nil { return erro.Wrap(server.NewError(http.StatusForbidden, erro.Unwrap(err).Error(), err)) } attrs1 = idTok.attributes() log.Debug(this.logPref, "ID token is OK") } // アクセストークンを取得する。 tok, idTok, err := this.getAccessToken(req, idp) if err != nil { return erro.Wrap(err) } // アカウント情報を取得する。 attrs2, err := this.getAccountInfo(req, tok, idp) if err != nil { return erro.Wrap(err) } // アカウント情報をまとめる。 jt := jwt.New() jt.SetHeader(tagAlg, tagNone) for _, m := range []map[string]interface{}{attrs1, idTok.attributes(), attrs2} { for k, v := range m { jt.SetClaim(k, v) } } jt.SetClaim(tagAt_tag, tok.Tag()) jt.SetClaim(tagAt_exp, tok.Expires().Unix()) buff, err := jt.Encode() if err != nil { return erro.Wrap(err) } // フロントエンドのためにセッション期限を延長する。 now := time.Now() http.SetCookie(w, this.newCookie(sess.Id(), now.Add(-time.Second))) http.SetCookie(w, this.newFrontCookie(this.idGen.String(this.fsessLen), now.Add(this.fsessExpIn))) log.Info(this.logPref, "Upgrade user session to frontend session") // フロントエンドが使うので保存しなくて良い。 w.Header().Add(tagX_auth_user, string(buff)) w.Header().Add(tagCache_control, tagNo_store) w.Header().Add(tagPragma, tagNo_cache) http.Redirect(w, r, sess.Path(), http.StatusFound) log.Info(this.logPref, "Redirect to "+sess.Path()) return nil }
// 認可コードを使って、ID プロバイダからアクセストークンを取得する。 func (this *environment) getAccessToken(req *callbackRequest, idp idpdb.Element) (*token.Element, *idToken, error) { keys, err := this.keyDb.Get() if err != nil { return nil, nil, erro.Wrap(err) } queries := url.Values{} // grant_type queries.Set(tagGrant_type, tagAuthorization_code) // code queries.Set(tagCode, req.code()) // redirect_uri queries.Set(tagRedirect_uri, this.sess.RedirectUri()) // client_id queries.Set(tagClient_id, this.sess.Ta()) // client_assertion_type queries.Set(tagClient_assertion_type, cliAssTypeJwt_bearer) // client_assertion ass := jwt.New() now := time.Now() ass.SetHeader(tagAlg, this.sigAlg) ass.SetClaim(tagIss, this.sess.Ta()) ass.SetClaim(tagSub, this.sess.Ta()) ass.SetClaim(tagAud, idp.TokenUri()) ass.SetClaim(tagJti, this.idGen.String(this.jtiLen)) ass.SetClaim(tagExp, now.Add(this.jtiExpIn).Unix()) ass.SetClaim(tagIat, now.Unix()) if err := ass.Sign(keys); err != nil { return nil, nil, erro.Wrap(err) } buff, err := ass.Encode() if err != nil { return nil, nil, erro.Wrap(err) } queries.Set(tagClient_assertion, string(buff)) tokReq, err := http.NewRequest("POST", idp.TokenUri(), strings.NewReader(queries.Encode())) if err != nil { return nil, nil, erro.Wrap(err) } tokReq.Header.Set(tagContent_type, contTypeForm) server.LogRequest(level.DEBUG, tokReq, this.debug, this.logPref) resp, err := http.DefaultClient.Do(tokReq) if err != nil { return nil, nil, erro.Wrap(err) } defer resp.Body.Close() server.LogResponse(level.DEBUG, resp, this.debug, this.logPref) log.Info(this.logPref, "Got token response from "+idp.Id()) tokResp, err := parseTokenResponse(resp) if err != nil { return nil, nil, erro.Wrap(server.NewError(http.StatusForbidden, "cannot get access token", nil)) } tok := token.New(tokResp.token(), this.idGen.String(this.tokTagLen), tokResp.expires(), idp.Id(), tokResp.scope()) log.Info(this.logPref, "Got access token "+logutil.Mosaic(tok.Id())) if err := this.tokDb.Save(tok, time.Now().Add(this.tokDbExpIn)); err != nil { return nil, nil, erro.Wrap(err) } log.Info(this.logPref, "Saved access token "+logutil.Mosaic(tok.Id())) idTok, err := parseIdToken(tokResp.idToken()) if err != nil { return nil, nil, erro.Wrap(server.NewError(http.StatusForbidden, erro.Unwrap(err).Error(), err)) } else if idTok.nonce() != this.sess.Nonce() { return nil, nil, erro.Wrap(server.NewError(http.StatusForbidden, "invalid nonce", nil)) } else if err := idTok.verify(idp.Keys()); err != nil { return nil, nil, erro.Wrap(server.NewError(http.StatusForbidden, erro.Unwrap(err).Error(), err)) } else if idTok.tokenHash() != nil { if err := idTok.verifyTokenHash(tok.Id()); err != nil { return nil, nil, erro.Wrap(server.NewError(http.StatusForbidden, erro.Unwrap(err).Error(), err)) } } log.Info(this.logPref, "ID token is OK") return tok, idTok, nil }
func (this *environment) getInfo(isMain bool, idp idpdb.Element, codTok *codeToken) (frTa string, tok *token.Element, tagToAttrs map[string]map[string]interface{}, err error) { params := map[string]interface{}{} // grant_type params[tagGrant_type] = tagCooperation_code // code params[tagCode] = codTok.code() // claims if isMain { // TODO 受け取り方を考えないと。 } // user_claims // TODO 受け取り方を考えないと。 // client_assertion_type params[tagClient_assertion_type] = cliAssTypeJwt_bearer // client_assertion keys, err := this.keyDb.Get() if err != nil { return "", nil, nil, erro.Wrap(err) } ass, err := makeAssertion(this.handler, keys, idp.CoopToUri()) if err != nil { return "", nil, nil, erro.Wrap(err) } params[tagClient_assertion] = string(ass) data, err := json.Marshal(params) if err != nil { return "", nil, nil, erro.Wrap(err) } r, err := http.NewRequest("POST", idp.CoopToUri(), bytes.NewReader(data)) r.Header.Set(tagContent_type, contTypeJson) log.Debug(this.logPref, "Made main cooperation-to request") server.LogRequest(level.DEBUG, r, this.debug, this.logPref) resp, err := this.conn.Do(r) if err != nil { return "", nil, nil, erro.Wrap(err) } defer resp.Body.Close() server.LogResponse(level.DEBUG, resp, this.debug, this.logPref) if resp.StatusCode != http.StatusOK { var buff struct { Error string Error_description string } if err := json.NewDecoder(resp.Body).Decode(&buff); err != nil { return "", nil, nil, erro.Wrap(err) } return "", nil, nil, erro.Wrap(idperr.New(buff.Error, buff.Error_description, resp.StatusCode, nil)) } coopResp, err := parseCoopResponse(resp) if err != nil { return "", nil, nil, erro.Wrap(idperr.New(idperr.Access_denied, erro.Unwrap(err).Error(), http.StatusForbidden, err)) } idsTok, err := parseIdsToken(coopResp.idsToken()) if err != nil { return "", nil, nil, erro.Wrap(idperr.New(idperr.Access_denied, erro.Unwrap(err).Error(), http.StatusForbidden, err)) } else if err := idsTok.verify(idp.Keys()); err != nil { return "", nil, nil, erro.Wrap(idperr.New(idperr.Access_denied, erro.Unwrap(err).Error(), http.StatusForbidden, err)) } tagToAttrs = map[string]map[string]interface{}{} for acntTag := range codTok.accountTags() { attrs := idsTok.attributes()[acntTag] if attrs == nil { return "", nil, nil, erro.Wrap(idperr.New(idperr.Access_denied, "cannot get sub account tagged by "+acntTag, http.StatusForbidden, nil)) } tagToAttrs[acntTag] = attrs } if isMain { attrs := idsTok.attributes()[codTok.accountTag()] if attrs == nil { return "", nil, nil, erro.Wrap(idperr.New(idperr.Access_denied, "cannot get main account tagged by "+codTok.accountTag(), http.StatusForbidden, nil)) } tagToAttrs[codTok.accountTag()] = attrs if coopResp.token() == "" { return "", nil, nil, erro.Wrap(idperr.New(idperr.Access_denied, "cannot get token", http.StatusForbidden, nil)) } now := time.Now() tok = token.New(coopResp.token(), this.idGen.String(this.tokTagLen), now.Add(coopResp.expiresIn()), idsTok.idProvider(), coopResp.scope()) log.Info(this.logPref, "Got access token "+logutil.Mosaic(tok.Id())) if err := this.tokDb.Save(tok, now.Add(this.tokDbExpIn)); err != nil { return "", nil, nil, erro.Wrap(err) } log.Info(this.logPref, "Saved access token "+logutil.Mosaic(tok.Id())) } return idsTok.fromTa(), tok, tagToAttrs, nil }
func (this *environment) serve(w http.ResponseWriter, r *http.Request) error { req, err := parseRequest(r) if err != nil { return erro.Wrap(idperr.New(idperr.Invalid_request, erro.Unwrap(err).Error(), http.StatusBadRequest, err)) } var acntTag string tags := map[string]bool{} var refHash string type idpUnit struct { idp idpdb.Element codTok *codeToken } units := []*idpUnit{} for _, rawCodTok := range req.codeTokens() { codTok, err := parseCodeToken(rawCodTok) if err != nil { return erro.Wrap(idperr.New(idperr.Invalid_request, erro.Unwrap(err).Error(), http.StatusBadRequest, err)) } else if !codTok.audience()[this.selfId] { return erro.Wrap(idperr.New(idperr.Invalid_request, "invalid audience", http.StatusBadRequest, nil)) } else if codTok.referralHash() == "" && len(req.codeTokens()) > 1 { return erro.Wrap(idperr.New(idperr.Invalid_request, "no referral hash", http.StatusBadRequest, nil)) } else if codTok.accountTag() != "" { if acntTag != "" { return erro.Wrap(idperr.New(idperr.Invalid_request, "two main token", http.StatusBadRequest, nil)) } acntTag = codTok.accountTag() log.Debug(this.logPref, "Main account tag is "+acntTag) if codTok.fromTa() == "" { return erro.Wrap(idperr.New(idperr.Invalid_request, "no from-TA in main token", http.StatusBadRequest, nil)) } log.Debug(this.logPref, "From-TA is "+codTok.fromTa()) } else if len(codTok.accountTags()) == 0 { return erro.Wrap(idperr.New(idperr.Invalid_request, "no account tags in sub token", http.StatusBadRequest, nil)) } for tag := range codTok.accountTags() { if tags[tag] { return erro.Wrap(idperr.New(idperr.Invalid_request, "tag "+tag+" overlaps", http.StatusBadRequest, nil)) } tags[tag] = true log.Debug(this.logPref, "Account tag is "+tag) } if codTok.referralHash() != "" { if refHash == "" { refHash = codTok.referralHash() } else if codTok.referralHash() != refHash { return erro.Wrap(idperr.New(idperr.Invalid_request, "invalid referral hash", http.StatusBadRequest, nil)) } } var idp idpdb.Element if idp, err = this.idpDb.Get(codTok.idProvider()); err != nil { return erro.Wrap(err) } else if idp == nil { return erro.Wrap(idperr.New(idperr.Invalid_request, "ID provider "+codTok.idProvider()+" is not exist", http.StatusBadRequest, nil)) } log.Debug(this.logPref, "ID provider "+idp.Id()+" is exist") if err := codTok.verify(idp.Keys()); err != nil { return erro.Wrap(idperr.New(idperr.Invalid_request, erro.Unwrap(err).Error(), http.StatusBadRequest, err)) } log.Debug(this.logPref, "Verified cooperation code") units = append(units, &idpUnit{idp, codTok}) } if acntTag == "" { return erro.Wrap(idperr.New(idperr.Invalid_request, "no main account tag", http.StatusBadRequest, nil)) } log.Debug(this.logPref, "Cooperation codes are OK") var tok *token.Element var mainAttrs map[string]interface{} tagToAttrs := map[string]map[string]interface{}{} var frTa string for _, unit := range units { var tToA map[string]map[string]interface{} var fT string if unit.codTok.accountTag() != "" { fT, tok, tToA, err = this.getInfoFromMainIdProvider(unit.idp, unit.codTok) if err != nil { return erro.Wrap(err) } log.Debug(this.logPref, "Got account info from main ID provider "+unit.idp.Id()) } else { fT, tToA, err = this.getInfoFromSubIdProvider(unit.idp, unit.codTok) if err != nil { return erro.Wrap(err) } log.Debug(this.logPref, "Got account info from sub ID provider "+unit.idp.Id()) } for tag, attrs := range tToA { if tag == acntTag { attrs[tagIss] = unit.idp.Id() attrs[tagAt_tag] = tok.Tag() attrs[tagAt_exp] = tok.Expires().Unix() mainAttrs = attrs } else { attrs[tagIss] = unit.idp.Id() tagToAttrs[tag] = attrs } } if frTa == "" { frTa = fT } else if frTa != fT { return erro.Wrap(idperr.New(idperr.Invalid_request, "two from-TA ID", http.StatusBadRequest, nil)) } } log.Debug(this.logPref, "Got all account info") jt := jwt.New() jt.SetHeader(tagAlg, tagNone) for k, v := range mainAttrs { jt.SetClaim(k, v) } mainInfo, err := jt.Encode() if err != nil { return erro.Wrap(err) } var relInfo []byte sessFlag := true if len(tagToAttrs) > 0 { jt = jwt.New() jt.SetHeader(tagAlg, tagNone) for acntTag, attrs := range tagToAttrs { jt.SetClaim(acntTag, attrs) if sessFlag { for k := range attrs { if !sessionEnable(k) { sessFlag = false } } } } relInfo, err = jt.Encode() if err != nil { return erro.Wrap(err) } } w.Header().Set(tagX_auth_user, string(mainInfo)) w.Header().Set(tagX_auth_user_tag, acntTag) w.Header().Set(tagX_auth_from_id, frTa) if relInfo != nil { w.Header().Set(tagX_auth_users, string(relInfo)) } if sessFlag { sessId := this.idGen.String(this.sessLen) http.SetCookie(w, this.handler.newCookie(sessId, tok.Expires())) log.Debug(this.logPref, "Report session "+logutil.Mosaic(sessId)) } return nil }
func (this *environment) authServe(w http.ResponseWriter, r *http.Request) error { req, err := parseAuthRequest(r) if err != nil { return erro.Wrap(server.NewError(http.StatusBadRequest, erro.Unwrap(err).Error(), err)) } log.Debug(this.logPref, "Parsed authentication request") authUri := req.authUri() queries := authUri.Query() // response_type var idp string respType := map[string]bool{tagCode: true} rawAuthUri := authUri.Scheme + "://" + authUri.Host + authUri.Path if idps, err := this.idpDb.Search(map[string]string{ tagAuthorization_endpoint: "^" + rawAuthUri + "$", }); err != nil { return erro.Wrap(err) } else if len(idps) == 1 { idp = idps[0].Id() log.Debug(this.logPref, "Destination is in ID provider "+idp) } else { // ID プロバイダ選択サービスか何か。 respType[tagId_token] = true log.Debug(this.logPref, "Destination "+rawAuthUri+" is not ID provider") } queries.Set(tagResponse_type, request.ValueSetForm(respType)) // scope if scop := request.FormValueSet(queries.Get(tagScope)); !scop[tagOpenid] { scop[tagOpenid] = true queries.Set(tagScope, request.ValueSetForm(scop)) log.Debug(this.logPref, `Added scope "`+tagOpenid+`"`) } // client_id ta := queries.Get(tagClient_id) if ta == "" { ta = this.selfId queries.Set(tagClient_id, ta) log.Debug(this.logPref, "Act as default TA "+ta) } else { log.Debug(this.logPref, "Act as TA "+ta) } // redirect_uri rediUri := queries.Get(tagRedirect_uri) if rediUri == "" { rediUri = this.rediUri queries.Set(tagRedirect_uri, rediUri) log.Debug(this.logPref, "Use default redirect uri "+rediUri) } else { log.Debug(this.logPref, "Use redirect uri "+rediUri) } // state stat := this.idGen.String(this.statLen) queries.Set(tagState, stat) log.Debug(this.logPref, "Use state "+logutil.Mosaic(stat)) // nonce nonc := this.idGen.String(this.noncLen) queries.Set(tagNonce, nonc) log.Debug(this.logPref, "Use nonce "+logutil.Mosaic(nonc)) authUri.RawQuery = queries.Encode() sess := asession.New( this.idGen.String(this.sessLen), time.Now().Add(this.sessExpIn), req.path(), idp, ta, rediUri, stat, nonc, ) if err := this.sessDb.Save(sess, sess.Expires().Add(this.sessDbExpIn-this.sessExpIn)); err != nil { return erro.Wrap(err) } log.Info(this.logPref, "Generated user session "+logutil.Mosaic(sess.Id())) http.SetCookie(w, this.newCookie(sess.Id(), sess.Expires())) w.Header().Add(tagCache_control, tagNo_store) w.Header().Add(tagPragma, tagNo_cache) uri := authUri.String() http.Redirect(w, r, uri, http.StatusFound) log.Info(this.logPref, "Redirect to "+uri) return nil }