// c_hash がおかしいなら拒否できることの検査。 func TestCallbackDenyInvalidCodeHash(t *testing.T) { // //////////////////////////////// // logutil.SetupConsole(logRoot, level.ALL) // defer logutil.SetupConsole(logRoot, level.OFF) // //////////////////////////////// idpServ, err := newTestIdProvider([]jwk.Key{test_idpKey}) if err != nil { t.Fatal(err) } defer idpServ.close() idp := idpServ.info() page := newTestPage([]jwk.Key{test_taKey}, []idpdb.Element{idp}) now := time.Now() sess := asession.New(test_sessId, now.Add(page.sessExpIn), test_reqPath, "", page.selfId, page.rediUri, test_stat, test_nonc) page.sessDb.Save(sess, now.Add(time.Minute)) r, err := newCallbackRequestWithIdToken(page, idp, map[string]interface{}{"c_hash": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"}) if err != nil { t.Fatal(err) } w := httptest.NewRecorder() page.HandleCallback(w, r) if w.Code != http.StatusForbidden { t.Error(w.Code) t.Fatal(http.StatusForbidden) } }
// 選択サービス経由のリクエストに対応できることの検査。 func TestCallbackThroughSelector(t *testing.T) { // //////////////////////////////// // logutil.SetupConsole(logRoot, level.ALL) // defer logutil.SetupConsole(logRoot, level.OFF) // //////////////////////////////// idpServ, err := newTestIdProvider([]jwk.Key{test_idpKey}) if err != nil { t.Fatal(err) } defer idpServ.close() idp := idpServ.info() page := newTestPage([]jwk.Key{test_taKey}, []idpdb.Element{idp}) now := time.Now() sess := asession.New(test_sessId, now.Add(page.sessExpIn), test_reqPath, "", page.selfId, page.rediUri, test_stat, test_nonc) page.sessDb.Save(sess, now.Add(time.Minute)) r, err := newCallbackRequestWithIdToken(page, idp, nil) if err != nil { t.Fatal(err) } s1, h1, b1, err := newTestTokenResponse(page, idp, nil) if err != nil { t.Fatal(err) } idpServ.addResponse(s1, h1, b1) s2, h2, b2, err := newTestAccountResponse(page, idp) if err != nil { t.Fatal(err) } idpServ.addResponse(s2, h2, b2) w := httptest.NewRecorder() page.HandleCallback(w, r) if w.Code != http.StatusFound { t.Error(w.Code) t.Fatal(http.StatusFound) } else if uri, err := url.Parse(test_reqPath); err != nil { t.Fatal(err) } else if uri2, err := url.Parse(w.HeaderMap.Get("Location")); err != nil { t.Fatal(err) } else if !reflect.DeepEqual(uri2, uri) { t.Error(uri2) t.Fatal(uri) } }
// nonce がおかしいなら拒否できることの検査。 func TestCallbackDenyInvalidNonce(t *testing.T) { // //////////////////////////////// // logutil.SetupConsole(logRoot, level.ALL) // defer logutil.SetupConsole(logRoot, level.OFF) // //////////////////////////////// idpServ, err := newTestIdProvider([]jwk.Key{test_idpKey}) if err != nil { t.Fatal(err) } defer idpServ.close() idp := idpServ.info() page := newTestPage([]jwk.Key{test_taKey}, []idpdb.Element{idp}) now := time.Now() sess := asession.New(test_sessId, now.Add(page.sessExpIn), test_reqPath, idp.Id(), page.selfId, page.rediUri, test_stat, test_nonc) page.sessDb.Save(sess, now.Add(time.Minute)) r, err := newCallbackRequest(page) if err != nil { t.Fatal(err) } s1, h1, b1, err := newTestTokenResponse(page, idp, map[string]interface{}{"nonce": test_nonc + "a"}) if err != nil { t.Fatal(err) } idpServ.addResponse(s1, h1, b1) s2, h2, b2, err := newTestAccountResponse(page, idp) if err != nil { t.Fatal(err) } idpServ.addResponse(s2, h2, b2) w := httptest.NewRecorder() page.HandleCallback(w, r) if w.Code != http.StatusForbidden { t.Error(w.Code) t.Fatal(http.StatusForbidden) } }
// state がおかしいなら拒否できることの検査。 func TestCallbackDenyInvalidState(t *testing.T) { // //////////////////////////////// // logutil.SetupConsole(logRoot, level.ALL) // defer logutil.SetupConsole(logRoot, level.OFF) // //////////////////////////////// idpServ, err := newTestIdProvider([]jwk.Key{test_idpKey}) if err != nil { t.Fatal(err) } defer idpServ.close() idp := idpServ.info() page := newTestPage([]jwk.Key{test_taKey}, []idpdb.Element{idp}) now := time.Now() sess := asession.New(test_sessId, now.Add(page.sessExpIn), test_reqPath, idp.Id(), page.selfId, page.rediUri, test_stat, test_nonc) page.sessDb.Save(sess, now.Add(time.Minute)) r, err := newCallbackRequest(page) if err != nil { t.Fatal(err) } { q := r.URL.Query() q.Set("state", test_stat+"a") r.URL.RawQuery = q.Encode() } w := httptest.NewRecorder() page.HandleCallback(w, r) if w.Code != http.StatusForbidden { t.Error(w.Code) t.Fatal(http.StatusForbidden) } }
// 正常系。 // 元のリクエストパスにリダイレクトさせることの検査。 // X-Auth-User ヘッダに iss, sub, at_tag, at_exp の入った JWT を入れることの検査。 // X-Auth-User ヘッダに追加属性を入れることの検査。 func TestCallback(t *testing.T) { // //////////////////////////////// // logutil.SetupConsole(logRoot, level.ALL) // defer logutil.SetupConsole(logRoot, level.OFF) // //////////////////////////////// idpServ, err := newTestIdProvider([]jwk.Key{test_idpKey}) if err != nil { t.Fatal(err) } defer idpServ.close() idp := idpServ.info() page := newTestPage([]jwk.Key{test_taKey}, []idpdb.Element{idp}) now := time.Now() sess := asession.New(test_sessId, now.Add(page.sessExpIn), test_reqPath, idp.Id(), page.selfId, page.rediUri, test_stat, test_nonc) page.sessDb.Save(sess, now.Add(time.Minute)) r, err := newCallbackRequest(page) if err != nil { t.Fatal(err) } s1, h1, b1, err := newTestTokenResponse(page, idp, nil) if err != nil { t.Fatal(err) } req1Ch := idpServ.addResponse(s1, h1, b1) s2, h2, b2, err := newTestAccountResponse(page, idp) if err != nil { t.Fatal(err) } req2Ch := idpServ.addResponse(s2, h2, b2) w := httptest.NewRecorder() page.HandleCallback(w, r) select { case req := <-req1Ch: if contType, contType2 := "application/x-www-form-urlencoded", req.Header.Get("Content-Type"); contType2 != contType { t.Error(contType) t.Fatal(contType2) } else if grntType, grntType2 := "authorization_code", req.FormValue("grant_type"); grntType2 != grntType { t.Error(grntType) t.Fatal(grntType2) } else if cod := req.FormValue("code"); cod != test_cod { t.Error(cod) t.Fatal(test_cod) } else if rediUri := req.FormValue("redirect_uri"); rediUri != page.rediUri { t.Error(rediUri) t.Fatal(page.rediUri) } else if taId := req.FormValue("client_id"); taId != page.selfId { t.Error(taId) t.Fatal(page.selfId) } else if assType, assType2 := "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", req.FormValue("client_assertion_type"); assType2 != assType { t.Error(assType2) t.Fatal(assType) } ass, err := jwt.Parse([]byte(req.FormValue("client_assertion"))) if err != nil { t.Fatal(err) } else if !ass.IsSigned() { t.Fatal("not signed") } else if err := ass.Verify([]jwk.Key{test_taKey}); err != nil { t.Fatal(err) } var buff struct { Iss string Sub string Aud audience.Audience Jti string Exp int Iat int } if err := json.Unmarshal(ass.RawBody(), &buff); err != nil { t.Fatal(err) } else if buff.Iss != page.selfId { t.Error(buff.Iss) t.Fatal(page.selfId) } else if buff.Sub != page.selfId { t.Error(buff.Sub) t.Fatal(page.selfId) } else if !buff.Aud[idp.TokenUri()] { t.Error(buff.Aud) t.Fatal(idp.TokenUri()) } else if len(buff.Jti) != page.jtiLen { t.Error(len(buff.Jti), " "+buff.Jti) t.Fatal(page.jtiLen) } else if buff.Exp == 0 { t.Fatal("no exp") } else if buff.Iat == 0 { t.Fatal("no iat") } else if !(buff.Iat < buff.Exp) { t.Error("exp not after iat") t.Error(buff.Iat) t.Fatal(buff.Exp) } case <-time.After(time.Minute): t.Fatal("no request") } select { case req := <-req2Ch: if auth := strings.Fields(req.Header.Get("Authorization")); len(auth) != 2 { t.Error("not 2 fields") t.Fatal(auth) } else if auth[0] != "Bearer" { t.Error(auth[0]) t.Fatal("Bearer") } else if auth[1] != test_tok { t.Error(auth[1]) t.Fatal(test_tok) } case <-time.After(time.Minute): t.Fatal("no request") } if w.Code != http.StatusFound { t.Error(w.Code) t.Fatal(http.StatusFound) } else if uri, err := url.Parse(test_reqPath); err != nil { t.Fatal(err) } else if uri2, err := url.Parse(w.HeaderMap.Get("Location")); err != nil { t.Fatal(err) } else if !reflect.DeepEqual(uri2, uri) { t.Error(uri2) t.Fatal(uri) } var buff struct { Iss string Sub string At_tag string At_exp int Email string } if jt, err := jwt.Parse([]byte(w.HeaderMap.Get("X-Auth-User"))); err != nil { t.Fatal(err) } else if err := json.Unmarshal(jt.RawBody(), &buff); err != nil { t.Fatal(err) } else if buff.Iss != idp.Id() { t.Error(buff.Iss) t.Fatal(idp.Id()) } else if buff.Sub != test_acntId { t.Error(buff.Sub) t.Fatal(test_acntId) } else if len(buff.At_tag) != page.tokTagLen { t.Error(len(buff.At_tag), buff.At_tag) t.Fatal(page.tokTagLen) } else if buff.At_exp == 0 { t.Fatal("no at_exp") } else if buff.Email != test_acntEmail { t.Error(buff.Email) t.Fatal(test_acntEmail) } }
func (this *environment) authServe(w http.ResponseWriter, r *http.Request) error { req, err := parseAuthRequest(r) if err != nil { return erro.Wrap(server.NewError(http.StatusBadRequest, erro.Unwrap(err).Error(), err)) } log.Debug(this.logPref, "Parsed authentication request") authUri := req.authUri() queries := authUri.Query() // response_type var idp string respType := map[string]bool{tagCode: true} rawAuthUri := authUri.Scheme + "://" + authUri.Host + authUri.Path if idps, err := this.idpDb.Search(map[string]string{ tagAuthorization_endpoint: "^" + rawAuthUri + "$", }); err != nil { return erro.Wrap(err) } else if len(idps) == 1 { idp = idps[0].Id() log.Debug(this.logPref, "Destination is in ID provider "+idp) } else { // ID プロバイダ選択サービスか何か。 respType[tagId_token] = true log.Debug(this.logPref, "Destination "+rawAuthUri+" is not ID provider") } queries.Set(tagResponse_type, request.ValueSetForm(respType)) // scope if scop := request.FormValueSet(queries.Get(tagScope)); !scop[tagOpenid] { scop[tagOpenid] = true queries.Set(tagScope, request.ValueSetForm(scop)) log.Debug(this.logPref, `Added scope "`+tagOpenid+`"`) } // client_id ta := queries.Get(tagClient_id) if ta == "" { ta = this.selfId queries.Set(tagClient_id, ta) log.Debug(this.logPref, "Act as default TA "+ta) } else { log.Debug(this.logPref, "Act as TA "+ta) } // redirect_uri rediUri := queries.Get(tagRedirect_uri) if rediUri == "" { rediUri = this.rediUri queries.Set(tagRedirect_uri, rediUri) log.Debug(this.logPref, "Use default redirect uri "+rediUri) } else { log.Debug(this.logPref, "Use redirect uri "+rediUri) } // state stat := this.idGen.String(this.statLen) queries.Set(tagState, stat) log.Debug(this.logPref, "Use state "+logutil.Mosaic(stat)) // nonce nonc := this.idGen.String(this.noncLen) queries.Set(tagNonce, nonc) log.Debug(this.logPref, "Use nonce "+logutil.Mosaic(nonc)) authUri.RawQuery = queries.Encode() sess := asession.New( this.idGen.String(this.sessLen), time.Now().Add(this.sessExpIn), req.path(), idp, ta, rediUri, stat, nonc, ) if err := this.sessDb.Save(sess, sess.Expires().Add(this.sessDbExpIn-this.sessExpIn)); err != nil { return erro.Wrap(err) } log.Info(this.logPref, "Generated user session "+logutil.Mosaic(sess.Id())) http.SetCookie(w, this.newCookie(sess.Id(), sess.Expires())) w.Header().Add(tagCache_control, tagNo_store) w.Header().Add(tagPragma, tagNo_cache) uri := authUri.String() http.Redirect(w, r, uri, http.StatusFound) log.Info(this.logPref, "Redirect to "+uri) return nil }