Example #1
0
func TestContextBindAndValidate(t *testing.T) {
	spec, api := petstore.NewAPI(t)
	ctx := NewContext(spec, api, nil)
	ctx.router = DefaultRouter(spec, ctx.api)

	request, _ := http.NewRequest("POST", "/pets", nil)
	request.Header.Add("Accept", "*/*")
	request.Header.Add("content-type", "text/html")

	v, ok := context.GetOk(request, ctxBoundParams)
	assert.False(t, ok)
	assert.Nil(t, v)

	ri, _ := ctx.RouteInfo(request)
	data, result := ctx.BindAndValidate(request, ri) // this requires a much more thorough test
	assert.NotNil(t, data)
	assert.NotNil(t, result)

	v, ok = context.GetOk(request, ctxBoundParams)
	assert.True(t, ok)
	assert.NotNil(t, v)

	dd, rr := ctx.BindAndValidate(request, ri)
	assert.Equal(t, data, dd)
	assert.Equal(t, result, rr)
}
Example #2
0
func TestContextInvalidResponseFormat(t *testing.T) {
	ct := "application/x-yaml"
	other := "application/sgml"
	spec, api := petstore.NewAPI(t)
	ctx := NewContext(spec, api, nil)
	ctx.router = DefaultRouter(spec, ctx.api)

	request, _ := http.NewRequest("GET", "http://localhost:8080", nil)
	request.Header.Set(httpkit.HeaderAccept, ct)

	// check there's nothing there
	cached, ok := context.GetOk(request, ctxResponseFormat)
	assert.False(t, ok)
	assert.Empty(t, cached)

	// trigger the parse
	mt := ctx.ResponseFormat(request, []string{other})
	assert.Empty(t, mt)

	// check it was cached
	cached, ok = context.GetOk(request, ctxResponseFormat)
	assert.False(t, ok)
	assert.Empty(t, cached)

	// check if the cast works and fetch from cache too
	mt = ctx.ResponseFormat(request, []string{other})
	assert.Empty(t, mt)
}
Example #3
0
func TestContextValidContentType(t *testing.T) {
	ct := "application/json"
	ctx := NewContext(nil, nil, nil)

	request, _ := http.NewRequest("GET", "http://localhost:8080", nil)
	request.Header.Set(httpkit.HeaderContentType, ct)

	// check there's nothing there
	_, ok := context.GetOk(request, ctxContentType)
	assert.False(t, ok)

	// trigger the parse
	mt, _, err := ctx.ContentType(request)
	assert.NoError(t, err)
	assert.Equal(t, ct, mt)

	// check it was cached
	_, ok = context.GetOk(request, ctxContentType)
	assert.True(t, ok)

	// check if the cast works and fetch from cache too
	mt, _, err = ctx.ContentType(request)
	assert.NoError(t, err)
	assert.Equal(t, ct, mt)
}
Example #4
0
// Tests that the context is cleared or not cleared properly depending on
// the configuration of the router
func TestKeepContext(t *testing.T) {
	func1 := func(w http.ResponseWriter, r *http.Request) {}

	r := NewRouter()
	r.HandleFunc("/", func1).Name("func1")

	req, _ := http.NewRequest("GET", "http://localhost/", nil)
	context.Set(req, "t", 1)

	res := new(http.ResponseWriter)
	r.ServeHTTP(*res, req)

	if _, ok := context.GetOk(req, "t"); ok {
		t.Error("Context should have been cleared at end of request")
	}

	r.KeepContext = true

	req, _ = http.NewRequest("GET", "http://localhost/", nil)
	context.Set(req, "t", 1)

	r.ServeHTTP(*res, req)
	if _, ok := context.GetOk(req, "t"); !ok {
		t.Error("Context should NOT have been cleared at end of request")
	}

}
Example #5
0
func currentServicePipeID(r *http.Request) (string, string) {
	var serviceID, pipeID string
	if v, ok := context.GetOk(r, serviceIDKey); ok {
		serviceID = v.(string)
	}
	if v, ok := context.GetOk(r, pipeIDKey); ok {
		pipeID = v.(string)
	}
	return serviceID, pipeID
}
Example #6
0
func GetFeatures(r *http.Request) map[string]interface{} {
	result := map[string]interface{}{}

	if features, ok := context.GetOk(r, "features"); ok {
		result["features"] = features
	}

	if identifier, ok := context.GetOk(r, "api_key"); ok {
		result["identifier"] = identifier
	}

	result["timestamp"] = time.Now().UTC().Format(time.RFC3339Nano)

	return result
}
Example #7
0
// GetCallID gets the current call ID (if any) from the request context.
func GetCallID(r *http.Request) (int64, bool) {
	if v, present := context.GetOk(r, callID); present {
		id, ok := v.(int64)
		return id, ok
	}
	return 0, false
}
Example #8
0
// Authorize authorizes the request
func (c *Context) Authorize(request *http.Request, route *MatchedRoute) (interface{}, error) {
	if len(route.Authenticators) == 0 {
		return nil, nil
	}
	if v, ok := context.GetOk(request, ctxSecurityPrincipal); ok {
		return v, nil
	}

	var lastError error
	for scheme, authenticator := range route.Authenticators {
		applies, usr, err := authenticator.Authenticate(&security.ScopedAuthRequest{
			Request:        request,
			RequiredScopes: route.Scopes[scheme],
		})
		if !applies || err != nil || usr == nil {
			if err != nil {
				lastError = err
			}
			continue
		}
		context.Set(request, ctxSecurityPrincipal, usr)
		context.Set(request, ctxSecurityScopes, route.Scopes[scheme])
		return usr, nil
	}

	if lastError != nil {
		return nil, lastError
	}

	return nil, errors.Unauthenticated("invalid credentials")
}
Example #9
0
func contextGet(r *http.Request, key string) (interface{}, error) {
	if val, ok := context.GetOk(r, key); ok {
		return val, nil
	}

	return nil, errors.Errorf("no value exists in the context for key %q", key)
}
Example #10
0
func userContext(r *http.Request) (*account.Account, bool) {
	if user, ok := context.GetOk(r, userAccountKey); ok {
		u, ok := user.(*account.Account)
		return u, ok
	}
	return nil, false
}
Example #11
0
func HandleSnapshots(w http.ResponseWriter, r *http.Request) {
	database, ok := context.GetOk(r, "DB")
	if !ok {
		RespondWithError(w, r, errors.New("Could'nt obtain database"), http.StatusInternalServerError)
		return
	}
	db := database.(*mgo.Database)
	params := r.URL.Query()

	d := time.Now()
	if date, ok := params["date"]; ok {
		if t, err := time.Parse(timeLayout, date[0]); err == nil {
			d = t
		}
	}

	snapshot, err := hubspider.FindSnapshotByTime(db, d)
	if err != nil {
		if err == mgo.ErrNotFound {
			RespondWithError(w, r, err, http.StatusNotFound)
			return
		}
		RespondWithError(w, r, err, http.StatusInternalServerError)
		return
	}
	snapshot.ServeJSON(w, r)
}
Example #12
0
func (h *DBConn) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	logger.Info.Println("Setting up the database connection")
	_, ok := context.GetOk(r, consts.DB_KEY)

	if ok {
		errJson := `{"Error":"Internal Server Error"}`
		logger.Error.Println("DBConnector middleware error. DB_KEY already set.")
		http.Error(w, errJson, http.StatusInternalServerError)
		return
	}

	//https://github.com/mattn/go-sqlite3/blob/master/_example/simple/simple.go
	db, err := sql.Open("sqlite3", "./middleware-test.db")

	if err != nil {
		errJson := `{"Error":"Internal Server Error"}`
		logger.Error.Println("DBConnector database connection problem. Check Logs.")
		http.Error(w, errJson, http.StatusInternalServerError)
		return
	}

	context.Set(r, consts.DB_KEY, db)

	//Close database connection once middleware chain is complete
	defer db.Close()

	if h.next != nil {
		h.next.ServeHTTP(w, r)
	}
}
Example #13
0
func main() {
	middle := interpose.New()

	// Create a middleware that yields a global counter that increments until
	// the server is shut down. Note that this would actually require a mutex
	// or a channel to be safe for concurrent use. Therefore, this example is
	// unsafe.
	middle.Use(context.ClearHandler)
	middle.Use(func() func(http.Handler) http.Handler {
		c := 0

		return func(next http.Handler) http.Handler {
			return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
				c++
				context.Set(req, CountKey, c)
				next.ServeHTTP(w, req)
			})
		}
	}())

	// Apply the router.
	router := mux.NewRouter()
	router.HandleFunc("/test/{user}", func(w http.ResponseWriter, req *http.Request) {
		c, ok := context.GetOk(req, CountKey)
		if !ok {
			fmt.Println("Context not ok")
		}

		fmt.Fprintf(w, "Hi %s, this is visit #%d to the site since the server was last rebooted.", mux.Vars(req)["user"], c)

	})
	middle.UseHandler(router)

	http.ListenAndServe(":3001", middle)
}
Example #14
0
// GetPost() returns binded Post from POST data
func GetPost(r *http.Request) (Post, error) {
	rv, ok := context.GetOk(r, "post")
	if !ok {
		return Post{}, errors.New("context not set")
	}
	return rv.(Post), nil
}
Example #15
0
//
//	HandleWith() constructs a handler and inserts it into
//	the mux router.
//
//	It does the following to each request:
//
//	Request Body JSON:
//  - If blank, fills body with {}
//  - Validates based on the schema provided
//		(default: dozy.PresetJsonSchemaBlank)
//
//	Query Strings:
//  - Converts all query strings, except authtoken, to a
//		map[string][]string
//
//	Rate Limiting:
//  - Handles rate limitng on requests with the authtoken in a
//	  URL query string called "authtoken"
//		- Because it only rate handles requests with "authtoken,"
//			it's especially important to use the .Query() function
//			provided by gorilla/mux as much as possible
//
//	Other:
//  - Writes the data specified by the user's handler function,
//		ensuring that only allowed HTTP response codes are used
//		and that no response is sent on error
//  - Sets the "Content-Type" header to "application/json"
//	- Manages the LOCATION header on a dozy.StatusPostOk return
//    value
//	- Ensures that Content-Type from client is correct
//
//	This method **must** be used with a dozy.ServeMux()-handled
//  mux router. dozy.ServeMux() attaches some neccesary data
//  values to the request, which are needed for the proper method
//  of the HandleWith() and the HandleSpecialCost() functions. However,
//  you can use a non dozy.HandleWith() method with the ServeMux()
//  method, only loosing the functionality added with the HandleWith()
//  method itself.
//
func (h handlerBuilder) HandleWith(method, path string, userHandler UserHandler) {
	method = strings.ToUpper(method)
	if (method == "GET" || method == "DELETE") && h.bodySchema != PresetJsonSchemaBlank {
		panic("Method type " + method + " should not have a body set!")
	}

	builtHandler := func(rw http.ResponseWriter, req *http.Request) {
		var cost uint

		if h.customCost {
			rawSettings, rawSettingsOk := context.GetOk(req, dozySettingsKey("settings"))
			settings, settingsOk := rawSettings.(*Settings)
			if !rawSettingsOk {
				panic("rawSettings is *not* okay!")
			}
			if !settingsOk {
				panic("settings is *not* okay!")
			}

			cost = settings.DefaultCost
		} else {
			cost = h.cost
		}

		handle(h.bodySchema, userHandler, cost, rw, req)
	}

	route := h.muxRouter.HandleFunc(path, builtHandler).Methods(method)
	if len(h.queries) != 0 {
		route.Queries(h.queries...)
	}
}
Example #16
0
// Value returns Gorilla's context package's value for this Context's request
// and key. It delegates to the parent Context if there is no such value.
func (ctx *ctxWrapper) Value(key interface{}) interface{} {
	if val, ok := gorillactx.GetOk(ctx.req, key); ok {
		return val
	}

	return ctx.Context.Value(key)
}
Example #17
0
func main() {
	mw := interpose.New()

	// Set a random integer everytime someone loads the page
	mw.Use(context.ClearHandler)
	mw.Use(func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
			c := rand.Int()
			fmt.Println("Setting ctx count to:", c)
			context.Set(req, CountKey, c)
			next.ServeHTTP(w, req)
		})
	})

	// Apply the router.
	router := mux.NewRouter()
	router.HandleFunc("/{user}", func(w http.ResponseWriter, req *http.Request) {
		c, ok := context.GetOk(req, CountKey)
		if !ok {
			fmt.Println("Get not ok")
		}

		fmt.Fprintf(w, "Welcome to the home page, %s!\nCount:%d", mux.Vars(req)["user"], c)

	})
	mw.UseHandler(router)

	// Launch and permit graceful shutdown, allowing up to 10 seconds for existing
	// connections to end
	graceful.Run(":3001", 10*time.Second, mw)
}
Example #18
0
// iterateAddHeaders is a helper functino that will iterate of a map and inject the key and value as a header in the request.
// if the key and value contain a tyk session variable reference, then it will try to inject the value
func (t *TransformHeaders) iterateAddHeaders(kv map[string]string, r *http.Request) {
	// Get session data
	ses, found := context.GetOk(r, SessionData)
	var thisSessionState SessionState
	if found {
		thisSessionState = ses.(SessionState)
	}

	// Iterate and manage key array injection
	for nKey, nVal := range kv {
		if strings.Contains(nVal, TYK_META_LABEL) {
			// Using meta_data key
			log.Debug("Meta data key in use")
			if found {
				metaKey := strings.Replace(nVal, TYK_META_LABEL, "", 1)
				if thisSessionState.MetaData != nil {
					tempVal, ok := thisSessionState.MetaData.(map[string]interface{})[metaKey]
					if ok {
						nVal = tempVal.(string)
						r.Header.Add(nKey, nVal)
					} else {
						log.Warning("Session Meta Data not found for key in map: ", metaKey)
					}

				} else {
					log.Debug("Meta data object is nil! Skipping.")
				}
			}

		} else {
			r.Header.Add(nKey, nVal)
		}
	}
}
Example #19
0
func GetSettings(r *http.Request) (Vertigo, error) {
	rv, ok := context.GetOk(r, "settings")
	if !ok {
		return Vertigo{}, errors.New("context not set")
	}
	return rv.(Vertigo), nil
}
Example #20
0
// GetSearch() returns binded Search from POST data
func GetSearch(r *http.Request) (Search, error) {
	rv, ok := context.GetOk(r, "search")
	if !ok {
		return Search{}, errors.New("context not set")
	}
	return rv.(Search), nil
}
Example #21
0
func GetUser(r *http.Request) (User, error) {
	rv, ok := context.GetOk(r, "user")
	if !ok {
		return User{}, errors.New("context not set")
	}
	return rv.(User), nil
}
Example #22
0
// Get the user from the request context and check if it's still valid.
func GetCurrentUser(r *http.Request) (*account.User, error) {
	user, ok := context.GetOk(r, CurrentUser)
	if !ok {
		return nil, errors.ErrLoginRequired
	}
	return user.(*account.User), nil
}
func (db *DB) CheckRequest(request *http.Request) (bool, int64) {

	// Get the token parsed by the jwt middleware
	userToken, present := context.GetOk(request, "user")

	if !present {

		return false, -1
	}

	// Get the informations contained inside the jwt
	tokenInformations := userToken.(*jwt.Token).Claims

	// Parse the user id
	id, err := strconv.ParseInt(tokenInformations["id"].(string), 10, 64)

	if err != nil {
		return false, 0
	}

	// Calculate the hash of the random code inside the jwt
	codeHash := sha1.Sum([]byte(tokenInformations["c"].(string)))

	// Create the db session object
	session := Session{
		UserId:      id,
		CodeHash:    codeHash[0:],
		SessionType: tokenInformations["t"].(string),
	}

	// Check if the session is valid
	valid := db.CheckSession(&session)

	return valid, id
}
Example #24
0
// Get an error from the request context.
// Return nil and false if nothing is found.
// Otherwise, return the error.
func GetRequestError(r *http.Request) (error, bool) {
	val, ok := context.GetOk(r, ErrRequestKey)
	if !ok {
		return nil, false
	}
	return val.(error), true
}
// ProcessRequest will run any checks on the request on the way through the system, return an error to have the chain fail
func (k *KeyExpired) ProcessRequest(w http.ResponseWriter, r *http.Request, configuration interface{}) (error, int) {
	sess, ok := context.GetOk(r, SessionData)

	if !ok {
		return errors.New("Session state is missing or unset! Please make sure that auth headers are properly applied"), 403
	}

	thisSessionState := sess.(SessionState)

	if thisSessionState.IsInactive {
		authHeaderValue := context.Get(r, AuthHeaderValue).(string)
		log.WithFields(logrus.Fields{
			"path":   r.URL.Path,
			"origin": GetIPFromRequest(r),
			"key":    authHeaderValue,
		}).Info("Attempted access from inactive key.")

		// Fire a key expired event
		go k.TykMiddleware.FireEvent(EVENT_KeyExpired,
			EVENT_KeyExpiredMeta{
				EventMetaDefault: EventMetaDefault{Message: "Attempted access from inactive key.", OriginatingRequest: EncodeRequestToEvent(r)},
				Path:             r.URL.Path,
				Origin:           GetIPFromRequest(r),
				Key:              authHeaderValue,
			})

		// Report in health check
		ReportHealthCheckValue(k.Spec.Health, KeyFailure, "-1")

		return errors.New("Key is inactive, please renew"), 403
	}

	keyExpired := k.Spec.AuthManager.IsKeyExpired(&thisSessionState)

	if keyExpired {
		authHeaderValue := context.Get(r, AuthHeaderValue).(string)
		log.WithFields(logrus.Fields{
			"path":   r.URL.Path,
			"origin": GetIPFromRequest(r),
			"key":    authHeaderValue,
		}).Info("Attempted access from expired key.")

		// Fire a key expired event
		go k.TykMiddleware.FireEvent(EVENT_KeyExpired,
			EVENT_KeyExpiredMeta{
				EventMetaDefault: EventMetaDefault{Message: "Attempted access from expired key."},
				Path:             r.URL.Path,
				Origin:           GetIPFromRequest(r),
				Key:              authHeaderValue,
			})

		// Report in health check
		ReportHealthCheckValue(k.Spec.Health, KeyFailure, "-1")

		return errors.New("Key has expired, please renew"), 403
	}

	return nil, 200
}
Example #26
0
func GetCurrentUser(r *http.Request) string {
	user, ok := context.GetOk(r, COOKIE_USER_ARG)
	if !ok {
		return "UNKNOWN"
	} else {
		return user.(string)
	}
}
Example #27
0
File: wrapper.go Project: gourd/kit
// Value returns Gorilla's context package's value for this Context's request
// and key. It delegates to the parent Context if there is no such value.
func (ctx *wrapper) Value(key interface{}) interface{} {
	if key == reqKey {
		// do nothing, fall to Context.Value
	} else if val, ok := gcontext.GetOk(HTTPRequest(ctx.Context), key); ok {
		return val
	}
	return ctx.Context.Value(key)
}
Example #28
0
func Paginate(fn func(http.ResponseWriter, *http.Request, url.Values, int, int)) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		params := r.URL.Query()
		if duuid, ok := context.GetOk(r, "duuid"); ok {
			params["duuid"] = []string{duuid.(string)}
		}
		if domain, ok := context.GetOk(r, "domain"); ok {
			params["domain"] = []string{domain.(string)}
		}
		limit := getInt(params.Get("limit"), 15)
		if limit > 50 {
			limit = 50
		}
		offset := limit * getInt(params.Get("page"), 0)
		fn(w, r, params, limit, offset)
	}
}
Example #29
0
func getInstance(r *http.Request, logIfNotPresent bool) (id int) {
	if v, present := context.GetOk(r, instanceID); present {
		id, _ = v.(int)
	} else if logIfNotPresent {
		log.Printf("warn: no instanceID set for request %q (is the app base handler wrapped with InstantiateApp and are clients sending an X-Track-View header?)", r.RequestURI)
	}
	return
}
Example #30
0
// Uses gorilla/context. Returns the legs for the request. Automatically parses
// the legs if they have not been parsed yet and caches the outcome using
// gorilla/context. If you aren't yourself using gorilla/context, remember to
// set it up to clear context data for exiting requests.
func Legs(req *http.Request) []originfuncs.Leg {
	v, ok := context.GetOk(req, &legsKey)
	if !ok {
		v = originfuncs.Parse(req)
		context.Set(req, &legsKey, v)
	}
	return v.([]originfuncs.Leg)
}