Example #1
0
func main() {
	// Parse command line arguments
	var (
		config_file = flag.String("config", "", "Path to configuration file")
	)
	flag.Parse()

	// Load configuration into package variable Config
	config_error := gcfg.ReadFileInto(&Config, *config_file)
	if config_error != nil {
		log.Fatal("Could not load config file: " + config_error.Error())
	}

	// Instantiate StatsD connection
	if Config.Statsd.Host == "" {
		StatsD = g2s.Noop()
	} else {
		StatsD, _ = g2s.Dial(Config.Statsd.Protocol, Config.Statsd.Host+":"+Config.Statsd.Port)
	}

	// Log startup
	log.Println("go-airbrake-proxy started")

	// Fire up an HTTP server and handle it
	http.HandleFunc("/", httpHandler)
	http.ListenAndServe(Config.Listen.Host+":"+Config.Listen.Port, nil)
}
func loadStatsd(addr string) g2s.Statter {
	disabled := config.AtPath("hailo", "service", "instrumentation", "statsd", "disabled").AsBool()
	if disabled {
		return g2s.Noop()
	}

	s, err := g2s.Dial("udp", addr)
	if err != nil {
		log.Warnf("Error initialising statsd connection to %v", addr)
		return nil
	}

	return s
}
Example #3
0
func setupStatsd() (g2s.Statter, error) {
	if config.Statsd.Addr == "" {
		return g2s.Noop(), nil
	}

	if config.Statsd.Namespace == "" {
		hostname, _ := os.Hostname()
		config.Statsd.Namespace = "nixy." + hostname
	}

	if config.Statsd.SampleRate < 1 || config.Statsd.SampleRate > 100 {
		config.Statsd.SampleRate = 100
	}

	return g2s.Dial("udp", config.Statsd.Addr)
}
Example #4
0
File: pprof.go Project: gonet2/auth
func init() {
	addr := DEFAULT_STATSD_HOST
	if env := os.Getenv(ENV_STATSD); env != "" {
		addr = env
	}

	s, err := g2s.Dial("udp", addr)
	if err == nil {
		_statter = s
	} else {
		_statter = g2s.Noop()
		log.Println(err)
	}

	go pprof_task()
}
Example #5
0
func TestLDAPUserCache(t *testing.T) {
	Convey("Given an LDAP user cache connected to our server", t, func() {
		// The SSH agent stuff was moved up here so that we can use it to
		// dynamically create the LDAP result object.
		sshSock := os.Getenv("SSH_AUTH_SOCK")
		if sshSock == "" {
			t.Skip()
		}

		c, err := net.Dial("unix", sshSock)
		if err != nil {
			t.Fatal(err)
		}
		agent := agent.NewClient(c)
		keys, err := agent.List()
		if err != nil {
			t.Fatal(err)
		}

		keyValue := base64.StdEncoding.EncodeToString(keys[0].Blob)

		// Load in an additional key from the test data.
		privateKey, _ := ssh.ParsePrivateKey(testKey)
		testPublicKey := base64.StdEncoding.EncodeToString(privateKey.PublicKey().Marshal())

		s := &StubLDAPServer{
			Keys: []string{keyValue, testPublicKey},
		}
		lc, err := server.NewLDAPUserCache(s, g2s.Noop(), "cn", "dc=testdn,dc=com")
		So(err, ShouldBeNil)
		So(lc, ShouldNotBeNil)

		Convey("It should retrieve users from LDAP", func() {
			So(lc.Users(), ShouldNotBeEmpty)
		})

		Convey("It should verify the current user positively.", func() {
			success := false

			for i := 0; i < len(keys); i++ {
				challenge := randomBytes(64)
				sig, err := agent.Sign(keys[i], challenge)
				if err != nil {
					t.Fatal(err)
				}
				verifiedUser, err := lc.Authenticate("ericallen", challenge, sig)
				success = success || (verifiedUser != nil)
			}

			So(success, ShouldEqual, true)
		})

		Convey("When a user is requested that cannot be found in the cache", func() {
			// Use an SSH key we're guaranteed to not have.
			oldKey := s.Keys[0]
			s.Keys[0] = testPublicKey
			lc.Update()

			// Swap the key back and try verifying.
			// We should still get a result back.
			s.Keys[0] = oldKey
			success := false

			for i := 0; i < len(keys); i++ {
				challenge := randomBytes(64)
				sig, err := agent.Sign(keys[i], challenge)
				if err != nil {
					t.Fatal(err)
				}
				verifiedUser, err := lc.Authenticate("ericallen", challenge, sig)
				success = success || (verifiedUser != nil)
			}

			Convey("Then it should update LDAP again and find the user.", func() {
				So(success, ShouldEqual, true)
			})
		})

		Convey("When a user with multiple SSH keys assigned tries to use Hologram", func() {
			Convey("The system should allow them to use any key.", func() {
				success := false

				for i := 0; i < len(keys); i++ {
					challenge := randomBytes(64)
					sig, err := privateKey.Sign(cryptrand.Reader, challenge)
					if err != nil {
						t.Fatal(err)
					}
					verifiedUser, err := lc.Authenticate("ericallen", challenge, sig)
					success = success || (verifiedUser != nil)
				}

				So(success, ShouldEqual, true)

			})
		})

		testAuthorizedKey := string(ssh.MarshalAuthorizedKey(privateKey.PublicKey()))

		s = &StubLDAPServer{
			Keys: []string{testAuthorizedKey},
		}
		lc, err = server.NewLDAPUserCache(s, g2s.Noop(), "cn", "dc=testdn,dc=com")
		So(err, ShouldBeNil)
		So(lc, ShouldNotBeNil)

		Convey("The usercache should understand the SSH authorized_keys format", func() {
			challenge := randomBytes(64)
			sig, err := privateKey.Sign(cryptrand.Reader, challenge)
			if err != nil {
				t.Fatal(err)
			}
			verifiedUser, err := lc.Authenticate("ericallen", challenge, sig)
			So(verifiedUser, ShouldNotBeNil)
			So(err, ShouldBeNil)
		})
	})
}
Example #6
0
func main() {
	var (
		redisInstances             = flag.String("redis.instances", "", "Semicolon-separated list of comma-separated lists of Redis instances")
		redisConnectTimeout        = flag.Duration("redis.connect.timeout", 3*time.Second, "Redis connect timeout")
		redisReadTimeout           = flag.Duration("redis.read.timeout", 3*time.Second, "Redis read timeout")
		redisWriteTimeout          = flag.Duration("redis.write.timeout", 3*time.Second, "Redis write timeout")
		redisMCPI                  = flag.Int("redis.mcpi", 10, "Max connections per Redis instance")
		redisHash                  = flag.String("redis.hash", "murmur3", "Redis hash function: murmur3, fnv, fnva")
		farmWriteQuorum            = flag.String("farm.write.quorum", "51%", "Write quorum, either number of clusters (2) or percentage of clusters (51%)")
		farmReadStrategy           = flag.String("farm.read.strategy", "SendAllReadAll", "Farm read strategy: SendAllReadAll, SendOneReadOne, SendAllReadFirstLinger, SendVarReadFirstLinger")
		farmReadThresholdRate      = flag.Int("farm.read.threshold.rate", 2000, "Baseline SendAll keys read per sec, additional keys are SendOne (SendVarReadFirstLinger strategy only)")
		farmReadThresholdLatency   = flag.Duration("farm.read.threshold.latency", 50*time.Millisecond, "If a SendOne read has not returned anything after this latency, it's promoted to SendAll (SendVarReadFirstLinger strategy only)")
		farmRepairStrategy         = flag.String("farm.repair.strategy", "RateLimitedRepairs", "Farm repair strategy: AllRepairs, NoRepairs, RateLimitedRepairs")
		farmRepairMaxKeysPerSecond = flag.Int("farm.repair.max.keys.per.second", 1000, "Max repaired keys per second (RateLimited repairer only)")
		maxSize                    = flag.Int("max.size", 10000, "Maximum number of events per key")
		selectGap                  = flag.Duration("select.gap", 0*time.Millisecond, "delay between pipeline read invocations when Selecting over multiple keys")
		statsdAddress              = flag.String("statsd.address", "", "Statsd address (blank to disable)")
		statsdSampleRate           = flag.Float64("statsd.sample.rate", 0.1, "Statsd sample rate for normal metrics")
		statsdBucketPrefix         = flag.String("statsd.bucket.prefix", "myservice.", "Statsd bucket key prefix, including trailing period")
		prometheusNamespace        = flag.String("prometheus.namespace", "roshiserver", "Prometheus key namespace, excluding trailing punctuation")
		prometheusMaxSummaryAge    = flag.Duration("prometheus.max.summary.age", 10*time.Second, "Prometheus max age for instantaneous histogram data")
		httpAddress                = flag.String("http.address", ":6302", "HTTP listen address")
	)
	flag.Parse()
	log.SetOutput(os.Stdout)
	log.SetFlags(log.Lmicroseconds)
	log.Printf("GOMAXPROCS %d", runtime.GOMAXPROCS(-1))

	// Set up statsd instrumentation, if it's specified.
	statter := g2s.Noop()
	if *statsdAddress != "" {
		var err error
		statter, err = g2s.Dial("udp", *statsdAddress)
		if err != nil {
			log.Fatal(err)
		}
	}
	prometheusInstr := prometheus.New(*prometheusNamespace, *prometheusMaxSummaryAge)
	prometheusInstr.Install("/metrics", http.DefaultServeMux)
	instr := instrumentation.NewMultiInstrumentation(
		statsd.New(statter, float32(*statsdSampleRate), *statsdBucketPrefix),
		prometheusInstr,
	)

	// Parse read strategy.
	var readStrategy farm.ReadStrategy
	switch strings.ToLower(*farmReadStrategy) {
	case "sendallreadall":
		readStrategy = farm.SendAllReadAll
	case "sendonereadone":
		readStrategy = farm.SendOneReadOne
	case "sendallreadfirstlinger":
		readStrategy = farm.SendAllReadFirstLinger
	case "sendvarreadfirstlinger":
		readStrategy = farm.SendVarReadFirstLinger(*farmReadThresholdRate, *farmReadThresholdLatency)
	default:
		log.Fatalf("unknown read strategy %q", *farmReadStrategy)
	}
	log.Printf("using %s read strategy", *farmReadStrategy)

	// Parse repair strategy. Note that because this is a client-facing
	// production server, all repair strategies get a Nonblocking wrapper!
	repairRequestBufferSize := 100
	var repairStrategy farm.RepairStrategy
	switch strings.ToLower(*farmRepairStrategy) {
	case "allrepairs":
		repairStrategy = farm.Nonblocking(repairRequestBufferSize, farm.AllRepairs)
	case "norepairs":
		repairStrategy = farm.Nonblocking(repairRequestBufferSize, farm.NoRepairs)
	case "ratelimitedrepairs":
		repairStrategy = farm.Nonblocking(repairRequestBufferSize, farm.RateLimited(*farmRepairMaxKeysPerSecond, farm.AllRepairs))
	default:
		log.Fatalf("unknown repair strategy %q", *farmRepairStrategy)
	}
	log.Printf("using %s repair strategy", *farmRepairStrategy)

	// Parse hash function.
	var hashFunc func(string) uint32
	switch strings.ToLower(*redisHash) {
	case "murmur3":
		hashFunc = pool.Murmur3
	case "fnv":
		hashFunc = pool.FNV
	case "fnva":
		hashFunc = pool.FNVa
	default:
		log.Fatalf("unknown hash %q", *redisHash)
	}

	// Build the farm.
	farm, err := newFarm(
		*redisInstances,
		*farmWriteQuorum,
		*redisConnectTimeout, *redisReadTimeout, *redisWriteTimeout,
		*redisMCPI,
		hashFunc,
		readStrategy,
		repairStrategy,
		*maxSize,
		*selectGap,
		instr,
	)
	if err != nil {
		log.Fatal(err)
	}

	// Build the HTTP server.
	r := pat.New()
	r.Add("GET", "/metrics", http.DefaultServeMux)
	r.Add("GET", "/debug", http.DefaultServeMux)
	r.Add("POST", "/debug", http.DefaultServeMux)
	r.Get("/", handleSelect(farm))
	r.Post("/", handleInsert(farm))
	r.Delete("/", handleDelete(farm))
	h := http.Handler(r)

	// Go for it.
	log.Printf("listening on %s", *httpAddress)
	log.Fatal(http.ListenAndServe(*httpAddress, h))
}
Example #7
0
	"sort"
	"strconv"
	"strings"
	"time"

	"github.com/gorilla/pat"
	"github.com/peterbourgon/g2s"
	"github.com/soundcloud/roshi/cluster"
	"github.com/soundcloud/roshi/common"
	"github.com/soundcloud/roshi/farm"
	"github.com/soundcloud/roshi/instrumentation/statsd"
	"github.com/soundcloud/roshi/pool"
)

