func ProxyChannels(logger lager.Logger, conn ssh.Conn, channels <-chan ssh.NewChannel) { logger = logger.Session("proxy-channels") logger.Info("started") defer logger.Info("completed") defer conn.Close() for newChannel := range channels { logger.Info("new-channel", lager.Data{ "channelType": newChannel.ChannelType(), "extraData": newChannel.ExtraData(), }) targetChan, targetReqs, err := conn.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData()) if err != nil { logger.Error("failed-to-open-channel", err) if openErr, ok := err.(*ssh.OpenChannelError); ok { newChannel.Reject(openErr.Reason, openErr.Message) } else { newChannel.Reject(ssh.ConnectionFailed, err.Error()) } continue } sourceChan, sourceReqs, err := newChannel.Accept() if err != nil { targetChan.Close() continue } toTargetLogger := logger.Session("to-target") toSourceLogger := logger.Session("to-source") go func() { helpers.Copy(toTargetLogger, nil, targetChan, sourceChan) targetChan.CloseWrite() }() go func() { helpers.Copy(toSourceLogger, nil, sourceChan, targetChan) sourceChan.CloseWrite() }() go ProxyRequests(toTargetLogger, newChannel.ChannelType(), sourceReqs, targetChan) go ProxyRequests(toSourceLogger, newChannel.ChannelType(), targetReqs, sourceChan) } }
func (sess *session) runWithPty(command *exec.Cmd) error { logger := sess.logger.Session("run-with-pty") ptyMaster, ptySlave, err := pty.Open() if err != nil { logger.Error("failed-to-open-pty", err) return err } sess.ptyMaster = ptyMaster defer ptySlave.Close() command.Stdout = ptySlave command.Stdin = ptySlave command.Stderr = ptySlave command.SysProcAttr = &syscall.SysProcAttr{ Setctty: true, Setsid: true, } setTerminalAttributes(logger, ptyMaster, sess.ptyRequest.Modelist) setWindowSize(logger, ptyMaster, sess.ptyRequest.Columns, sess.ptyRequest.Rows) sess.wg.Add(1) go helpers.Copy(logger.Session("to-pty"), nil, ptyMaster, sess.channel) go func() { helpers.Copy(logger.Session("from-pty"), &sess.wg, sess.channel, ptyMaster) sess.channel.CloseWrite() }() err = sess.runner.Start(command) if err == nil { sess.keepaliveStopCh = make(chan struct{}) go sess.keepalive(command, sess.keepaliveStopCh) } return err }
logger = lagertest.NewTestLogger("test") }) Describe("Copy", func() { var reader io.Reader var fakeWriter *fake_io.FakeWriter var wg *sync.WaitGroup BeforeEach(func() { reader = strings.NewReader("message") fakeWriter = &fake_io.FakeWriter{} wg = nil }) JustBeforeEach(func() { helpers.Copy(logger, wg, fakeWriter, reader) }) It("copies from source to target", func() { Expect(fakeWriter.WriteCallCount()).To(Equal(1)) Expect(string(fakeWriter.WriteArgsForCall(0))).To(Equal("message")) }) Context("when a wait group is provided", func() { BeforeEach(func() { wg = &sync.WaitGroup{} wg.Add(1) }) It("calls done before returning", func() { wg.Wait()