예제 #1
0
func walkthroughService(ec *ab.EntityController, search *search.SearchService, baseurl string) ab.Service {
	h := &walkthroughEntityResourceHelper{
		controller: ec,
	}

	res := ab.EntityResource(ec, &Walkthrough{}, ab.EntityResourceConfig{
		PostMiddlewares:      []func(http.Handler) http.Handler{userLoggedInMiddleware},
		PutMiddlewares:       []func(http.Handler) http.Handler{userLoggedInMiddleware},
		DeleteMiddlewares:    []func(http.Handler) http.Handler{userLoggedInMiddleware},
		EntityResourceLister: h,
		EntityResourceLoader: h,
	})

	res.ExtraEndpoints = func(srv *ab.Server) error {
		reindexing := false
		var reindexingMutex sync.RWMutex
		srv.Post("/api/reindexwalkthroughs", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			reindexingMutex.RLock()
			idxing := reindexing
			reindexingMutex.RUnlock()

			if idxing {
				ab.Fail(http.StatusServiceUnavailable, errors.New("reindexing is in progress"))
			}

			reindexingMutex.Lock()
			reindexing = true
			reindexingMutex.Unlock()

			db := ab.GetDB(r)

			go func() {
				defer func() {
					reindexingMutex.Lock()
					reindexing = false
					reindexingMutex.Unlock()
				}()
				err := search.PurgeIndex()
				if err != nil {
					log.Println(err)
					return
				}

				wts, err := LoadAllActualWalkthroughs(db, ec, 0, 0)
				if err != nil {
					log.Println(err)
					return
				}

				for _, wt := range wts {
					err = search.IndexEntity("walkthrough", wt)
					if err != nil {
						log.Println(err)
						return
					}
				}
			}()

			ab.Render(r).SetCode(http.StatusAccepted)
		}), ab.RestrictPrivateAddressMiddleware())

		srv.Get("/api/mysites", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			db := ab.GetDB(r)
			uid := ab.GetSession(r)["uid"]

			rows, err := db.Query("SELECT DISTINCT steps->0->'arg0' AS site FROM walkthrough WHERE uid = $1 AND published ORDER BY site", uid)
			ab.MaybeFail(http.StatusInternalServerError, err)
			defer rows.Close()

			sites := []string{}

			for rows.Next() {
				var site sql.NullString
				err = rows.Scan(&site)
				ab.MaybeFail(http.StatusInternalServerError, err)
				if site.Valid {
					siteName := site.String

					// strip surrounding "
					siteName = siteName[1:]
					siteName = siteName[:len(siteName)-1]

					sites = append(sites, siteName)
				}
			}

			ab.Render(r).JSON(sites)
		}), userLoggedInMiddleware)

		return nil
	}

	res.AddPostEvent(ab.ResourceEventCallback{
		BeforeCallback: func(r *http.Request, d ab.Resource) {
			wt := d.(*Walkthrough)
			uid := UserDelegate.CurrentUser(r)
			if wt.UID == "" {
				wt.UID = uid
			}
			if wt.UID != uid {
				ab.Fail(http.StatusBadRequest, errors.New("invalid user id"))
			}

			wt.Updated = time.Now()
			wt.Revision = ""
			wt.UUID = ""
		},
		AfterCallback: func(r *http.Request, d ab.Resource) {
			db := ab.GetDB(r)
			wt := d.(*Walkthrough)
			search.IndexEntity("walkthrough", wt)
			userEntity, err := ec.Load(db, "user", wt.UID)
			if err != nil {
				log.Println(err)
				return
			}
			user := userEntity.(*User)
			startURL := ""
			if len(wt.Steps) > 0 && wt.Steps[0].Command == "open" {
				startURL = wt.Steps[0].Arg0
			}
			message := fmt.Sprintf("%s has recorded a Walkthrough (<%s|%s>) on %s",
				user.Mail,
				baseurl+"walkthrough/"+wt.UUID,
				html.EscapeString(wt.Name),
				html.EscapeString(startURL),
			)
			DBLog(db, ec, "walkthroughrecord", message)
		},
	})

	res.AddPutEvent(ab.ResourceEventCallback{
		BeforeCallback: func(r *http.Request, d ab.Resource) {
			db := ab.GetDB(r)
			wt := d.(*Walkthrough)
			uid := UserDelegate.CurrentUser(r)
			currentUserEntity, err := ec.Load(db, "user", uid)
			ab.MaybeFail(http.StatusBadRequest, err)
			currentUser := currentUserEntity.(*User)
			if wt.UID != uid {
				if !currentUser.Admin {
					ab.Fail(http.StatusForbidden, nil)
				}
			}

			previousRevision, err := LoadActualRevision(db, ec, wt.UUID)
			ab.MaybeFail(http.StatusBadRequest, err)
			if previousRevision == nil {
				ab.Fail(http.StatusNotFound, nil)
			}

			if previousRevision.UID != uid && !currentUser.Admin {
				ab.Fail(http.StatusForbidden, nil)
			}

			wt.Updated = time.Now()
			wt.Revision = ""
		},
		AfterCallback: func(r *http.Request, d ab.Resource) {
			search.IndexEntity("walkthrough", d.(*Walkthrough))
		},
	})

	res.AddDeleteEvent(ab.ResourceEventCallback{
		InsideCallback: func(r *http.Request, d ab.Resource) {
			db := ab.GetDB(r)
			uid := UserDelegate.CurrentUser(r)
			wt := d.(*Walkthrough)
			currentUserEntity, err := ec.Load(db, "user", uid)
			ab.MaybeFail(http.StatusBadRequest, err)
			currentUser := currentUserEntity.(*User)
			if wt.UID != uid {
				if !currentUser.Admin {
					ab.Fail(http.StatusForbidden, nil)
				}
			}
		},
	})

	return res
}
예제 #2
0
func (s *WalkhubServer) Start(addr string, certfile string, keyfile string) error {
	frontendPaths := []string{
		"/",
		"/connect",
		"/record",
		"/walkthrough/:uuid",
		"/search",
		"/embedcode",
		"/helpcenterlist",
		"/profile/:uuid",
	}
	for _, path := range append(frontendPaths, s.CustomPaths...) {
		s.GetF(path, handleIndex)
	}

	ec := ab.NewEntityController(s.GetDBConnection())
	ec.
		Add(&User{}, userEntityDelegate{}).
		Add(&Walkthrough{}, walkthroughEntityDelegate{}).
		Add(&Screening{}, nil).
		Add(&EmbedLog{}, nil).
		Add(&Log{}, nil)

	if mailchimpClient := createMailchimpClient(s.cfg, s.Logger); mailchimpClient != nil {
		ec.AddInsertEvent(mailchimpClient)
	}

	s.Options("/*path", corsPreflightHandler(s.BaseURL, s.HTTPOrigin))

	s.Use(corsMiddleware(s.BaseURL, s.HTTPOrigin))

	UserDelegate.DB = s.GetDBConnection()

	authProviders := []auth.AuthProvider{}
	if s.PWAuth {
		smtpAuth := smtp.PlainAuth(s.AuthCreds.SMTP.Identity, s.AuthCreds.SMTP.Username, s.AuthCreds.SMTP.Password, s.AuthCreds.SMTP.Host)
		delegate := auth.NewPasswordAuthSMTPEmailSenderDelegate(s.AuthCreds.SMTP.Addr, smtpAuth, s.BaseURL)
		delegate.From = s.AuthCreds.SMTP.From
		delegate.RegistrationEmailTemplate = regMailTemplate
		delegate.LostPasswordEmailTemplate = lostpwMailTemplate
		pwauth := auth.NewPasswordAuthProvider(ec, NewPasswordDelegate(s.GetDBConnection(), ec), delegate)
		authProviders = append(authProviders, pwauth)
	}
	if !s.AuthCreds.Google.Empty() {
		gauth := google.NewGoogleAuthProvider(ec, s.AuthCreds.Google, &GoogleUserDelegate{})
		authProviders = append(authProviders, gauth)
	}
	if len(authProviders) == 0 {
		return errors.New("no authentication providers are enabled")
	}
	authsvc := auth.NewService(s.BaseURL, UserDelegate, s.GetDBConnection(), authProviders...)
	s.RegisterService(authsvc)

	s.RegisterService(userService(ec))

	searchsvc := search.NewSearchService(s.GetDBConnection(), nil)
	searchsvc.AddDelegate("walkthrough", &walkhubSearchDelegate{
		db:         s.GetDBConnection(),
		controller: ec,
	})
	s.RegisterService(searchsvc)

	s.RegisterService(walkthroughService(ec, searchsvc, s.BaseURL))

	s.RegisterService(embedlogService(ec))

	s.RegisterService(logService(ec, s.BaseURL))

	metricsRestrictAddressMiddleware := ab.RestrictPrivateAddressMiddleware()
	if addresses := s.cfg.GetString("metricsaddresses"); addresses != "" {
		addresslist := strings.Split(addresses, ",")
		s.Logger.User().Printf("access to metrics from: %v\n", addresslist)
		metricsRestrictAddressMiddleware = ab.RestrictAddressMiddleware(addresslist...)
	}
	s.Get("/metrics", stdprometheus.Handler(), metricsRestrictAddressMiddleware)

	siteinfoBaseURLs := []string{s.BaseURL}
	if s.HTTPOrigin != "" {
		siteinfoBaseURLs = append(siteinfoBaseURLs, s.HTTPOrigin)
	}

	s.RegisterService(NewSiteinfoService(siteinfoBaseURLs...))

	s.RegisterService(screeningService(ec))

	if certfile != "" && keyfile != "" {
		s.setupHTTPS()
		if s.TLSConfig == nil {
			s.TLSConfig = &tls.Config{}
		}

		if s.TLSConfig.ServerName == "" {
			s.TLSConfig.ServerName = s.BaseURL
		}
	} else if host := s.cfg.GetString("letsencrypthost"); host != "" {
		s.setupHTTPS()
		s.EnableLetsEncrypt("", host)
	}

	return s.StartHTTPS(addr, certfile, keyfile)
}