예제 #1
0
func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error {
	if err := conn.Handshake(); err != nil {
		return err
	}

	cs := conn.ConnectionState()
	if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != protocol.ProtocolName {
		return fmt.Errorf("protocol negotiation error")
	}

	q := uri.Query()
	relayIDs := q.Get("id")
	if relayIDs != "" {
		relayID, err := syncthingprotocol.DeviceIDFromString(relayIDs)
		if err != nil {
			return fmt.Errorf("relay address contains invalid verification id: %s", err)
		}

		certs := cs.PeerCertificates
		if cl := len(certs); cl != 1 {
			return fmt.Errorf("unexpected certificate count: %d", cl)
		}

		remoteID := syncthingprotocol.NewDeviceID(certs[0].Raw)
		if remoteID != relayID {
			return fmt.Errorf("relay id does not match. Expected %v got %v", relayID, remoteID)
		}
	}

	return nil
}
예제 #2
0
func TestManyPeers(t *testing.T) {
	log.Println("Cleaning...")
	err := removeAll("s1", "s2", "h1/index*", "h2/index*")
	if err != nil {
		t.Fatal(err)
	}

	log.Println("Generating files...")
	err = generateFiles("s1", 200, 20, "../LICENSE")
	if err != nil {
		t.Fatal(err)
	}

	receiver := startInstance(t, 2)
	defer checkedStop(t, receiver)

	bs, err := receiver.Get("/rest/system/config")
	if err != nil {
		t.Fatal(err)
	}

	var cfg config.Configuration
	if err := json.Unmarshal(bs, &cfg); err != nil {
		t.Fatal(err)
	}

	for len(cfg.Devices) < 100 {
		bs := make([]byte, 16)
		ReadRand(bs)
		id := protocol.NewDeviceID(bs)
		cfg.Devices = append(cfg.Devices, config.DeviceConfiguration{DeviceID: id})
		cfg.Folders[0].Devices = append(cfg.Folders[0].Devices, config.FolderDeviceConfiguration{DeviceID: id})
	}

	osutil.Rename("h2/config.xml", "h2/config.xml.orig")
	defer osutil.Rename("h2/config.xml.orig", "h2/config.xml")

	var buf bytes.Buffer
	json.NewEncoder(&buf).Encode(cfg)
	_, err = receiver.Post("/rest/system/config", &buf)
	if err != nil {
		t.Fatal(err)
	}

	sender := startInstance(t, 1)
	defer checkedStop(t, sender)

	rc.AwaitSync("default", sender, receiver)

	log.Println("Comparing directories...")
	err = compareDirectories("s1", "s2")
	if err != nil {
		t.Fatal(err)
	}
}
예제 #3
0
func (c *idCheckingHTTPClient) check(resp *http.Response) error {
	if resp.TLS == nil {
		return errors.New("security: not TLS")
	}

	if len(resp.TLS.PeerCertificates) == 0 {
		return errors.New("security: no certificates")
	}

	id := protocol.NewDeviceID(resp.TLS.PeerCertificates[0].Raw)
	if !id.Equals(c.id) {
		return errors.New("security: incorrect device id")
	}

	return nil
}
예제 #4
0
func TestRelay(uri *url.URL, certs []tls.Certificate, sleep time.Duration, times int) bool {
	id := syncthingprotocol.NewDeviceID(certs[0].Certificate[0])
	invs := make(chan protocol.SessionInvitation, 1)
	c := NewProtocolClient(uri, certs, invs)
	go c.Serve()
	defer func() {
		close(invs)
		c.Stop()
	}()

	for i := 0; i < times; i++ {
		_, err := GetInvitationFromRelay(uri, id, certs)
		if err == nil {
			return true
		}
		if !strings.Contains(err.Error(), "Incorrect response code") {
			return false
		}
		time.Sleep(sleep)
	}
	return false
}
예제 #5
0
func TestManyPeers(t *testing.T) {
	log.Println("Cleaning...")
	err := removeAll("s1", "s2", "h1/index", "h2/index")
	if err != nil {
		t.Fatal(err)
	}

	log.Println("Generating files...")
	err = generateFiles("s1", 200, 20, "../LICENSE")
	if err != nil {
		t.Fatal(err)
	}

	receiver := syncthingProcess{ // id2
		instance: "2",
		argv:     []string{"-home", "h2"},
		port:     8082,
		apiKey:   apiKey,
	}
	err = receiver.start()
	if err != nil {
		t.Fatal(err)
	}
	defer receiver.stop()

	resp, err := receiver.get("/rest/config")
	if err != nil {
		t.Fatal(err)
	}
	if resp.StatusCode != 200 {
		t.Fatalf("Code %d != 200", resp.StatusCode)
	}

	var cfg config.Configuration
	json.NewDecoder(resp.Body).Decode(&cfg)
	resp.Body.Close()

	for len(cfg.Devices) < 100 {
		bs := make([]byte, 16)
		ReadRand(bs)
		id := protocol.NewDeviceID(bs)
		cfg.Devices = append(cfg.Devices, config.DeviceConfiguration{DeviceID: id})
		cfg.Folders[0].Devices = append(cfg.Folders[0].Devices, config.FolderDeviceConfiguration{DeviceID: id})
	}

	osutil.Rename("h2/config.xml", "h2/config.xml.orig")
	defer osutil.Rename("h2/config.xml.orig", "h2/config.xml")

	var buf bytes.Buffer
	json.NewEncoder(&buf).Encode(cfg)
	resp, err = receiver.post("/rest/config", &buf)
	if err != nil {
		t.Fatal(err)
	}
	if resp.StatusCode != 200 {
		t.Fatalf("Code %d != 200", resp.StatusCode)
	}
	resp.Body.Close()

	log.Println("Starting up...")
	sender := syncthingProcess{ // id1
		instance: "1",
		argv:     []string{"-home", "h1"},
		port:     8081,
		apiKey:   apiKey,
	}
	err = sender.start()
	if err != nil {
		t.Fatal(err)
	}
	defer sender.stop()

	for {
		comp, err := sender.peerCompletion()
		if err != nil {
			if isTimeout(err) {
				time.Sleep(250 * time.Millisecond)
				continue
			}
			t.Fatal(err)
		}
		if comp[id2] == 100 {
			return
		}
		time.Sleep(2 * time.Second)
	}

	log.Println("Comparing directories...")
	err = compareDirectories("s1", "s2")
	if err != nil {
		t.Fatal(err)
	}
}
예제 #6
0
파일: main.go 프로젝트: beride/syncthing
func syncthingMain() {
	// Create a main service manager. We'll add things to this as we go along.
	// We want any logging it does to go through our log system.
	mainSvc := suture.New("main", suture.Spec{
		Log: func(line string) {
			if debugSuture {
				l.Debugln(line)
			}
		},
	})
	mainSvc.ServeBackground()

	// Set a log prefix similar to the ID we will have later on, or early log
	// lines look ugly.
	l.SetPrefix("[start] ")

	if auditEnabled {
		startAuditing(mainSvc)
	}

	if verbose {
		mainSvc.Add(newVerboseSvc())
	}

	// Event subscription for the API; must start early to catch the early events.
	apiSub := events.NewBufferedSubscription(events.Default.Subscribe(events.AllEvents), 1000)

	if len(os.Getenv("GOMAXPROCS")) == 0 {
		runtime.GOMAXPROCS(runtime.NumCPU())
	}

	// Ensure that that we have a certificate and key.
	cert, err := tls.LoadX509KeyPair(locations[locCertFile], locations[locKeyFile])
	if err != nil {
		cert, err = newCertificate(locations[locCertFile], locations[locKeyFile], tlsDefaultCommonName)
		if err != nil {
			l.Fatalln("load cert:", err)
		}
	}

	// We reinitialize the predictable RNG with our device ID, to get a
	// sequence that is always the same but unique to this syncthing instance.
	predictableRandom.Seed(seedFromBytes(cert.Certificate[0]))

	myID = protocol.NewDeviceID(cert.Certificate[0])
	l.SetPrefix(fmt.Sprintf("[%s] ", myID.String()[:5]))

	l.Infoln(LongVersion)
	l.Infoln("My ID:", myID)

	// Emit the Starting event, now that we know who we are.

	events.Default.Log(events.Starting, map[string]string{
		"home": baseDirs["config"],
		"myID": myID.String(),
	})

	// Prepare to be able to save configuration

	cfgFile := locations[locConfigFile]

	var myName string

	// Load the configuration file, if it exists.
	// If it does not, create a template.

	if info, err := os.Stat(cfgFile); err == nil {
		if !info.Mode().IsRegular() {
			l.Fatalln("Config file is not a file?")
		}
		cfg, err = config.Load(cfgFile, myID)
		if err == nil {
			myCfg := cfg.Devices()[myID]
			if myCfg.Name == "" {
				myName, _ = os.Hostname()
			} else {
				myName = myCfg.Name
			}
		} else {
			l.Fatalln("Configuration:", err)
		}
	} else {
		l.Infoln("No config file; starting with empty defaults")
		myName, _ = os.Hostname()
		newCfg := defaultConfig(myName)
		cfg = config.Wrap(cfgFile, newCfg)
		cfg.Save()
		l.Infof("Edit %s to taste or use the GUI\n", cfgFile)
	}

	if cfg.Raw().OriginalVersion != config.CurrentVersion {
		l.Infoln("Archiving a copy of old config file format")
		// Archive a copy
		osutil.Rename(cfgFile, cfgFile+fmt.Sprintf(".v%d", cfg.Raw().OriginalVersion))
		// Save the new version
		cfg.Save()
	}

	if err := checkShortIDs(cfg); err != nil {
		l.Fatalln("Short device IDs are in conflict. Unlucky!\n  Regenerate the device ID of one if the following:\n  ", err)
	}

	if len(profiler) > 0 {
		go func() {
			l.Debugln("Starting profiler on", profiler)
			runtime.SetBlockProfileRate(1)
			err := http.ListenAndServe(profiler, nil)
			if err != nil {
				l.Fatalln(err)
			}
		}()
	}

	// The TLS configuration is used for both the listening socket and outgoing
	// connections.

	tlsCfg := &tls.Config{
		Certificates:           []tls.Certificate{cert},
		NextProtos:             []string{bepProtocolName},
		ClientAuth:             tls.RequestClientCert,
		SessionTicketsDisabled: true,
		InsecureSkipVerify:     true,
		MinVersion:             tls.VersionTLS12,
		CipherSuites: []uint16{
			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
			tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
			tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
			tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
			tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
			tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
		},
	}

	// If the read or write rate should be limited, set up a rate limiter for it.
	// This will be used on connections created in the connect and listen routines.

	opts := cfg.Options()

	if !opts.SymlinksEnabled {
		symlinks.Supported = false
	}

	protocol.PingTimeout = time.Duration(opts.PingTimeoutS) * time.Second
	protocol.PingIdleTime = time.Duration(opts.PingIdleTimeS) * time.Second

	if opts.MaxSendKbps > 0 {
		writeRateLimit = ratelimit.NewBucketWithRate(float64(1000*opts.MaxSendKbps), int64(5*1000*opts.MaxSendKbps))
	}
	if opts.MaxRecvKbps > 0 {
		readRateLimit = ratelimit.NewBucketWithRate(float64(1000*opts.MaxRecvKbps), int64(5*1000*opts.MaxRecvKbps))
	}

	if (opts.MaxRecvKbps > 0 || opts.MaxSendKbps > 0) && !opts.LimitBandwidthInLan {
		lans, _ = osutil.GetLans()
		networks := make([]string, 0, len(lans))
		for _, lan := range lans {
			networks = append(networks, lan.String())
		}
		l.Infoln("Local networks:", strings.Join(networks, ", "))
	}

	dbFile := locations[locDatabase]
	ldb, err := leveldb.OpenFile(dbFile, dbOpts())
	if err != nil && errors.IsCorrupted(err) {
		ldb, err = leveldb.RecoverFile(dbFile, dbOpts())
	}
	if err != nil {
		l.Fatalln("Cannot open database:", err, "- Is another copy of Syncthing already running?")
	}

	// Remove database entries for folders that no longer exist in the config
	folders := cfg.Folders()
	for _, folder := range db.ListFolders(ldb) {
		if _, ok := folders[folder]; !ok {
			l.Infof("Cleaning data for dropped folder %q", folder)
			db.DropFolder(ldb, folder)
		}
	}

	m := model.NewModel(cfg, myID, myName, "syncthing", Version, ldb)
	cfg.Subscribe(m)

	if t := os.Getenv("STDEADLOCKTIMEOUT"); len(t) > 0 {
		it, err := strconv.Atoi(t)
		if err == nil {
			m.StartDeadlockDetector(time.Duration(it) * time.Second)
		}
	} else if !IsRelease || IsBeta {
		m.StartDeadlockDetector(20 * 60 * time.Second)
	}

	// Clear out old indexes for other devices. Otherwise we'll start up and
	// start needing a bunch of files which are nowhere to be found. This
	// needs to be changed when we correctly do persistent indexes.
	for _, folderCfg := range cfg.Folders() {
		m.AddFolder(folderCfg)
		for _, device := range folderCfg.DeviceIDs() {
			if device == myID {
				continue
			}
			m.Index(device, folderCfg.ID, nil, 0, nil)
		}
		// Routine to pull blocks from other devices to synchronize the local
		// folder. Does not run when we are in read only (publish only) mode.
		if folderCfg.ReadOnly {
			l.Okf("Ready to synchronize %s (read only; no external updates accepted)", folderCfg.ID)
			m.StartFolderRO(folderCfg.ID)
		} else {
			l.Okf("Ready to synchronize %s (read-write)", folderCfg.ID)
			m.StartFolderRW(folderCfg.ID)
		}
	}

	mainSvc.Add(m)

	// GUI

	setupGUI(mainSvc, cfg, m, apiSub)

	// The default port we announce, possibly modified by setupUPnP next.

	addr, err := net.ResolveTCPAddr("tcp", opts.ListenAddress[0])
	if err != nil {
		l.Fatalln("Bad listen address:", err)
	}

	// Start discovery

	localPort := addr.Port
	discoverer = discovery(localPort)

	// Start UPnP. The UPnP service will restart global discovery if the
	// external port changes.

	if opts.UPnPEnabled {
		upnpSvc := newUPnPSvc(cfg, localPort)
		mainSvc.Add(upnpSvc)
	}

	connectionSvc := newConnectionSvc(cfg, myID, m, tlsCfg)
	cfg.Subscribe(connectionSvc)
	mainSvc.Add(connectionSvc)

	if cpuProfile {
		f, err := os.Create(fmt.Sprintf("cpu-%d.pprof", os.Getpid()))
		if err != nil {
			log.Fatal(err)
		}
		pprof.StartCPUProfile(f)
		defer pprof.StopCPUProfile()
	}

	for _, device := range cfg.Devices() {
		if len(device.Name) > 0 {
			l.Infof("Device %s is %q at %v", device.DeviceID, device.Name, device.Addresses)
		}
	}

	if opts.URAccepted > 0 && opts.URAccepted < usageReportVersion {
		l.Infoln("Anonymous usage report has changed; revoking acceptance")
		opts.URAccepted = 0
		opts.URUniqueID = ""
		cfg.SetOptions(opts)
	}
	if opts.URAccepted >= usageReportVersion {
		if opts.URUniqueID == "" {
			// Previously the ID was generated from the node ID. We now need
			// to generate a new one.
			opts.URUniqueID = randomString(8)
			cfg.SetOptions(opts)
			cfg.Save()
		}
	}

	// The usageReportingManager registers itself to listen to configuration
	// changes, and there's nothing more we need to tell it from the outside.
	// Hence we don't keep the returned pointer.
	newUsageReportingManager(m, cfg)

	if opts.RestartOnWakeup {
		go standbyMonitor()
	}

	if opts.AutoUpgradeIntervalH > 0 {
		if noUpgrade {
			l.Infof("No automatic upgrades; STNOUPGRADE environment variable defined.")
		} else if IsRelease {
			go autoUpgrade()
		} else {
			l.Infof("No automatic upgrades; %s is not a release version.", Version)
		}
	}

	events.Default.Log(events.StartupComplete, map[string]string{
		"myID": myID.String(),
	})
	go generatePingEvents()

	cleanConfigDirectory()

	code := <-stop

	mainSvc.Stop()

	l.Okln("Exiting")
	os.Exit(code)
}
예제 #7
0
파일: main.go 프로젝트: beride/syncthing
func main() {
	if runtime.GOOS == "windows" {
		// On Windows, we use a log file by default. Setting the -logfile flag
		// to "-" disables this behavior.

		flag.StringVar(&logFile, "logfile", "", "Log file name (use \"-\" for stdout)")

		// We also add an option to hide the console window
		flag.BoolVar(&noConsole, "no-console", false, "Hide console window")
	}

	flag.StringVar(&generateDir, "generate", "", "Generate key and config in specified dir, then exit")
	flag.StringVar(&guiAddress, "gui-address", guiAddress, "Override GUI address")
	flag.StringVar(&guiAuthentication, "gui-authentication", guiAuthentication, "Override GUI authentication; username:password")
	flag.StringVar(&guiAPIKey, "gui-apikey", guiAPIKey, "Override GUI API key")
	flag.StringVar(&confDir, "home", "", "Set configuration directory")
	flag.IntVar(&logFlags, "logflags", logFlags, "Select information in log line prefix")
	flag.BoolVar(&noBrowser, "no-browser", false, "Do not start browser")
	flag.BoolVar(&noRestart, "no-restart", noRestart, "Do not restart; just exit")
	flag.BoolVar(&reset, "reset", false, "Reset the database")
	flag.BoolVar(&doUpgrade, "upgrade", false, "Perform upgrade")
	flag.BoolVar(&doUpgradeCheck, "upgrade-check", false, "Check for available upgrade")
	flag.BoolVar(&showVersion, "version", false, "Show version")
	flag.StringVar(&upgradeTo, "upgrade-to", upgradeTo, "Force upgrade directly from specified URL")
	flag.BoolVar(&auditEnabled, "audit", false, "Write events to audit file")
	flag.BoolVar(&verbose, "verbose", false, "Print verbose log output")

	flag.Usage = usageFor(flag.CommandLine, usage, fmt.Sprintf(extraUsage, baseDirs["config"]))
	flag.Parse()

	if noConsole {
		osutil.HideConsole()
	}

	if confDir != "" {
		// Not set as default above because the string can be really long.
		baseDirs["config"] = confDir
	}

	if err := expandLocations(); err != nil {
		l.Fatalln(err)
	}

	if guiAssets == "" {
		guiAssets = locations[locGUIAssets]
	}

	if runtime.GOOS == "windows" {
		if logFile == "" {
			// Use the default log file location
			logFile = locations[locLogFile]
		} else if logFile == "-" {
			// Don't use a logFile
			logFile = ""
		}
	}

	if showVersion {
		fmt.Println(LongVersion)
		return
	}

	l.SetFlags(logFlags)

	if generateDir != "" {
		dir, err := osutil.ExpandTilde(generateDir)
		if err != nil {
			l.Fatalln("generate:", err)
		}

		info, err := os.Stat(dir)
		if err == nil && !info.IsDir() {
			l.Fatalln(dir, "is not a directory")
		}
		if err != nil && os.IsNotExist(err) {
			err = osutil.MkdirAll(dir, 0700)
			if err != nil {
				l.Fatalln("generate:", err)
			}
		}

		certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem")
		cert, err := tls.LoadX509KeyPair(certFile, keyFile)
		if err == nil {
			l.Warnln("Key exists; will not overwrite.")
			l.Infoln("Device ID:", protocol.NewDeviceID(cert.Certificate[0]))
		} else {
			cert, err = newCertificate(certFile, keyFile, tlsDefaultCommonName)
			myID = protocol.NewDeviceID(cert.Certificate[0])
			if err != nil {
				l.Fatalln("load cert:", err)
			}
			if err == nil {
				l.Infoln("Device ID:", protocol.NewDeviceID(cert.Certificate[0]))
			}
		}

		cfgFile := filepath.Join(dir, "config.xml")
		if _, err := os.Stat(cfgFile); err == nil {
			l.Warnln("Config exists; will not overwrite.")
			return
		}
		var myName, _ = os.Hostname()
		var newCfg = defaultConfig(myName)
		var cfg = config.Wrap(cfgFile, newCfg)
		err = cfg.Save()
		if err != nil {
			l.Warnln("Failed to save config", err)
		}

		return
	}

	if info, err := os.Stat(baseDirs["config"]); err == nil && !info.IsDir() {
		l.Fatalln("Config directory", baseDirs["config"], "is not a directory")
	}

	// Ensure that our home directory exists.
	ensureDir(baseDirs["config"], 0700)

	if upgradeTo != "" {
		err := upgrade.ToURL(upgradeTo)
		if err != nil {
			l.Fatalln("Upgrade:", err) // exits 1
		}
		l.Okln("Upgraded from", upgradeTo)
		return
	}

	if doUpgrade || doUpgradeCheck {
		rel, err := upgrade.LatestRelease(Version)
		if err != nil {
			l.Fatalln("Upgrade:", err) // exits 1
		}

		if upgrade.CompareVersions(rel.Tag, Version) <= 0 {
			l.Infof("No upgrade available (current %q >= latest %q).", Version, rel.Tag)
			os.Exit(exitNoUpgradeAvailable)
		}

		l.Infof("Upgrade available (current %q < latest %q)", Version, rel.Tag)

		if doUpgrade {
			// Use leveldb database locks to protect against concurrent upgrades
			_, err = leveldb.OpenFile(locations[locDatabase], &opt.Options{OpenFilesCacheCapacity: 100})
			if err != nil {
				l.Infoln("Attempting upgrade through running Syncthing...")
				err = upgradeViaRest()
				if err != nil {
					l.Fatalln("Upgrade:", err)
				}
				l.Okln("Syncthing upgrading")
				return
			}

			err = upgrade.To(rel)
			if err != nil {
				l.Fatalln("Upgrade:", err) // exits 1
			}
			l.Okf("Upgraded to %q", rel.Tag)
		}

		return
	}

	if reset {
		resetDB()
		return
	}

	if noRestart {
		syncthingMain()
	} else {
		monitorMain()
	}
}
예제 #8
0
func (s *connectionSvc) handle() {
next:
	for conn := range s.conns {
		cs := conn.ConnectionState()

		// We should have negotiated the next level protocol "bep/1.0" as part
		// of the TLS handshake. Unfortunately this can't be a hard error,
		// because there are implementations out there that don't support
		// protocol negotiation (iOS for one...).
		if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != bepProtocolName {
			l.Infof("Peer %s did not negotiate bep/1.0", conn.RemoteAddr())
		}

		// We should have received exactly one certificate from the other
		// side. If we didn't, they don't have a device ID and we drop the
		// connection.
		certs := cs.PeerCertificates
		if cl := len(certs); cl != 1 {
			l.Infof("Got peer certificate list of length %d != 1 from %s; protocol error", cl, conn.RemoteAddr())
			conn.Close()
			continue
		}
		remoteCert := certs[0]
		remoteID := protocol.NewDeviceID(remoteCert.Raw)

		// The device ID should not be that of ourselves. It can happen
		// though, especially in the presence of NAT hairpinning, multiple
		// clients between the same NAT gateway, and global discovery.
		if remoteID == myID {
			l.Infof("Connected to myself (%s) - should not happen", remoteID)
			conn.Close()
			continue
		}

		// We should not already be connected to the other party. TODO: This
		// could use some better handling. If the old connection is dead but
		// hasn't timed out yet we may want to drop *that* connection and keep
		// this one. But in case we are two devices connecting to each other
		// in parallel we don't want to do that or we end up with no
		// connections still established...
		if s.model.ConnectedTo(remoteID) {
			l.Infof("Connected to already connected device (%s)", remoteID)
			conn.Close()
			continue
		}

		for deviceID, deviceCfg := range s.cfg.Devices() {
			if deviceID == remoteID {
				// Verify the name on the certificate. By default we set it to
				// "syncthing" when generating, but the user may have replaced
				// the certificate and used another name.
				certName := deviceCfg.CertName
				if certName == "" {
					certName = tlsDefaultCommonName
				}
				err := remoteCert.VerifyHostname(certName)
				if err != nil {
					// Incorrect certificate name is something the user most
					// likely wants to know about, since it's an advanced
					// config. Warn instead of Info.
					l.Warnf("Bad certificate from %s (%v): %v", remoteID, conn.RemoteAddr(), err)
					conn.Close()
					continue next
				}

				// If rate limiting is set, and based on the address we should
				// limit the connection, then we wrap it in a limiter.

				limit := s.shouldLimit(conn.RemoteAddr())

				wr := io.Writer(conn)
				if limit && writeRateLimit != nil {
					wr = &limitedWriter{conn, writeRateLimit}
				}

				rd := io.Reader(conn)
				if limit && readRateLimit != nil {
					rd = &limitedReader{conn, readRateLimit}
				}

				name := fmt.Sprintf("%s-%s", conn.LocalAddr(), conn.RemoteAddr())
				protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression)

				l.Infof("Established secure connection to %s at %s", remoteID, name)
				if debugNet {
					l.Debugf("cipher suite: %04X in lan: %t", conn.ConnectionState().CipherSuite, !limit)
				}

				s.model.AddConnection(conn, protoConn)
				continue next
			}
		}

		if !s.cfg.IgnoredDevice(remoteID) {
			events.Default.Log(events.DeviceRejected, map[string]string{
				"device":  remoteID.String(),
				"address": conn.RemoteAddr().String(),
			})
			l.Infof("Connection from %s with unknown device ID %s", conn.RemoteAddr(), remoteID)
		} else {
			l.Infof("Connection from %s with ignored device ID %s", conn.RemoteAddr(), remoteID)
		}

		conn.Close()
	}
}
예제 #9
0
파일: main.go 프로젝트: rumpelsepp/relaysrv
func main() {
	var dir, extAddress string

	flag.StringVar(&listen, "listen", ":22067", "Protocol listen address")
	flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored")
	flag.DurationVar(&networkTimeout, "network-timeout", 2*time.Minute, "Timeout for network operations")
	flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent")
	flag.DurationVar(&messageTimeout, "message-timeout", time.Minute, "Maximum amount of time we wait for relevant messages to arrive")
	flag.IntVar(&sessionLimitBps, "per-session-rate", sessionLimitBps, "Per session rate limit, in bytes/s")
	flag.IntVar(&globalLimitBps, "global-rate", globalLimitBps, "Global rate limit, in bytes/s")
	flag.BoolVar(&debug, "debug", false, "Enable debug output")
	flag.StringVar(&statusAddr, "status-srv", ":22070", "Listen address for status service (blank to disable)")
	flag.StringVar(&poolAddrs, "pools", defaultPoolAddrs, "Comma separated list of relay pool addresses to join")

	flag.Parse()

	if extAddress == "" {
		extAddress = listen
	}

	addr, err := net.ResolveTCPAddr("tcp", extAddress)
	if err != nil {
		log.Fatal(err)
	}

	sessionAddress = addr.IP[:]
	sessionPort = uint16(addr.Port)

	certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem")
	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
	if err != nil {
		log.Fatalln("Failed to load X509 key pair:", err)
	}

	tlsCfg := &tls.Config{
		Certificates:           []tls.Certificate{cert},
		NextProtos:             []string{protocol.ProtocolName},
		ClientAuth:             tls.RequestClientCert,
		SessionTicketsDisabled: true,
		InsecureSkipVerify:     true,
		MinVersion:             tls.VersionTLS12,
		CipherSuites: []uint16{
			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
			tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
			tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
			tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
			tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
			tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
		},
	}

	id := syncthingprotocol.NewDeviceID(cert.Certificate[0])
	if debug {
		log.Println("ID:", id)
	}

	if sessionLimitBps > 0 {
		sessionLimiter = ratelimit.NewBucketWithRate(float64(sessionLimitBps), int64(2*sessionLimitBps))
	}
	if globalLimitBps > 0 {
		globalLimiter = ratelimit.NewBucketWithRate(float64(globalLimitBps), int64(2*globalLimitBps))
	}

	if statusAddr != "" {
		go statusService(statusAddr)
	}

	uri, err := url.Parse(fmt.Sprintf("relay://%s/?id=%s", extAddress, id))
	if err != nil {
		log.Fatalln("Failed to construct URI", err)
	}

	if poolAddrs == defaultPoolAddrs {
		log.Println("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
		log.Println("!!  Joining default relay pools, this relay will be available for public use. !!")
		log.Println(`!!      Use the -pools="" command line option to make the relay private.      !!`)
		log.Println("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
	}

	pools := strings.Split(poolAddrs, ",")
	for _, pool := range pools {
		pool = strings.TrimSpace(pool)
		if len(pool) > 0 {
			go poolHandler(pool, uri)
		}
	}

	listener(listen, tlsCfg)
}
예제 #10
0
파일: main.go 프로젝트: qbit/syncthing
func listenConnect(myID protocol.DeviceID, m *model.Model, tlsCfg *tls.Config) {
	var conns = make(chan *tls.Conn)

	// Listen
	for _, addr := range cfg.Options().ListenAddress {
		go listenTLS(conns, addr, tlsCfg)
	}

	// Connect
	go dialTLS(m, conns, tlsCfg)

next:
	for conn := range conns {
		certs := conn.ConnectionState().PeerCertificates
		if cl := len(certs); cl != 1 {
			l.Infof("Got peer certificate list of length %d != 1 from %s; protocol error", cl, conn.RemoteAddr())
			conn.Close()
			continue
		}
		remoteCert := certs[0]
		remoteID := protocol.NewDeviceID(remoteCert.Raw)

		if remoteID == myID {
			l.Infof("Connected to myself (%s) - should not happen", remoteID)
			conn.Close()
			continue
		}

		if m.ConnectedTo(remoteID) {
			l.Infof("Connected to already connected device (%s)", remoteID)
			conn.Close()
			continue
		}

		for deviceID, deviceCfg := range cfg.Devices() {
			if deviceID == remoteID {
				// Verify the name on the certificate. By default we set it to
				// "syncthing" when generating, but the user may have replaced
				// the certificate and used another name.
				certName := deviceCfg.CertName
				if certName == "" {
					certName = tlsDefaultCommonName
				}
				err := remoteCert.VerifyHostname(certName)
				if err != nil {
					// Incorrect certificate name is something the user most
					// likely wants to know about, since it's an advanced
					// config. Warn instead of Info.
					l.Warnf("Bad certificate from %s (%v): %v", remoteID, conn.RemoteAddr(), err)
					conn.Close()
					continue next
				}

				// If rate limiting is set, we wrap the connection in a
				// limiter.
				wr := io.Writer(conn)
				if writeRateLimit != nil {
					wr = &limitedWriter{conn, writeRateLimit}
				}

				rd := io.Reader(conn)
				if readRateLimit != nil {
					rd = &limitedReader{conn, readRateLimit}
				}

				name := fmt.Sprintf("%s-%s", conn.LocalAddr(), conn.RemoteAddr())
				protoConn := protocol.NewConnection(remoteID, rd, wr, m, name, deviceCfg.Compression)

				l.Infof("Established secure connection to %s at %s", remoteID, name)
				if debugNet {
					l.Debugf("cipher suite %04X", conn.ConnectionState().CipherSuite)
				}
				events.Default.Log(events.DeviceConnected, map[string]string{
					"id":   remoteID.String(),
					"addr": conn.RemoteAddr().String(),
				})

				m.AddConnection(conn, protoConn)
				continue next
			}
		}

		if !cfg.IgnoredDevice(remoteID) {
			events.Default.Log(events.DeviceRejected, map[string]string{
				"device":  remoteID.String(),
				"address": conn.RemoteAddr().String(),
			})
			l.Infof("Connection from %s with unknown device ID %s", conn.RemoteAddr(), remoteID)
		} else {
			l.Infof("Connection from %s with ignored device ID %s", conn.RemoteAddr(), remoteID)
		}

		conn.Close()
	}
}
예제 #11
0
파일: main.go 프로젝트: qbit/syncthing
func syncthingMain() {
	var err error

	if len(os.Getenv("GOGC")) == 0 {
		debug.SetGCPercent(25)
	}

	if len(os.Getenv("GOMAXPROCS")) == 0 {
		runtime.GOMAXPROCS(runtime.NumCPU())
	}

	events.Default.Log(events.Starting, map[string]string{"home": confDir})

	// Ensure that that we have a certificate and key.
	cert, err = loadCert(confDir, "")
	if err != nil {
		newCertificate(confDir, "", tlsDefaultCommonName)
		cert, err = loadCert(confDir, "")
		if err != nil {
			l.Fatalln("load cert:", err)
		}
	}

	// We reinitialize the predictable RNG with our device ID, to get a
	// sequence that is always the same but unique to this syncthing instance.
	predictableRandom.Seed(seedFromBytes(cert.Certificate[0]))

	myID = protocol.NewDeviceID(cert.Certificate[0])
	l.SetPrefix(fmt.Sprintf("[%s] ", myID.String()[:5]))

	l.Infoln(LongVersion)
	l.Infoln("My ID:", myID)

	// Prepare to be able to save configuration

	cfgFile := filepath.Join(confDir, "config.xml")

	var myName string

	// Load the configuration file, if it exists.
	// If it does not, create a template.

	if info, err := os.Stat(cfgFile); err == nil {
		if !info.Mode().IsRegular() {
			l.Fatalln("Config file is not a file?")
		}
		cfg, err = config.Load(cfgFile, myID)
		if err == nil {
			myCfg := cfg.Devices()[myID]
			if myCfg.Name == "" {
				myName, _ = os.Hostname()
			} else {
				myName = myCfg.Name
			}
		} else {
			l.Fatalln("Configuration:", err)
		}
	} else {
		l.Infoln("No config file; starting with empty defaults")
		myName, _ = os.Hostname()
		newCfg := defaultConfig(myName)
		cfg = config.Wrap(cfgFile, newCfg)
		cfg.Save()
		l.Infof("Edit %s to taste or use the GUI\n", cfgFile)
	}

	if cfg.Raw().OriginalVersion != config.CurrentVersion {
		l.Infoln("Archiving a copy of old config file format")
		// Archive a copy
		osutil.Rename(cfgFile, cfgFile+fmt.Sprintf(".v%d", cfg.Raw().OriginalVersion))
		// Save the new version
		cfg.Save()
	}

	if len(profiler) > 0 {
		go func() {
			l.Debugln("Starting profiler on", profiler)
			runtime.SetBlockProfileRate(1)
			err := http.ListenAndServe(profiler, nil)
			if err != nil {
				l.Fatalln(err)
			}
		}()
	}

	// The TLS configuration is used for both the listening socket and outgoing
	// connections.

	tlsCfg := &tls.Config{
		Certificates:           []tls.Certificate{cert},
		NextProtos:             []string{"bep/1.0"},
		ClientAuth:             tls.RequestClientCert,
		SessionTicketsDisabled: true,
		InsecureSkipVerify:     true,
		MinVersion:             tls.VersionTLS12,
		CipherSuites: []uint16{
			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
			tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
			tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
			tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
			tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
			tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
		},
	}

	// If the read or write rate should be limited, set up a rate limiter for it.
	// This will be used on connections created in the connect and listen routines.

	opts := cfg.Options()

	if !opts.SymlinksEnabled {
		symlinks.Supported = false
	}

	if opts.MaxSendKbps > 0 {
		writeRateLimit = ratelimit.NewBucketWithRate(float64(1000*opts.MaxSendKbps), int64(5*1000*opts.MaxSendKbps))
	}
	if opts.MaxRecvKbps > 0 {
		readRateLimit = ratelimit.NewBucketWithRate(float64(1000*opts.MaxRecvKbps), int64(5*1000*opts.MaxRecvKbps))
	}

	ldb, err := leveldb.OpenFile(filepath.Join(confDir, "index"), &opt.Options{OpenFilesCacheCapacity: 100})
	if err != nil {
		l.Fatalln("Cannot open database:", err, "- Is another copy of Syncthing already running?")
	}

	// Remove database entries for folders that no longer exist in the config
	folders := cfg.Folders()
	for _, folder := range db.ListFolders(ldb) {
		if _, ok := folders[folder]; !ok {
			l.Infof("Cleaning data for dropped folder %q", folder)
			db.DropFolder(ldb, folder)
		}
	}

	m := model.NewModel(cfg, myName, "syncthing", Version, ldb)

	sanityCheckFolders(cfg, m)

	// GUI

	setupGUI(cfg, m)

	// Clear out old indexes for other devices. Otherwise we'll start up and
	// start needing a bunch of files which are nowhere to be found. This
	// needs to be changed when we correctly do persistent indexes.
	for _, folderCfg := range cfg.Folders() {
		if folderCfg.Invalid != "" {
			continue
		}
		for _, device := range folderCfg.DeviceIDs() {
			if device == myID {
				continue
			}
			m.Index(device, folderCfg.ID, nil)
		}
	}

	// The default port we announce, possibly modified by setupUPnP next.

	addr, err := net.ResolveTCPAddr("tcp", opts.ListenAddress[0])
	if err != nil {
		l.Fatalln("Bad listen address:", err)
	}
	externalPort = addr.Port

	// UPnP
	igd = nil

	if opts.UPnPEnabled {
		setupUPnP()
	}

	// Routine to connect out to configured devices
	discoverer = discovery(externalPort)
	go listenConnect(myID, m, tlsCfg)

	for _, folder := range cfg.Folders() {
		if folder.Invalid != "" {
			continue
		}

		// Routine to pull blocks from other devices to synchronize the local
		// folder. Does not run when we are in read only (publish only) mode.
		if folder.ReadOnly {
			l.Okf("Ready to synchronize %s (read only; no external updates accepted)", folder.ID)
			m.StartFolderRO(folder.ID)
		} else {
			l.Okf("Ready to synchronize %s (read-write)", folder.ID)
			m.StartFolderRW(folder.ID)
		}
	}

	if cpuProfile {
		f, err := os.Create(fmt.Sprintf("cpu-%d.pprof", os.Getpid()))
		if err != nil {
			log.Fatal(err)
		}
		pprof.StartCPUProfile(f)
		defer pprof.StopCPUProfile()
	}

	for _, device := range cfg.Devices() {
		if len(device.Name) > 0 {
			l.Infof("Device %s is %q at %v", device.DeviceID, device.Name, device.Addresses)
		}
	}

	if opts.URAccepted > 0 && opts.URAccepted < usageReportVersion {
		l.Infoln("Anonymous usage report has changed; revoking acceptance")
		opts.URAccepted = 0
		opts.URUniqueID = ""
		cfg.SetOptions(opts)
	}
	if opts.URAccepted >= usageReportVersion {
		if opts.URUniqueID == "" {
			// Previously the ID was generated from the node ID. We now need
			// to generate a new one.
			opts.URUniqueID = randomString(8)
			cfg.SetOptions(opts)
			cfg.Save()
		}
		go usageReportingLoop(m)
		go func() {
			time.Sleep(10 * time.Minute)
			err := sendUsageReport(m)
			if err != nil {
				l.Infoln("Usage report:", err)
			}
		}()
	}

	if opts.RestartOnWakeup {
		go standbyMonitor()
	}

	if opts.AutoUpgradeIntervalH > 0 {
		if noUpgrade {
			l.Infof("No automatic upgrades; STNOUPGRADE environment variable defined.")
		} else if IsRelease {
			go autoUpgrade()
		} else {
			l.Infof("No automatic upgrades; %s is not a relase version.", Version)
		}
	}

	events.Default.Log(events.StartupComplete, nil)
	go generateEvents()

	code := <-stop

	l.Okln("Exiting")
	os.Exit(code)
}
예제 #12
0
파일: main.go 프로젝트: qbit/syncthing
func main() {
	defConfDir, err := getDefaultConfDir()
	if err != nil {
		l.Fatalln("home:", err)
	}

	if runtime.GOOS == "windows" {
		// On Windows, we use a log file by default. Setting the -logfile flag
		// to the empty string disables this behavior.

		logFile = filepath.Join(defConfDir, "syncthing.log")
		flag.StringVar(&logFile, "logfile", logFile, "Log file name (blank for stdout)")

		// We also add an option to hide the console window
		flag.BoolVar(&noConsole, "no-console", false, "Hide console window")
	}

	flag.StringVar(&generateDir, "generate", "", "Generate key and config in specified dir, then exit")
	flag.StringVar(&guiAddress, "gui-address", guiAddress, "Override GUI address")
	flag.StringVar(&guiAuthentication, "gui-authentication", guiAuthentication, "Override GUI authentication; username:password")
	flag.StringVar(&guiAPIKey, "gui-apikey", guiAPIKey, "Override GUI API key")
	flag.StringVar(&confDir, "home", "", "Set configuration directory")
	flag.IntVar(&logFlags, "logflags", logFlags, "Select information in log line prefix")
	flag.BoolVar(&noBrowser, "no-browser", false, "Do not start browser")
	flag.BoolVar(&noRestart, "no-restart", noRestart, "Do not restart; just exit")
	flag.BoolVar(&reset, "reset", false, "Prepare to resync from cluster")
	flag.BoolVar(&doUpgrade, "upgrade", false, "Perform upgrade")
	flag.BoolVar(&doUpgradeCheck, "upgrade-check", false, "Check for available upgrade")
	flag.BoolVar(&showVersion, "version", false, "Show version")
	flag.StringVar(&upgradeTo, "upgrade-to", upgradeTo, "Force upgrade directly from specified URL")

	flag.Usage = usageFor(flag.CommandLine, usage, fmt.Sprintf(extraUsage, defConfDir))
	flag.Parse()

	if noConsole {
		osutil.HideConsole()
	}

	if confDir == "" {
		// Not set as default above because the string can be really long.
		confDir = defConfDir
	}

	if confDir != defConfDir && filepath.Dir(logFile) == defConfDir {
		// The user changed the config dir with -home, but not the log file
		// location. In this case we assume they meant for the logfile to
		// still live in it's default location *relative to the config dir*.
		logFile = filepath.Join(confDir, "syncthing.log")
	}

	if showVersion {
		fmt.Println(LongVersion)
		return
	}

	l.SetFlags(logFlags)

	if generateDir != "" {
		dir, err := osutil.ExpandTilde(generateDir)
		if err != nil {
			l.Fatalln("generate:", err)
		}

		info, err := os.Stat(dir)
		if err == nil && !info.IsDir() {
			l.Fatalln(dir, "is not a directory")
		}
		if err != nil && os.IsNotExist(err) {
			err = os.MkdirAll(dir, 0700)
			if err != nil {
				l.Fatalln("generate:", err)
			}
		}

		cert, err := loadCert(dir, "")
		if err == nil {
			l.Warnln("Key exists; will not overwrite.")
			l.Infoln("Device ID:", protocol.NewDeviceID(cert.Certificate[0]))
		} else {
			newCertificate(dir, "", tlsDefaultCommonName)
			cert, err = loadCert(dir, "")
			myID = protocol.NewDeviceID(cert.Certificate[0])
			if err != nil {
				l.Fatalln("load cert:", err)
			}
			if err == nil {
				l.Infoln("Device ID:", protocol.NewDeviceID(cert.Certificate[0]))
			}
		}

		cfgFile := filepath.Join(dir, "config.xml")
		if _, err := os.Stat(cfgFile); err == nil {
			l.Warnln("Config exists; will not overwrite.")
			return
		}
		var myName, _ = os.Hostname()
		var newCfg = defaultConfig(myName)
		var cfg = config.Wrap(cfgFile, newCfg)
		err = cfg.Save()
		if err != nil {
			l.Warnln("Failed to save config", err)
		}

		return
	}

	confDir, err := osutil.ExpandTilde(confDir)
	if err != nil {
		l.Fatalln("home:", err)
	}

	if info, err := os.Stat(confDir); err == nil && !info.IsDir() {
		l.Fatalln("Config directory", confDir, "is not a directory")
	}

	// Ensure that our home directory exists.
	ensureDir(confDir, 0700)

	if upgradeTo != "" {
		err := upgrade.ToURL(upgradeTo)
		if err != nil {
			l.Fatalln("Upgrade:", err) // exits 1
		}
		l.Okln("Upgraded from", upgradeTo)
		return
	}

	if doUpgrade || doUpgradeCheck {
		rel, err := upgrade.LatestRelease(IsBeta)
		if err != nil {
			l.Fatalln("Upgrade:", err) // exits 1
		}

		if upgrade.CompareVersions(rel.Tag, Version) <= 0 {
			l.Infof("No upgrade available (current %q >= latest %q).", Version, rel.Tag)
			os.Exit(exitNoUpgradeAvailable)
		}

		l.Infof("Upgrade available (current %q < latest %q)", Version, rel.Tag)

		if doUpgrade {
			// Use leveldb database locks to protect against concurrent upgrades
			_, err = leveldb.OpenFile(filepath.Join(confDir, "index"), &opt.Options{OpenFilesCacheCapacity: 100})
			if err != nil {
				l.Fatalln("Cannot upgrade, database seems to be locked. Is another copy of Syncthing already running?")
			}

			err = upgrade.To(rel)
			if err != nil {
				l.Fatalln("Upgrade:", err) // exits 1
			}
			l.Okf("Upgraded to %q", rel.Tag)
		}

		return
	}

	if reset {
		resetFolders()
		return
	}

	if noRestart {
		syncthingMain()
	} else {
		monitorMain()
	}
}
예제 #13
0
func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) {
	conn := tls.Server(tcpConn, config)
	err := conn.Handshake()
	if err != nil {
		if debug {
			log.Println("Protocol connection TLS handshake:", conn.RemoteAddr(), err)
		}
		conn.Close()
		return
	}

	state := conn.ConnectionState()
	if (!state.NegotiatedProtocolIsMutual || state.NegotiatedProtocol != protocol.ProtocolName) && debug {
		log.Println("Protocol negotiation error")
	}

	certs := state.PeerCertificates
	if len(certs) != 1 {
		if debug {
			log.Println("Certificate list error")
		}
		conn.Close()
		return
	}

	id := syncthingprotocol.NewDeviceID(certs[0].Raw)

	messages := make(chan interface{})
	errors := make(chan error, 1)
	outbox := make(chan interface{})

	// Read messages from the connection and send them on the messages
	// channel. When there is an error, send it on the error channel and
	// return. Applies also when the connection gets closed, so the pattern
	// below is to close the connection on error, then wait for the error
	// signal from messageReader to exit.
	go messageReader(conn, messages, errors)

	pingTicker := time.NewTicker(pingInterval)
	timeoutTicker := time.NewTimer(networkTimeout)
	joined := false

	for {
		select {
		case message := <-messages:
			timeoutTicker.Reset(networkTimeout)
			if debug {
				log.Printf("Message %T from %s", message, id)
			}

			switch msg := message.(type) {
			case protocol.JoinRelayRequest:
				outboxesMut.RLock()
				_, ok := outboxes[id]
				outboxesMut.RUnlock()
				if ok {
					protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected)
					if debug {
						log.Println("Already have a peer with the same ID", id, conn.RemoteAddr())
					}
					conn.Close()
					continue
				}

				outboxesMut.Lock()
				outboxes[id] = outbox
				outboxesMut.Unlock()
				joined = true

				protocol.WriteMessage(conn, protocol.ResponseSuccess)

			case protocol.ConnectRequest:
				requestedPeer := syncthingprotocol.DeviceIDFromBytes(msg.ID)
				outboxesMut.RLock()
				peerOutbox, ok := outboxes[requestedPeer]
				outboxesMut.RUnlock()
				if !ok {
					if debug {
						log.Println(id, "is looking for", requestedPeer, "which does not exist")
					}
					protocol.WriteMessage(conn, protocol.ResponseNotFound)
					conn.Close()
					continue
				}

				ses := newSession(sessionLimiter, globalLimiter)

				go ses.Serve()

				clientInvitation := ses.GetClientInvitationMessage(requestedPeer)
				serverInvitation := ses.GetServerInvitationMessage(id)

				if err := protocol.WriteMessage(conn, clientInvitation); err != nil {
					if debug {
						log.Printf("Error sending invitation from %s to client: %s", id, err)
					}
					conn.Close()
					continue
				}

				peerOutbox <- serverInvitation

				if debug {
					log.Println("Sent invitation from", id, "to", requestedPeer)
				}
				conn.Close()

			case protocol.Pong:
				// Nothing

			default:
				if debug {
					log.Printf("Unknown message %s: %T", id, message)
				}
				protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage)
				conn.Close()
			}

		case err := <-errors:
			if debug {
				log.Printf("Closing connection %s: %s", id, err)
			}
			close(outbox)

			// Potentially closing a second time.
			conn.Close()

			// Only delete the outbox if the client is joined, as it might be
			// a lookup request coming from the same client.
			if joined {
				outboxesMut.Lock()
				delete(outboxes, id)
				outboxesMut.Unlock()
			}
			return

		case <-pingTicker.C:
			if !joined {
				if debug {
					log.Println(id, "didn't join within", pingInterval)
				}
				conn.Close()
				continue
			}

			if err := protocol.WriteMessage(conn, protocol.Ping{}); err != nil {
				if debug {
					log.Println(id, err)
				}
				conn.Close()
			}

		case <-timeoutTicker.C:
			// We should receive a error from the reader loop, which will cause
			// us to quit this loop.
			if debug {
				log.Printf("%s timed out", id)
			}
			conn.Close()

		case msg := <-outbox:
			if debug {
				log.Printf("Sending message %T to %s", msg, id)
			}
			if err := protocol.WriteMessage(conn, msg); err != nil {
				if debug {
					log.Println(id, err)
				}
				conn.Close()
			}
		}
	}
}
예제 #14
0
파일: main.go 프로젝트: rumpelsepp/relaysrv
func main() {
	log.SetOutput(os.Stdout)
	log.SetFlags(log.LstdFlags | log.Lshortfile)

	var connect, relay, dir string
	var join, test bool

	flag.StringVar(&connect, "connect", "", "Device ID to which to connect to")
	flag.BoolVar(&join, "join", false, "Join relay")
	flag.BoolVar(&test, "test", false, "Generic relay test")
	flag.StringVar(&relay, "relay", "relay://127.0.0.1:22067", "Relay address")
	flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored")

	flag.Parse()

	certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem")
	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
	if err != nil {
		log.Fatalln("Failed to load X509 key pair:", err)
	}

	id := syncthingprotocol.NewDeviceID(cert.Certificate[0])
	log.Println("ID:", id)

	uri, err := url.Parse(relay)
	if err != nil {
		log.Fatal(err)
	}

	stdin := make(chan string)

	go stdinReader(stdin)

	if join {
		log.Println("Creating client")
		relay := client.NewProtocolClient(uri, []tls.Certificate{cert}, nil)
		log.Println("Created client")

		go relay.Serve()

		recv := make(chan protocol.SessionInvitation)

		go func() {
			log.Println("Starting invitation receiver")
			for invite := range relay.Invitations {
				select {
				case recv <- invite:
					log.Println("Received invitation", invite)
				default:
					log.Println("Discarding invitation", invite)
				}
			}
		}()

		for {
			conn, err := client.JoinSession(<-recv)
			if err != nil {
				log.Fatalln("Failed to join", err)
			}
			log.Println("Joined", conn.RemoteAddr(), conn.LocalAddr())
			connectToStdio(stdin, conn)
			log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr())
		}
	} else if connect != "" {
		id, err := syncthingprotocol.DeviceIDFromString(connect)
		if err != nil {
			log.Fatal(err)
		}

		invite, err := client.GetInvitationFromRelay(uri, id, []tls.Certificate{cert})
		if err != nil {
			log.Fatal(err)
		}

		log.Println("Received invitation", invite)
		conn, err := client.JoinSession(invite)
		if err != nil {
			log.Fatalln("Failed to join", err)
		}
		log.Println("Joined", conn.RemoteAddr(), conn.LocalAddr())
		connectToStdio(stdin, conn)
		log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr())
	} else if test {
		if client.TestRelay(uri, []tls.Certificate{cert}) {
			log.Println("OK")
		} else {
			log.Println("FAIL")
		}
	} else {
		log.Fatal("Requires either join or connect")
	}
}