func (m *Manager) connect(conn *srpc.Conn) error { defer conn.Flush() clientChannel := make(chan *proto.ServerMessage, 4096) m.rwMutex.Lock() m.clients[clientChannel] = clientChannel m.rwMutex.Unlock() defer func() { m.rwMutex.Lock() delete(m.clients, clientChannel) m.rwMutex.Unlock() }() closeNotifyChannel := make(chan struct{}) // The client must keep the same encoder/decoder pair over the lifetime // of the connection. go m.handleClientRequests(gob.NewDecoder(conn), clientChannel, closeNotifyChannel) encoder := gob.NewEncoder(conn) for { select { case serverMessage := <-clientChannel: if err := encoder.Encode(serverMessage); err != nil { m.logger.Printf("error encoding ServerMessage: %s\n", err) return err } if len(clientChannel) < 1 { if err := conn.Flush(); err != nil { m.logger.Printf("error flushing: %s\n", err) return err } } case <-closeNotifyChannel: return nil } } }
func (t *rpcType) GetFiles(conn *srpc.Conn) error { defer conn.Flush() t.getFilesLock.Lock() defer t.getFilesLock.Unlock() encoder := gob.NewEncoder(conn) numFiles := 0 for ; ; numFiles++ { filename, err := conn.ReadString('\n') if err != nil { return err } filename = filename[:len(filename)-1] if filename == "" { break } filename = path.Join(t.rootDir, filename) if err := processFilename(conn, filename, encoder); err != nil { return err } } plural := "s" if numFiles == 1 { plural = "" } t.logger.Printf("GetFiles(): %d file%s provided\n", numFiles, plural) return nil }
func addObjects(conn *srpc.Conn, adder ObjectAdder, logger *log.Logger) error { defer conn.Flush() decoder := gob.NewDecoder(conn) encoder := gob.NewEncoder(conn) numAdded := 0 numObj := 0 for ; ; numObj++ { var request objectserver.AddObjectRequest var response objectserver.AddObjectResponse if err := decoder.Decode(&request); err != nil { if err == io.EOF || err == io.ErrUnexpectedEOF { break } return err } if request.Length < 1 { break } response.Hash, response.Added, response.Error = adder.AddObject(conn, request.Length, request.ExpectedHash) if response.Added { numAdded++ } if err := encoder.Encode(response); err != nil { return err } if response.Error != nil { logger.Printf("AddObjects(): failed, %d of %d are new objects %s", numAdded, numObj, response.Error.Error()) return nil } } logger.Printf("AddObjects(): %d of %d are new objects", numAdded, numObj) return nil }
func (t *srpcType) ListImages(conn *srpc.Conn) error { for _, name := range t.imageDataBase.ListImages() { if _, err := conn.WriteString(name + "\n"); err != nil { return err } } _, err := conn.WriteString("\n") return err }
func (t *rpcType) ConfigureSubs(conn *srpc.Conn, request dominator.ConfigureSubsRequest, reply *dominator.ConfigureSubsResponse) error { if conn.Username() == "" { t.logger.Printf("ConfigureSubs()\n") } else { t.logger.Printf("ConfigureSubs(): by %s\n", conn.Username()) } return t.herd.ConfigureSubs(sub.Configuration(request)) }
func (t *rpcType) EnableUpdates(conn *srpc.Conn, request dominator.EnableUpdatesRequest, reply *dominator.EnableUpdatesResponse) error { if conn.Username() == "" { t.logger.Printf("EnableUpdates(%s)\n", request.Reason) } else { t.logger.Printf("EnableUpdates(%s): by %s\n", request.Reason, conn.Username()) } return t.herd.EnableUpdates() }
func (t *rpcType) ClearSafetyShutoff(conn *srpc.Conn, request dominator.ClearSafetyShutoffRequest, reply *dominator.ClearSafetyShutoffResponse) error { if conn.Username() == "" { t.logger.Printf("ClearSafetyShutoff(%s)\n", request.Hostname) } else { t.logger.Printf("ClearSafetyShutoff(%s): by %s\n", request.Hostname, conn.Username()) } return t.herd.ClearSafetyShutoff(request.Hostname) }
func (t *srpcType) CheckObjects(conn *srpc.Conn, request objectserver.CheckObjectsRequest, reply *objectserver.CheckObjectsResponse) error { sizes, err := t.objectServer.CheckObjects(request.Hashes) if err != nil { _, err = conn.WriteString(err.Error() + "\n") return err } reply.ObjectSizes = sizes return nil }
func getCloseNotifier(conn *srpc.Conn) <-chan error { closeChannel := make(chan error) go func() { for { buf := make([]byte, 1) if _, err := conn.Read(buf); err != nil { closeChannel <- err return } } }() return closeChannel }
func (t *srpcType) DeleteUnreferencedObjects(conn *srpc.Conn, request imageserver.DeleteUnreferencedObjectsRequest, reply *imageserver.DeleteUnreferencedObjectsResponse) error { username := conn.Username() if username == "" { t.logger.Printf("DeleteUnreferencedObjects(%d%%, %s)\n", request.Percentage, format.FormatBytes(request.Bytes)) } else { t.logger.Printf("DeleteUnreferencedObjects(%d%%, %s) by %s\n", request.Percentage, format.FormatBytes(request.Bytes), username) } return t.imageDataBase.DeleteUnreferencedObjects(request.Percentage, request.Bytes) }
func (t *srpcType) MakeDirectory(conn *srpc.Conn, request imageserver.MakeDirectoryRequest, reply *imageserver.MakeDirectoryResponse) error { username := conn.Username() if err := t.checkMutability(); err != nil { return err } if username == "" { t.logger.Printf("MakeDirectory(%s)\n", request.DirectoryName) } else { t.logger.Printf("MakeDirectory(%s) by %s\n", request.DirectoryName, username) } return t.imageDataBase.MakeDirectory(request.DirectoryName, username) }
func (t *srpcType) ChownDirectory(conn *srpc.Conn, request imageserver.ChangeOwnerRequest, reply *imageserver.ChangeOwnerResponse) error { username := conn.Username() if username == "" { return errors.New("no username: unauthenticated connection") } if request.OwnerGroup != "" { if _, err := user.LookupGroup(request.OwnerGroup); err != nil { return err } } t.logger.Printf("ChownDirectory(%s) to: \"%s\" by %s\n", request.DirectoryName, request.OwnerGroup, username) return t.imageDataBase.ChownDirectory(request.DirectoryName, request.OwnerGroup) }
func (t *srpcType) DeleteImage(conn *srpc.Conn, request imageserver.DeleteImageRequest, reply *imageserver.DeleteImageResponse) error { username := conn.Username() if err := t.checkMutability(); err != nil { return err } if !t.imageDataBase.CheckImage(request.ImageName) { return errors.New("image does not exist") } if username == "" { t.logger.Printf("DeleteImage(%s)\n", request.ImageName) } else { t.logger.Printf("DeleteImage(%s) by %s\n", request.ImageName, username) } return t.imageDataBase.DeleteImage(request.ImageName, &username) }
func (t *rpcType) GetMdbUpdates(conn *srpc.Conn) error { encoder := gob.NewEncoder(conn) updateChannel := make(chan mdbserver.MdbUpdate, 1) t.rwMutex.Lock() t.updateChannels[conn] = updateChannel t.rwMutex.Unlock() mdbUpdate := mdbserver.MdbUpdate{MachinesToAdd: t.currentMdb.Machines} if err := encoder.Encode(mdbUpdate); err != nil { return nil } if err := conn.Flush(); err != nil { return nil } closeChannel := getCloseNotifier(conn) for { var err error select { case mdbUpdate := <-updateChannel: if err = encoder.Encode(mdbUpdate); err != nil { break } if err = conn.Flush(); err != nil { break } case err = <-closeChannel: break } if err != nil { t.rwMutex.Lock() delete(t.updateChannels, conn) t.rwMutex.Unlock() if err != io.EOF { t.logger.Println(err) return err } else { return nil } } } }
func sendClientRequests(conn *srpc.Conn, clientRequestChannel <-chan *proto.ClientRequest, closeNotifyChannel <-chan struct{}, logger *log.Logger) { encoder := gob.NewEncoder(conn) for { select { case clientRequest := <-clientRequestChannel: if err := encoder.Encode(clientRequest); err != nil { logger.Printf("error encoding client request: %s\n", err) return } if len(clientRequestChannel) < 1 { if err := conn.Flush(); err != nil { logger.Printf("error flushing: %s\n", err) return } } case <-closeNotifyChannel: return } } }
func (t *srpcType) AddObjects(conn *srpc.Conn) error { defer runtime.GC() // An opportune time to take out the garbage. defer conn.Flush() decoder := gob.NewDecoder(conn) encoder := gob.NewEncoder(conn) numAdded := 0 numObj := 0 for ; ; numObj++ { var request objectserver.AddObjectRequest var response objectserver.AddObjectResponse if err := decoder.Decode(&request); err != nil { if err == io.EOF || err == io.ErrUnexpectedEOF { break } return err } if request.Length < 1 { break } response.Hash, response.Added, response.Error = t.objectServer.AddObject(conn, request.Length, request.ExpectedHash) if response.Added { numAdded++ } if err := encoder.Encode(response); err != nil { return err } if response.Error != nil { t.logger.Printf("AddObjects(): failed, %d of %d are new objects %s", numAdded, numObj, response.Error.Error()) return nil } } t.logger.Printf("AddObjects(): %d of %d are new objects", numAdded, numObj) return nil }
func (t *rpcType) Cleanup(conn *srpc.Conn) error { defer conn.Flush() var request sub.CleanupRequest var response sub.CleanupResponse decoder := gob.NewDecoder(conn) if err := decoder.Decode(&request); err != nil { _, err = conn.WriteString(err.Error() + "\n") return err } if err := t.cleanup(request, &response); err != nil { _, err = conn.WriteString(err.Error() + "\n") return err } if _, err := conn.WriteString("\n"); err != nil { return err } return gob.NewEncoder(conn).Encode(response) }
func (t *srpcType) DeleteImage(conn *srpc.Conn) error { defer conn.Flush() var request imageserver.DeleteImageRequest var response imageserver.DeleteImageResponse decoder := gob.NewDecoder(conn) if err := decoder.Decode(&request); err != nil { _, err = conn.WriteString(err.Error() + "\n") return err } if err := t.deleteImage(request, &response); err != nil { _, err = conn.WriteString(err.Error() + "\n") return err } if _, err := conn.WriteString("\n"); err != nil { return err } return gob.NewEncoder(conn).Encode(response) }
func sendRequests(conn *srpc.Conn, filenames []string) error { for _, filename := range filenames { if _, err := conn.WriteString(filename + "\n"); err != nil { return err } } if _, err := conn.WriteString("\n"); err != nil { return err } return conn.Flush() }
func (t *rpcType) Poll(conn *srpc.Conn) error { defer conn.Flush() var request sub.PollRequest var response sub.PollResponse decoder := gob.NewDecoder(conn) if err := decoder.Decode(&request); err != nil { _, err = conn.WriteString(err.Error() + "\n") return err } if _, err := conn.WriteString("\n"); err != nil { return err } response.NetworkSpeed = t.networkReaderContext.MaximumSpeed() response.CurrentConfiguration = t.getConfiguration() t.rwLock.RLock() response.FetchInProgress = t.fetchInProgress response.UpdateInProgress = t.updateInProgress if t.lastFetchError != nil { response.LastFetchError = t.lastFetchError.Error() } if !t.updateInProgress { if t.lastUpdateError != nil { response.LastUpdateError = t.lastUpdateError.Error() } response.LastUpdateHadTriggerFailures = t.lastUpdateHadTriggerFailures } response.LastSuccessfulImageName = t.lastSuccessfulImageName t.rwLock.RUnlock() response.StartTime = startTime response.PollTime = time.Now() response.ScanCount = t.fileSystemHistory.ScanCount() response.DurationOfLastScan = t.fileSystemHistory.DurationOfLastScan() response.GenerationCount = t.fileSystemHistory.GenerationCount() fs := t.fileSystemHistory.FileSystem() if fs != nil && !request.ShortPollOnly && request.HaveGeneration != t.fileSystemHistory.GenerationCount() { response.FileSystemFollows = true } encoder := gob.NewEncoder(conn) if err := encoder.Encode(response); err != nil { return err } if response.FileSystemFollows { if err := fs.FileSystem.Encode(conn); err != nil { return err } if err := fs.ObjectCache.Encode(conn); err != nil { return err } } return nil }
func (t *rpcType) Update(conn *srpc.Conn) error { var request sub.UpdateRequest var response sub.UpdateResponse decoder := gob.NewDecoder(conn) if err := decoder.Decode(&request); err != nil { _, err = conn.WriteString(err.Error() + "\n") return err } if err := t.update(request, &response); err != nil { _, err = conn.WriteString(err.Error() + "\n") return err } if _, err := conn.WriteString("\n"); err != nil { return err } return gob.NewEncoder(conn).Encode(response) }
func (t *rpcType) Poll(conn *srpc.Conn) error { defer conn.Flush() t.pollLock.Lock() defer t.pollLock.Unlock() var request sub.PollRequest var response sub.PollResponse decoder := gob.NewDecoder(conn) if err := decoder.Decode(&request); err != nil { _, err = conn.WriteString(err.Error() + "\n") return err } if _, err := conn.WriteString("\n"); err != nil { return err } response.NetworkSpeed = t.networkReaderContext.MaximumSpeed() t.rwLock.RLock() response.FetchInProgress = t.fetchInProgress response.UpdateInProgress = t.updateInProgress response.LastUpdateHadTriggerFailures = t.lastUpdateHadTriggerFailures t.rwLock.RUnlock() response.GenerationCount = t.fileSystemHistory.GenerationCount() fs := t.fileSystemHistory.FileSystem() if fs != nil && request.HaveGeneration != t.fileSystemHistory.GenerationCount() { response.FileSystemFollows = true } encoder := gob.NewEncoder(conn) if err := encoder.Encode(response); err != nil { return err } if response.FileSystemFollows { if err := fs.FileSystem.Encode(conn); err != nil { return err } if err := fs.ObjectCache.Encode(conn); err != nil { return err } } return nil }
func (objSrv *srpcType) GetObjects(conn *srpc.Conn) error { defer conn.Flush() var request objectserver.GetObjectsRequest var response objectserver.GetObjectsResponse if request.Exclusive { exclusive.Lock() defer exclusive.Unlock() } else { exclusive.RLock() defer exclusive.RUnlock() objSrv.getSemaphore <- true defer releaseSemaphore(objSrv.getSemaphore) } decoder := gob.NewDecoder(conn) encoder := gob.NewEncoder(conn) var err error if err = decoder.Decode(&request); err != nil { response.ResponseString = err.Error() return encoder.Encode(response) } response.ObjectSizes, err = objSrv.objectServer.CheckObjects(request.Hashes) if err != nil { response.ResponseString = err.Error() return encoder.Encode(response) } // First a quick check for existence. If any objects missing, fail request. for index, hash := range request.Hashes { if response.ObjectSizes[index] < 1 { response.ResponseString = fmt.Sprintf("unknown object: %x", hash) return encoder.Encode(response) } } objectsReader, err := objSrv.objectServer.GetObjects(request.Hashes) if err != nil { response.ResponseString = err.Error() return encoder.Encode(response) } defer objectsReader.Close() if err := encoder.Encode(response); err != nil { return err } conn.Flush() for _, hash := range request.Hashes { length, reader, err := objectsReader.NextObject() if err != nil { objSrv.logger.Println(err) return err } nCopied, err := io.Copy(conn.Writer, reader) reader.Close() if err != nil { objSrv.logger.Printf("Error copying:\t%s\n", err) return err } if nCopied != int64(length) { txt := fmt.Sprintf("Expected length: %d, got: %d for: %x", length, nCopied, hash) objSrv.logger.Printf(txt) return errors.New(txt) } } objSrv.logger.Printf("GetObjects() sent: %d objects\n", len(request.Hashes)) return nil }
func (t *srpcType) GetImageUpdates(conn *srpc.Conn) error { defer conn.Flush() t.logger.Println("New image replication client connected") t.incrementNumReplicationClients(true) defer t.incrementNumReplicationClients(false) addChannel := t.imageDataBase.RegisterAddNotifier() deleteChannel := t.imageDataBase.RegisterDeleteNotifier() mkdirChannel := t.imageDataBase.RegisterMakeDirectoryNotifier() defer t.imageDataBase.UnregisterAddNotifier(addChannel) defer t.imageDataBase.UnregisterDeleteNotifier(deleteChannel) defer t.imageDataBase.UnregisterMakeDirectoryNotifier(mkdirChannel) encoder := gob.NewEncoder(conn) directories := t.imageDataBase.ListDirectories() image.SortDirectories(directories) for _, directory := range directories { imageUpdate := imageserver.ImageUpdate{ Directory: &directory, Operation: imageserver.OperationMakeDirectory, } if err := encoder.Encode(imageUpdate); err != nil { t.logger.Println(err) return err } } for _, imageName := range t.imageDataBase.ListImages() { imageUpdate := imageserver.ImageUpdate{Name: imageName} if err := encoder.Encode(imageUpdate); err != nil { t.logger.Println(err) return err } } // Signal end of initial image list. if err := encoder.Encode(imageserver.ImageUpdate{}); err != nil { t.logger.Println(err) return err } if err := conn.Flush(); err != nil { t.logger.Println(err) return err } t.logger.Println( "Finished sending initial image list to replication client") closeChannel := getCloseNotifier(conn) for { select { case imageName := <-addChannel: if err := sendUpdate(encoder, imageName, imageserver.OperationAddImage); err != nil { t.logger.Println(err) return err } case imageName := <-deleteChannel: if err := sendUpdate(encoder, imageName, imageserver.OperationDeleteImage); err != nil { t.logger.Println(err) return err } case directory := <-mkdirChannel: if err := sendMakeDirectory(encoder, directory); err != nil { t.logger.Println(err) return err } case err := <-closeChannel: if err == io.EOF { t.logger.Println("Image replication client disconnected") return nil } t.logger.Println(err) return err } if err := conn.Flush(); err != nil { t.logger.Println(err) return err } } }