var (
	stats = g2s.Noop()
	log   = logpkg.New(os.Stdout, "", logpkg.Lmicroseconds)
)

func main() {
	var (
		redisInstances             = flag.String("redis.instances", "", "Semicolon-separated list of comma-separated lists of Redis instances")
		redisConnectTimeout        = flag.Duration("redis.connect.timeout", 3*time.Second, "Redis connect timeout")
		redisReadTimeout           = flag.Duration("redis.read.timeout", 3*time.Second, "Redis read timeout")
		redisWriteTimeout          = flag.Duration("redis.write.timeout", 3*time.Second, "Redis write timeout")
		redisMCPI                  = flag.Int("redis.mcpi", 10, "Max connections per Redis instance")
		redisHash                  = flag.String("redis.hash", "murmur3", "Redis hash function: murmur3, fnv, fnva")
		farmWriteQuorum            = flag.String("farm.write.quorum", "51%", "Write quorum, either number of clusters (2) or percentage of clusters (51%)")
		farmReadStrategy           = flag.String("farm.read.strategy", "SendAllReadAll", "Farm read strategy: SendAllReadAll, SendOneReadOne, SendAllReadFirstLinger, SendVarReadFirstLinger")
		farmReadThresholdRate      = flag.Int("farm.read.threshold.rate", 2000, "Baseline SendAll keys read per sec, additional keys are SendOne (SendVarReadFirstLinger strategy only)")
		farmReadThresholdLatency   = flag.Duration("farm.read.threshold.latency", 50*time.Millisecond, "If a SendOne read has not returned anything after this latency, it's promoted to SendAll (SendVarReadFirstLinger strategy only)")
Example #8
0
func main() {
	// Parse command-line flags for this system.
	var (
		listenAddress    = flag.String("addr", "", "Address to listen to incoming requests on.")
		ldapAddress      = flag.String("ldapAddr", "", "Address to connect to LDAP.")
		ldapBindDN       = flag.String("ldapBindDN", "", "LDAP DN to bind to for login.")
		ldapInsecure     = flag.Bool("insecureLDAP", false, "INSECURE: Don't use TLS for LDAP connection.")
		ldapBindPassword = flag.String("ldapBindPassword", "", "LDAP password for bind.")
		statsdHost       = flag.String("statsHost", "", "Address to send statsd metrics to.")
		iamAccount       = flag.String("iamaccount", "", "AWS Account ID for generating IAM Role ARNs")
		enableLDAPRoles  = flag.Bool("ldaproles", false, "Enable role support using LDAP directory.")
		roleAttribute    = flag.String("roleattribute", "", "Group attribute to get role from.")
		defaultRole      = flag.String("role", "", "AWS role to assume by default.")
		configFile       = flag.String("conf", "/etc/hologram/server.json", "Config file to load.")
		cacheTimeout     = flag.Int("cachetime", 3600, "Time in seconds after which to refresh LDAP user cache.")
		debugMode        = flag.Bool("debug", false, "Enable debug mode.")
		config           Config
	)

	flag.Parse()

	// Enable debug log output if the user requested it.
	if *debugMode {
		log.DebugMode(true)
		log.Debug("Enabling debug log output. Use sparingly.")
	}

	// Parse in options from the given config file.
	log.Debug("Loading configuration from %s", *configFile)
	configContents, configErr := ioutil.ReadFile(*configFile)
	if configErr != nil {
		log.Errorf("Could not read from config file. The error was: %s", configErr.Error())
		os.Exit(1)
	}

	configParseErr := json.Unmarshal(configContents, &config)
	if configParseErr != nil {
		log.Errorf("Error in parsing config file: %s", configParseErr.Error())
		os.Exit(1)
	}

	// Merge in command flag options.
	if *ldapAddress != "" {
		config.LDAP.Host = *ldapAddress
	}

	if *ldapInsecure {
		config.LDAP.InsecureLDAP = true
	}

	if *ldapBindDN != "" {
		config.LDAP.Bind.DN = *ldapBindDN
	}

	if *ldapBindPassword != "" {
		config.LDAP.Bind.Password = *ldapBindPassword
	}

	if *statsdHost != "" {
		config.Stats = *statsdHost
	}

	if *iamAccount != "" {
		config.AWS.Account = *iamAccount
	}

	if *listenAddress != "" {
		config.Listen = *listenAddress
	}

	if *defaultRole != "" {
		config.AWS.DefaultRole = *defaultRole
	}

	if *enableLDAPRoles {
		config.LDAP.EnableLDAPRoles = true
	}

	if *roleAttribute != "" {
		config.LDAP.RoleAttribute = *roleAttribute
	}

	if *cacheTimeout != 3600 {
		config.CacheTimeout = *cacheTimeout
	}

	var stats g2s.Statter
	var statsErr error

	if config.LDAP.UserAttr == "" {
		config.LDAP.UserAttr = "cn"
	}

	if config.Stats == "" {
		log.Debug("No statsd server specified; no metrics will be emitted by this program.")
		stats = g2s.Noop()
	} else {
		stats, statsErr = g2s.Dial("udp", config.Stats)
		if statsErr != nil {
			log.Errorf("Error connecting to statsd: %s. No metrics will be emitted by this program.", statsErr.Error())
			stats = g2s.Noop()
		} else {
			log.Debug("This program will emit metrics to %s", config.Stats)
		}
	}

	// Setup the server state machine that responds to requests.
	auth, err := aws.GetAuth(os.Getenv("HOLOGRAM_AWSKEY"), os.Getenv("HOLOGRAM_AWSSECRET"), "", time.Now())
	if err != nil {
		log.Errorf("Error getting instance credentials: %s", err.Error())
		os.Exit(1)
	}

	stsConnection := sts.New(auth, aws.Regions["us-east-1"])
	credentialsService := server.NewDirectSessionTokenService(config.AWS.Account, stsConnection)

	var ldapServer *ldap.Conn

	// Connect to the LDAP server using TLS or not depending on the config
	if config.LDAP.InsecureLDAP {
		log.Debug("Connecting to LDAP at server %s (NOT using TLS).", config.LDAP.Host)
		ldapServer, err = ldap.Dial("tcp", config.LDAP.Host)
		if err != nil {
			log.Errorf("Could not dial LDAP! %s", err.Error())
			os.Exit(1)
		}
	} else {
		// Connect to the LDAP server with sample credentials.
		tlsConfig := &tls.Config{
			InsecureSkipVerify: true,
		}

		log.Debug("Connecting to LDAP at server %s.", config.LDAP.Host)
		ldapServer, err = ldap.DialTLS("tcp", config.LDAP.Host, tlsConfig)
		if err != nil {
			log.Errorf("Could not dial LDAP! %s", err.Error())
			os.Exit(1)
		}
	}

	if bindErr := ldapServer.Bind(config.LDAP.Bind.DN, config.LDAP.Bind.Password); bindErr != nil {
		log.Errorf("Could not bind to LDAP! %s", bindErr.Error())
		os.Exit(1)
	}

	ldapCache, err := server.NewLDAPUserCache(ldapServer, stats, config.LDAP.UserAttr, config.LDAP.BaseDN, config.LDAP.EnableLDAPRoles, config.LDAP.RoleAttribute)
	if err != nil {
		log.Errorf("Top-level error in LDAPUserCache layer: %s", err.Error())
		os.Exit(1)
	}

	serverHandler := server.New(ldapCache, credentialsService, config.AWS.DefaultRole, stats, ldapServer, config.LDAP.UserAttr, config.LDAP.BaseDN, config.LDAP.EnableLDAPRoles)
	server, err := remote.NewServer(config.Listen, serverHandler.HandleConnection)

	// Wait for a signal from the OS to shutdown.
	terminate := make(chan os.Signal)
	signal.Notify(terminate, syscall.SIGINT, syscall.SIGTERM)

	// SIGUSR1 and SIGUSR2 should make Hologram enable and disable debug logging,
	// respectively.
	debugEnable := make(chan os.Signal)
	debugDisable := make(chan os.Signal)
	signal.Notify(debugEnable, syscall.SIGUSR1)
	signal.Notify(debugDisable, syscall.SIGUSR2)

	// SIGHUP should make Hologram server reload its cache of user information
	// from LDAP.
	reloadCacheSigHup := make(chan os.Signal)
	signal.Notify(reloadCacheSigHup, syscall.SIGHUP)

	// Reload the cache based on time set in configuration
	cacheTimeoutTicker := time.NewTicker(time.Duration(config.CacheTimeout) * time.Second)

	log.Info("Hologram server is online, waiting for termination.")

WaitForTermination:
	for {
		select {
		case <-terminate:
			break WaitForTermination
		case <-debugEnable:
			log.Info("Enabling debug mode.")
			log.DebugMode(true)
		case <-debugDisable:
			log.Info("Disabling debug mode.")
			log.DebugMode(false)
		case <-reloadCacheSigHup:
			log.Info("Force-reloading user cache.")
			ldapCache.Update()
		case <-cacheTimeoutTicker.C:
			log.Info("Cache timeout. Reloading user cache.")
			ldapCache.Update()
		}
	}

	log.Info("Caught signal; shutting down now.")
	server.Close()
}
Example #9
0
func main() {
	var (
		redisInstances          = flag.String("redis.instances", "", "Semicolon-separated list of comma-separated lists of Redis instances")
		redisConnectTimeout     = flag.Duration("redis.connect.timeout", 3*time.Second, "Redis connect timeout")
		redisReadTimeout        = flag.Duration("redis.read.timeout", 3*time.Second, "Redis read timeout")
		redisWriteTimeout       = flag.Duration("redis.write.timeout", 3*time.Second, "Redis write timeout")
		redisMCPI               = flag.Int("redis.mcpi", 2, "Max connections per Redis instance")
		redisHash               = flag.String("redis.hash", "murmur3", "Redis hash function: murmur3, fnv, fnva")
		selectGap               = flag.Duration("select.gap", 0*time.Millisecond, "delay between pipeline read invocations when Selecting over multiple keys")
		maxSize                 = flag.Int("max.size", 10000, "Maximum number of events per key")
		batchSize               = flag.Int("batch.size", 100, "keys to select per request")
		maxKeysPerSecond        = flag.Int64("max.keys.per.second", 1000, "max keys per second to walk")
		scanLogInterval         = flag.Duration("scan.log.interval", 5*time.Second, "how often to report scan rates in log")
		once                    = flag.Bool("once", false, "walk entire keyspace once and exit (default false, walk forever)")
		statsdAddress           = flag.String("statsd.address", "", "Statsd address (blank to disable)")
		statsdSampleRate        = flag.Float64("statsd.sample.rate", 0.1, "Statsd sample rate for normal metrics")
		statsdBucketPrefix      = flag.String("statsd.bucket.prefix", "myservice.", "Statsd bucket key prefix, including trailing period")
		prometheusNamespace     = flag.String("prometheus.namespace", "roshiwalker", "Prometheus key namespace, excluding trailing punctuation")
		prometheusMaxSummaryAge = flag.Duration("prometheus.max.summary.age", 10*time.Second, "Prometheus max age for instantaneous histogram data")
		httpAddress             = flag.String("http.address", ":6060", "HTTP listen address (profiling/metrics endpoints only)")
	)
	flag.Parse()
	log.SetOutput(os.Stdout)
	log.SetFlags(log.Lmicroseconds)

	// Validate integer arguments.
	if *maxKeysPerSecond < int64(*batchSize) {
		log.Fatal("max keys per second should be bigger than batch size")
	}

	// Set up instrumentation.
	statter := g2s.Noop()
	if *statsdAddress != "" {
		var err error
		statter, err = g2s.Dial("udp", *statsdAddress)
		if err != nil {
			log.Fatal(err)
		}
	}
	prometheusInstr := prometheus.New(*prometheusNamespace, *prometheusMaxSummaryAge)
	prometheusInstr.Install("/metrics", http.DefaultServeMux)
	instr := instrumentation.NewMultiInstrumentation(
		statsd.New(statter, float32(*statsdSampleRate), *statsdBucketPrefix),
		prometheusInstr,
	)

	// Parse hash function.
	var hashFunc func(string) uint32
	switch strings.ToLower(*redisHash) {
	case "murmur3":
		hashFunc = pool.Murmur3
	case "fnv":
		hashFunc = pool.FNV
	case "fnva":
		hashFunc = pool.FNVa
	default:
		log.Fatalf("unknown hash %q", *redisHash)
	}

	// Set up the clusters.
	clusters, err := farm.ParseFarmString(
		*redisInstances,
		*redisConnectTimeout, *redisReadTimeout, *redisWriteTimeout,
		*redisMCPI,
		hashFunc,
		*maxSize,
		*selectGap,
		instr,
	)
	if err != nil {
		log.Fatal(err)
	}

	// HTTP server for profiling.
	go func() { log.Print(http.ListenAndServe(*httpAddress, nil)) }()

	// Set up our rate limiter. Remember: it's per-key, not per-request.
	var (
		freq   = time.Duration(1/(*maxKeysPerSecond)) * time.Second
		bucket = tb.NewBucket(*maxKeysPerSecond, freq)
	)

	// Build the farm.
	var (
		readStrategy   = farm.SendAllReadAll
		repairStrategy = farm.AllRepairs // blocking
		writeQuorum    = len(clusters)   // 100%
		dst            = farm.New(clusters, writeQuorum, readStrategy, repairStrategy, instr)
	)

	// Perform the walk.
	defer func(t time.Time) { log.Printf("total walk complete, %s", time.Since(t)) }(time.Now())
	for {
		src := scan(clusters, *batchSize, *scanLogInterval) // new key set
		walkOnce(dst, bucket, src, *maxSize, instr)
		if *once {
			break
		}
	}
}
Example #10
0
func TestServerStateMachine(t *testing.T) {
	// This silly thing is needed for equality testing for the LDAP dummy.
	neededModifyRequest := ldap.NewModifyRequest("something")
	neededModifyRequest.Add("sshPublicKey", []string{"test"})

	Convey("Given a state machine setup with a null logger", t, func() {
		authenticator := &DummyAuthenticator{&server.User{Username: "******"}}
		ldap := &DummyLDAP{
			username: "******",
			password: "******",
			sshKeys:  []string{},
			req:      neededModifyRequest,
		}
		testServer := server.New(authenticator, &dummyCredentials{}, "default", g2s.Noop(), ldap, "cn", "dc=testdn,dc=com", false)
		r, w := io.Pipe()

		testConnection := protocol.NewMessageConnection(ReadWriter(r, w))
		go testServer.HandleConnection(testConnection)
		Convey("When a ping message comes in", func() {
			testPing := &protocol.Message{Ping: &protocol.Ping{}}
			testConnection.Write(testPing)
			Convey("Then the server should respond with a pong response.", func() {
				recvMsg, recvErr := testConnection.Read()
				So(recvErr, ShouldBeNil)
				So(recvMsg.GetPing(), ShouldNotBeNil)
			})
		})

		Convey("After an AssumeRequest", func() {
			role := "testrole"

			msg := &protocol.Message{
				ServerRequest: &protocol.ServerRequest{
					AssumeRole: &protocol.AssumeRole{
						Role: &role,
					},
				},
			}

			testConnection.Write(msg)

			msg, err := testConnection.Read()
			if err != nil {
				t.Fatal(err)
			}

			Convey("it should challenge, then send credentials on success", func() {
				challenge := msg.GetServerResponse().GetChallenge().GetChallenge()

				So(len(challenge), ShouldEqual, 64)

				format := "test"
				sig := []byte("ssss")

				challengeResponseMsg := &protocol.Message{
					ServerRequest: &protocol.ServerRequest{
						ChallengeResponse: &protocol.SSHChallengeResponse{
							Format:    &format,
							Signature: sig,
						},
					},
				}

				testConnection.Write(challengeResponseMsg)

				credsMsg, err := testConnection.Read()
				if err != nil {
					t.Fatal(err)
				}

				So(credsMsg, ShouldNotBeNil)
				So(credsMsg.GetServerResponse(), ShouldNotBeNil)
				So(credsMsg.GetServerResponse().GetCredentials(), ShouldNotBeNil)

				creds := credsMsg.GetServerResponse().GetCredentials()
				So(creds.GetAccessKeyId(), ShouldEqual, "access_key")
				So(creds.GetSecretAccessKey(), ShouldEqual, "secret")
				So(creds.GetAccessToken(), ShouldEqual, "token")
				So(creds.GetExpiration(), ShouldBeGreaterThanOrEqualTo, time.Now().Unix())
			})

			Convey("it should then send failure message on failed key verification", func() {
				authenticator.user = nil

				challenge := msg.GetServerResponse().GetChallenge().GetChallenge()

				So(len(challenge), ShouldEqual, 64)

				format := "test"
				sig := []byte("ssss")

				challengeResponseMsg := &protocol.Message{
					ServerRequest: &protocol.ServerRequest{
						ChallengeResponse: &protocol.SSHChallengeResponse{
							Format:    &format,
							Signature: sig,
						},
					},
				}

				testConnection.Write(challengeResponseMsg)

				credsMsg, err := testConnection.Read()
				if err != nil {
					t.Fatal(err)
				}

				So(credsMsg, ShouldNotBeNil)
				So(credsMsg.GetServerResponse(), ShouldNotBeNil)
				So(credsMsg.GetServerResponse().GetVerificationFailure(), ShouldNotBeNil)
			})
		})

		Convey("When a request to add an SSH key comes in", func() {
			user := "******"
			password := "******"
			sshKey := "test"
			testMessage := &protocol.Message{
				ServerRequest: &protocol.ServerRequest{
					AddSSHkey: &protocol.AddSSHKey{
						Username:     &user,
						Passwordhash: &password,
						Sshkeybytes:  &sshKey,
					},
				},
			}

			testConnection.Write(testMessage)
			Convey("If this request is valid", func() {
				msg, err := testConnection.Read()
				if err != nil {
					t.Fatal(err)
				}

				if msg.GetSuccess() == nil {
					t.Fail()
				}
				Convey("It should add the SSH key to the user.", func() {
					So(ldap.sshKeys[0], ShouldEqual, sshKey)
					Convey("If the user tries to add the same SSH key", func() {
						testConnection.Write(testMessage)
						Convey("It should not insert the same key twice.", func() {
							So(len(ldap.sshKeys), ShouldEqual, 1)
						})
					})
				})
			})
		})
	})
}
Example #11
0
func main() {
	var (
		redisInstances      = flag.String("redis.instances", "", "Semicolon-separated list of comma-separated lists of Redis instances")
		redisConnectTimeout = flag.Duration("redis.connect.timeout", 3*time.Second, "Redis connect timeout")
		redisReadTimeout    = flag.Duration("redis.read.timeout", 3*time.Second, "Redis read timeout")
		redisWriteTimeout   = flag.Duration("redis.write.timeout", 3*time.Second, "Redis write timeout")
		redisMCPI           = flag.Int("redis.mcpi", 2, "Max connections per Redis instance")
		redisHash           = flag.String("redis.hash", "murmur3", "Redis hash function: murmur3, fnv, fnva")
		maxSize             = flag.Int("max.size", 10000, "Maximum number of events per key")
		batchSize           = flag.Int("batch.size", 100, "keys to select per request")
		maxKeysPerSecond    = flag.Int64("max.keys.per.second", 1000, "max keys per second to walk")
		scanLogInterval     = flag.Duration("scan.log.interval", 5*time.Second, "how often to report scan rates in log")
		once                = flag.Bool("once", false, "walk entire keyspace once and exit (default false, walk forever)")
		statsdAddress       = flag.String("statsd.address", "", "Statsd address (blank to disable)")
		statsdSampleRate    = flag.Float64("statsd.sample.rate", 0.1, "Statsd sample rate for normal metrics")
		statsdBucketPrefix  = flag.String("statsd.bucket.prefix", "myservice.", "Statsd bucket key prefix, including trailing period")
		httpAddress         = flag.String("http.address", ":6060", "HTTP listen address (profiling endpoints only)")
	)
	flag.Parse()
	log.SetFlags(log.Lmicroseconds)

	// Validate integer arguments.
	if *maxKeysPerSecond < int64(*batchSize) {
		log.Fatal("max keys per second should be bigger than batch size")
	}

	// Set up statsd instrumentation, if it's specified.
	stats := g2s.Noop()
	if *statsdAddress != "" {
		var err error
		stats, err = g2s.Dial("udp", *statsdAddress)
		if err != nil {
			log.Fatal(err)
		}
	}
	instr := statsd.New(stats, float32(*statsdSampleRate), *statsdBucketPrefix)

	// Parse hash function.
	var hashFunc func(string) uint32
	switch strings.ToLower(*redisHash) {
	case "murmur3":
		hashFunc = pool.Murmur3
	case "fnv":
		hashFunc = pool.FNV
	case "fnva":
		hashFunc = pool.FNVa
	default:
		log.Fatalf("unknown hash '%s'", *redisHash)
	}

	// Set up the clusters.
	clusters, err := makeClusters(
		*redisInstances,
		*redisConnectTimeout, *redisReadTimeout, *redisWriteTimeout,
		*redisMCPI,
		hashFunc,
		*maxSize,
		instr,
	)
	if err != nil {
		log.Fatal(err)
	}

	// HTTP server for profiling
	go func() { log.Print(http.ListenAndServe(*httpAddress, nil)) }()

	// Set up our rate limiter. Remember: it's per-key, not per-request.
	freq := time.Duration(1/(*maxKeysPerSecond)) * time.Second
	bucket := tb.NewBucket(*maxKeysPerSecond, freq)

	// Build the farm
	readStrategy := farm.SendAllReadAll
	repairStrategy := farm.AllRepairs // blocking
	dst := farm.New(clusters, len(clusters), readStrategy, repairStrategy, instr)

	// Perform the walk
	begin := time.Now()
	for {
		src := scan(clusters, *batchSize, *scanLogInterval) // new key set
		walkOnce(dst, bucket, src, *maxSize, instr)
		if *once {
			break
		}
	}
	log.Printf("walk complete in %s", time.Since(begin))
}