// c_hash がおかしいなら拒否できることの検査。
func TestCallbackDenyInvalidCodeHash(t *testing.T) {
	// ////////////////////////////////
	// logutil.SetupConsole(logRoot, level.ALL)
	// defer logutil.SetupConsole(logRoot, level.OFF)
	// ////////////////////////////////

	idpServ, err := newTestIdProvider([]jwk.Key{test_idpKey})
	if err != nil {
		t.Fatal(err)
	}
	defer idpServ.close()
	idp := idpServ.info()
	page := newTestPage([]jwk.Key{test_taKey}, []idpdb.Element{idp})

	now := time.Now()
	sess := asession.New(test_sessId, now.Add(page.sessExpIn), test_reqPath, "", page.selfId, page.rediUri, test_stat, test_nonc)
	page.sessDb.Save(sess, now.Add(time.Minute))

	r, err := newCallbackRequestWithIdToken(page, idp, map[string]interface{}{"c_hash": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"})
	if err != nil {
		t.Fatal(err)
	}

	w := httptest.NewRecorder()
	page.HandleCallback(w, r)

	if w.Code != http.StatusForbidden {
		t.Error(w.Code)
		t.Fatal(http.StatusForbidden)
	}
}
// 選択サービス経由のリクエストに対応できることの検査。
func TestCallbackThroughSelector(t *testing.T) {
	// ////////////////////////////////
	// logutil.SetupConsole(logRoot, level.ALL)
	// defer logutil.SetupConsole(logRoot, level.OFF)
	// ////////////////////////////////

	idpServ, err := newTestIdProvider([]jwk.Key{test_idpKey})
	if err != nil {
		t.Fatal(err)
	}
	defer idpServ.close()
	idp := idpServ.info()
	page := newTestPage([]jwk.Key{test_taKey}, []idpdb.Element{idp})

	now := time.Now()
	sess := asession.New(test_sessId, now.Add(page.sessExpIn), test_reqPath, "", page.selfId, page.rediUri, test_stat, test_nonc)
	page.sessDb.Save(sess, now.Add(time.Minute))

	r, err := newCallbackRequestWithIdToken(page, idp, nil)
	if err != nil {
		t.Fatal(err)
	}

	s1, h1, b1, err := newTestTokenResponse(page, idp, nil)
	if err != nil {
		t.Fatal(err)
	}
	idpServ.addResponse(s1, h1, b1)
	s2, h2, b2, err := newTestAccountResponse(page, idp)
	if err != nil {
		t.Fatal(err)
	}
	idpServ.addResponse(s2, h2, b2)

	w := httptest.NewRecorder()
	page.HandleCallback(w, r)

	if w.Code != http.StatusFound {
		t.Error(w.Code)
		t.Fatal(http.StatusFound)
	} else if uri, err := url.Parse(test_reqPath); err != nil {
		t.Fatal(err)
	} else if uri2, err := url.Parse(w.HeaderMap.Get("Location")); err != nil {
		t.Fatal(err)
	} else if !reflect.DeepEqual(uri2, uri) {
		t.Error(uri2)
		t.Fatal(uri)
	}
}
// nonce がおかしいなら拒否できることの検査。
func TestCallbackDenyInvalidNonce(t *testing.T) {
	// ////////////////////////////////
	// logutil.SetupConsole(logRoot, level.ALL)
	// defer logutil.SetupConsole(logRoot, level.OFF)
	// ////////////////////////////////

	idpServ, err := newTestIdProvider([]jwk.Key{test_idpKey})
	if err != nil {
		t.Fatal(err)
	}
	defer idpServ.close()
	idp := idpServ.info()
	page := newTestPage([]jwk.Key{test_taKey}, []idpdb.Element{idp})

	now := time.Now()
	sess := asession.New(test_sessId, now.Add(page.sessExpIn), test_reqPath, idp.Id(), page.selfId, page.rediUri, test_stat, test_nonc)
	page.sessDb.Save(sess, now.Add(time.Minute))

	r, err := newCallbackRequest(page)
	if err != nil {
		t.Fatal(err)
	}

	s1, h1, b1, err := newTestTokenResponse(page, idp, map[string]interface{}{"nonce": test_nonc + "a"})
	if err != nil {
		t.Fatal(err)
	}
	idpServ.addResponse(s1, h1, b1)
	s2, h2, b2, err := newTestAccountResponse(page, idp)
	if err != nil {
		t.Fatal(err)
	}
	idpServ.addResponse(s2, h2, b2)

	w := httptest.NewRecorder()
	page.HandleCallback(w, r)

	if w.Code != http.StatusForbidden {
		t.Error(w.Code)
		t.Fatal(http.StatusForbidden)
	}
}
// state がおかしいなら拒否できることの検査。
func TestCallbackDenyInvalidState(t *testing.T) {
	// ////////////////////////////////
	// logutil.SetupConsole(logRoot, level.ALL)
	// defer logutil.SetupConsole(logRoot, level.OFF)
	// ////////////////////////////////

	idpServ, err := newTestIdProvider([]jwk.Key{test_idpKey})
	if err != nil {
		t.Fatal(err)
	}
	defer idpServ.close()
	idp := idpServ.info()
	page := newTestPage([]jwk.Key{test_taKey}, []idpdb.Element{idp})

	now := time.Now()
	sess := asession.New(test_sessId, now.Add(page.sessExpIn), test_reqPath, idp.Id(), page.selfId, page.rediUri, test_stat, test_nonc)
	page.sessDb.Save(sess, now.Add(time.Minute))

	r, err := newCallbackRequest(page)
	if err != nil {
		t.Fatal(err)
	}
	{
		q := r.URL.Query()
		q.Set("state", test_stat+"a")
		r.URL.RawQuery = q.Encode()
	}

	w := httptest.NewRecorder()
	page.HandleCallback(w, r)

	if w.Code != http.StatusForbidden {
		t.Error(w.Code)
		t.Fatal(http.StatusForbidden)
	}
}
// 正常系。
// 元のリクエストパスにリダイレクトさせることの検査。
// X-Auth-User ヘッダに iss, sub, at_tag, at_exp の入った JWT を入れることの検査。
// X-Auth-User ヘッダに追加属性を入れることの検査。
func TestCallback(t *testing.T) {
	// ////////////////////////////////
	// logutil.SetupConsole(logRoot, level.ALL)
	// defer logutil.SetupConsole(logRoot, level.OFF)
	// ////////////////////////////////

	idpServ, err := newTestIdProvider([]jwk.Key{test_idpKey})
	if err != nil {
		t.Fatal(err)
	}
	defer idpServ.close()
	idp := idpServ.info()
	page := newTestPage([]jwk.Key{test_taKey}, []idpdb.Element{idp})

	now := time.Now()
	sess := asession.New(test_sessId, now.Add(page.sessExpIn), test_reqPath, idp.Id(), page.selfId, page.rediUri, test_stat, test_nonc)
	page.sessDb.Save(sess, now.Add(time.Minute))

	r, err := newCallbackRequest(page)
	if err != nil {
		t.Fatal(err)
	}

	s1, h1, b1, err := newTestTokenResponse(page, idp, nil)
	if err != nil {
		t.Fatal(err)
	}
	req1Ch := idpServ.addResponse(s1, h1, b1)
	s2, h2, b2, err := newTestAccountResponse(page, idp)
	if err != nil {
		t.Fatal(err)
	}
	req2Ch := idpServ.addResponse(s2, h2, b2)

	w := httptest.NewRecorder()
	page.HandleCallback(w, r)

	select {
	case req := <-req1Ch:
		if contType, contType2 := "application/x-www-form-urlencoded", req.Header.Get("Content-Type"); contType2 != contType {
			t.Error(contType)
			t.Fatal(contType2)
		} else if grntType, grntType2 := "authorization_code", req.FormValue("grant_type"); grntType2 != grntType {
			t.Error(grntType)
			t.Fatal(grntType2)
		} else if cod := req.FormValue("code"); cod != test_cod {
			t.Error(cod)
			t.Fatal(test_cod)
		} else if rediUri := req.FormValue("redirect_uri"); rediUri != page.rediUri {
			t.Error(rediUri)
			t.Fatal(page.rediUri)
		} else if taId := req.FormValue("client_id"); taId != page.selfId {
			t.Error(taId)
			t.Fatal(page.selfId)
		} else if assType, assType2 := "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", req.FormValue("client_assertion_type"); assType2 != assType {
			t.Error(assType2)
			t.Fatal(assType)
		}
		ass, err := jwt.Parse([]byte(req.FormValue("client_assertion")))
		if err != nil {
			t.Fatal(err)
		} else if !ass.IsSigned() {
			t.Fatal("not signed")
		} else if err := ass.Verify([]jwk.Key{test_taKey}); err != nil {
			t.Fatal(err)
		}
		var buff struct {
			Iss string
			Sub string
			Aud audience.Audience
			Jti string
			Exp int
			Iat int
		}
		if err := json.Unmarshal(ass.RawBody(), &buff); err != nil {
			t.Fatal(err)
		} else if buff.Iss != page.selfId {
			t.Error(buff.Iss)
			t.Fatal(page.selfId)
		} else if buff.Sub != page.selfId {
			t.Error(buff.Sub)
			t.Fatal(page.selfId)
		} else if !buff.Aud[idp.TokenUri()] {
			t.Error(buff.Aud)
			t.Fatal(idp.TokenUri())
		} else if len(buff.Jti) != page.jtiLen {
			t.Error(len(buff.Jti), " "+buff.Jti)
			t.Fatal(page.jtiLen)
		} else if buff.Exp == 0 {
			t.Fatal("no exp")
		} else if buff.Iat == 0 {
			t.Fatal("no iat")
		} else if !(buff.Iat < buff.Exp) {
			t.Error("exp not after iat")
			t.Error(buff.Iat)
			t.Fatal(buff.Exp)
		}
	case <-time.After(time.Minute):
		t.Fatal("no request")
	}

	select {
	case req := <-req2Ch:
		if auth := strings.Fields(req.Header.Get("Authorization")); len(auth) != 2 {
			t.Error("not 2 fields")
			t.Fatal(auth)
		} else if auth[0] != "Bearer" {
			t.Error(auth[0])
			t.Fatal("Bearer")
		} else if auth[1] != test_tok {
			t.Error(auth[1])
			t.Fatal(test_tok)
		}
	case <-time.After(time.Minute):
		t.Fatal("no request")
	}

	if w.Code != http.StatusFound {
		t.Error(w.Code)
		t.Fatal(http.StatusFound)
	} else if uri, err := url.Parse(test_reqPath); err != nil {
		t.Fatal(err)
	} else if uri2, err := url.Parse(w.HeaderMap.Get("Location")); err != nil {
		t.Fatal(err)
	} else if !reflect.DeepEqual(uri2, uri) {
		t.Error(uri2)
		t.Fatal(uri)
	}
	var buff struct {
		Iss    string
		Sub    string
		At_tag string
		At_exp int
		Email  string
	}
	if jt, err := jwt.Parse([]byte(w.HeaderMap.Get("X-Auth-User"))); err != nil {
		t.Fatal(err)
	} else if err := json.Unmarshal(jt.RawBody(), &buff); err != nil {
		t.Fatal(err)
	} else if buff.Iss != idp.Id() {
		t.Error(buff.Iss)
		t.Fatal(idp.Id())
	} else if buff.Sub != test_acntId {
		t.Error(buff.Sub)
		t.Fatal(test_acntId)
	} else if len(buff.At_tag) != page.tokTagLen {
		t.Error(len(buff.At_tag), buff.At_tag)
		t.Fatal(page.tokTagLen)
	} else if buff.At_exp == 0 {
		t.Fatal("no at_exp")
	} else if buff.Email != test_acntEmail {
		t.Error(buff.Email)
		t.Fatal(test_acntEmail)
	}
}
示例#6
0
func (this *environment) authServe(w http.ResponseWriter, r *http.Request) error {
	req, err := parseAuthRequest(r)
	if err != nil {
		return erro.Wrap(server.NewError(http.StatusBadRequest, erro.Unwrap(err).Error(), err))
	}

	log.Debug(this.logPref, "Parsed authentication request")

	authUri := req.authUri()
	queries := authUri.Query()

	// response_type
	var idp string
	respType := map[string]bool{tagCode: true}
	rawAuthUri := authUri.Scheme + "://" + authUri.Host + authUri.Path
	if idps, err := this.idpDb.Search(map[string]string{
		tagAuthorization_endpoint: "^" + rawAuthUri + "$",
	}); err != nil {
		return erro.Wrap(err)
	} else if len(idps) == 1 {
		idp = idps[0].Id()
		log.Debug(this.logPref, "Destination is in ID provider "+idp)
	} else {
		// ID プロバイダ選択サービスか何か。
		respType[tagId_token] = true
		log.Debug(this.logPref, "Destination "+rawAuthUri+" is not ID provider")
	}
	queries.Set(tagResponse_type, request.ValueSetForm(respType))

	// scope
	if scop := request.FormValueSet(queries.Get(tagScope)); !scop[tagOpenid] {
		scop[tagOpenid] = true
		queries.Set(tagScope, request.ValueSetForm(scop))
		log.Debug(this.logPref, `Added scope "`+tagOpenid+`"`)
	}

	// client_id
	ta := queries.Get(tagClient_id)
	if ta == "" {
		ta = this.selfId
		queries.Set(tagClient_id, ta)
		log.Debug(this.logPref, "Act as default TA "+ta)
	} else {
		log.Debug(this.logPref, "Act as TA "+ta)
	}

	// redirect_uri
	rediUri := queries.Get(tagRedirect_uri)
	if rediUri == "" {
		rediUri = this.rediUri
		queries.Set(tagRedirect_uri, rediUri)
		log.Debug(this.logPref, "Use default redirect uri "+rediUri)
	} else {
		log.Debug(this.logPref, "Use redirect uri "+rediUri)
	}

	// state
	stat := this.idGen.String(this.statLen)
	queries.Set(tagState, stat)
	log.Debug(this.logPref, "Use state "+logutil.Mosaic(stat))

	// nonce
	nonc := this.idGen.String(this.noncLen)
	queries.Set(tagNonce, nonc)
	log.Debug(this.logPref, "Use nonce "+logutil.Mosaic(nonc))

	authUri.RawQuery = queries.Encode()

	sess := asession.New(
		this.idGen.String(this.sessLen),
		time.Now().Add(this.sessExpIn),
		req.path(),
		idp,
		ta,
		rediUri,
		stat,
		nonc,
	)
	if err := this.sessDb.Save(sess, sess.Expires().Add(this.sessDbExpIn-this.sessExpIn)); err != nil {
		return erro.Wrap(err)
	}
	log.Info(this.logPref, "Generated user session "+logutil.Mosaic(sess.Id()))

	http.SetCookie(w, this.newCookie(sess.Id(), sess.Expires()))
	w.Header().Add(tagCache_control, tagNo_store)
	w.Header().Add(tagPragma, tagNo_cache)

	uri := authUri.String()
	http.Redirect(w, r, uri, http.StatusFound)
	log.Info(this.logPref, "Redirect to "+uri)
	return nil
}