// Opens a connection to target (e.g. "foo.example.com:20081"), // sends msg followed by \r\n. // If keep_open == false, the connection is closed, otherwise it is // returned together with the corresponding security.Context. // The connection will be secured according to // the config settings. If a certificate is configured, the connection // will use TLS (and the key argument will be ignored). Otherwise, key // will be used to GosaEncrypt() the message before sending it over // a non-TLS connection. // If an error occurs, it is logged and nil is returned even if keep_open. func SendLnTo(target, msg, key string, keep_open bool) (net.Conn, *Context) { conn, err := net.Dial("tcp", target) if err != nil { util.Log(0, "ERROR! Could not connect to %v: %v\n", target, err) return nil, nil } if !keep_open { defer conn.Close() } // enable keep alive to avoid connections hanging forever in case of routing issues etc. err = conn.(*net.TCPConn).SetKeepAlive(true) if err != nil { util.Log(0, "ERROR! SetKeepAlive: %v", err) // This is not fatal => Don't abort send attempt } if config.TLSClientConfig != nil { conn.SetDeadline(time.Now().Add(config.TimeoutTLS)) // don't allow stalling on STARTTLS _, err = util.WriteAll(conn, starttls) if err != nil { util.Log(0, "ERROR! [SECURITY] Could not send STARTTLS to %v: %v\n", target, err) conn.Close() // even if keep_open return nil, nil } var no_deadline time.Time conn.SetDeadline(no_deadline) conn = tls.Client(conn, config.TLSClientConfig) } else { msg = GosaEncrypt(msg, key) } context := ContextFor(conn) if context == nil { conn.Close() // even if keep_open return nil, nil } err = util.SendLn(conn, msg, config.Timeout) if err != nil { util.Log(0, "ERROR! [SECURITY] While sending message to %v: %v\n", target, err) conn.Close() // even if keep_open return nil, nil } if keep_open { return conn, context } return nil, nil }
func write_via_buffer(w io.Writer, buf []byte, s string, n *int64, err *error) bool { i := 0 for i < len(s) { cnt := copy(buf, s[i:]) i += cnt nn, ee := util.WriteAll(w, buf[0:cnt]) *n += int64(nn) if ee != nil { *err = ee return false } } return true }
func writeimpl(stack_ *[]*asn1.CookStackElement, location string, overwrite bool) error { stack := *stack_ if len(stack) < 2 { return fmt.Errorf("%vwrite() called on stack with fewer than 2 elements", location) } data1, ok1 := stack[len(stack)-1].Value.([]byte) data2, ok2 := stack[len(stack)-2].Value.([]byte) file2, ok3 := stack[len(stack)-1].Value.(string) file1, ok4 := stack[len(stack)-2].Value.(string) if !((ok1 && ok4) || (ok2 && ok3) || (ok3 && ok4)) { return fmt.Errorf("%vwrite() called, but top 2 elements of stack are not a byte-array or string and a file name", location) } if ok2 { data1, file1 = data2, file2 } if ok3 && ok4 { data1, file1 = []byte(file1), file2 } flag := os.O_WRONLY | os.O_CREATE if overwrite { flag = flag | os.O_TRUNC } else { flag = flag | os.O_EXCL } f, err := os.OpenFile(file1, flag, 0644) if err == nil { defer f.Close() _, err = util.WriteAll(f, data1) } if err != nil { if overwrite || !os.IsExist(err) { return fmt.Errorf("%vwrite() error: %v", location, err) } } // Result value is the byte array. We need a result because cook() expects one. *stack_ = append(stack[0:len(stack)-2], &asn1.CookStackElement{Value: data1}) return nil }
// Performs a TFTP get for path at host (which may optionally include a port; // if it doesn't, port 69 is used). If timeout != 0 and any individual read // operation (not the whole get()!) takes longer than that time, get() will // return an error. All the read data is written into w. w is NOT closed! // NOTE: path is usually a relative path that does not start with "/". func Get(host, path string, w io.Writer, timeout time.Duration) error { blocksize := 512 buf := make([]byte, blocksize) if strings.Index(host, ":") < 0 { host = host + ":69" } remote_addr, err := net.ResolveUDPAddr("udp4", host) if err != nil { return err } local_addr, err := net.ResolveUDPAddr("udp4", ":0") if err != nil { return err } udp_conn, err := net.ListenUDP("udp4", local_addr) if err != nil { return err } defer udp_conn.Close() local_addr = udp_conn.LocalAddr().(*net.UDPAddr) n, remote_addr, err := writeReadUDP(udp_conn, remote_addr, []byte("\000\001"+path+"\000octet\000"), buf, min_wait_retry, max_wait_retry, timeout) if err != nil { return err } raddr := remote_addr ack := []byte{0, 4, 0, 0} blockid := 1 for { if buf[0] == 0 && buf[1] == 3 { // DATA if buf[2] != byte(blockid>>8) || buf[3] != byte(blockid&0xff) { if buf[2] == byte((blockid-1)>>8) && buf[3] == byte((blockid-1)&0xff) { // DATA is retransmission. Probably because ACK has been lost. => Ignore. We'll resend ACK further below } else { return fmt.Errorf("TFTP packet with incorrect sequence number") } } else { // correct blockid => Next packet _, err = util.WriteAll(w, buf[4:n]) if err != nil { return err } ack[2] = byte(blockid >> 8) ack[3] = byte(blockid & 0xff) if n < blocksize { // was the received packet the last one? _, err = udp_conn.WriteToUDP(ack, remote_addr) // send ACK and ... break //... stop } blockid++ } for { n, raddr, err = writeReadUDP(udp_conn, remote_addr, ack, buf, min_wait_retry, max_wait_retry, timeout) if raddr != nil && raddr.Port != remote_addr.Port { continue } // verify sender if err != nil { return err } break } } else { // not a DATA packet return fmt.Errorf("Unexpected TFTP packet. Expected DATA, got %#v...", string(buf[0:n])) } } return nil }
// Unit tests for the package github.com/mbenkmann/golib/util. func Util_test() { fmt.Printf("\n==== util ===\n\n") addr, err := util.Resolve("1.2.3.4", "") check(err, nil) check(addr, "1.2.3.4") addr, err = util.Resolve("1.2.3.4:5", "") check(err, nil) check(addr, "1.2.3.4:5") addr, err = util.Resolve("::1:5", "") check(err, nil) check(addr, "[::1:5]") addr, err = util.Resolve("localhost:65535", "") check(err, nil) check(addr, "127.0.0.1:65535") addr, err = util.Resolve("localhost", "") check(err, nil) check(addr, "127.0.0.1") addr, err = util.Resolve("::1", "") check(err, nil) check(addr, "127.0.0.1") addr, err = util.Resolve("[::1]", "") check(err, nil) check(addr, "127.0.0.1") addr, err = util.Resolve("[::1]:12345", "") check(err, nil) check(addr, "127.0.0.1:12345") addr, err = util.Resolve("localhost:65535", "foo") check(err, nil) check(addr, "foo:65535") addr, err = util.Resolve("localhost", "foo") check(err, nil) check(addr, "foo") addr, err = util.Resolve("::1", "foo") check(err, nil) check(addr, "foo") addr, err = util.Resolve("[::1]", "foo") check(err, nil) check(addr, "foo") addr, err = util.Resolve("[::1]:12345", "foo") check(err, nil) check(addr, "foo:12345") addr, err = util.Resolve("", "") check(hasWords(err, "no", "such", "host"), "") check(addr, "") addr, err = util.Resolve(":10", "") check(hasWords(err, "no", "such", "host"), "") check(addr, ":10") check(util.WaitForDNS(3*time.Second), true) h, _ := exec.Command("hostname").CombinedOutput() hostname := strings.TrimSpace(string(h)) ipp, _ := exec.Command("hostname", "-I").CombinedOutput() ips := strings.Fields(strings.TrimSpace(string(ipp))) addr, err = util.Resolve(hostname+":234", config.IP) check(err, nil) ip := "" for _, ip2 := range ips { if addr == ip2+":234" { ip = ip2 } } check(addr, ip+":234") testLogging() buf := make([]byte, 80) for i := range buf { buf[i] = byte(util_test_rng.Intn(26) + 'a') } crap1 := &crappyConnection1{} n, err := util.WriteAll(crap1, buf) check(string(*crap1), string(buf)) check(n, len(buf)) check(err, nil) crap2 := &crappyConnection2{} n, err = util.WriteAll(crap2, buf) check(string(*crap2), string(buf)) check(n, len(buf)) check(err, nil) stalled1 := &stalledConnection1{} n, err = util.WriteAll(stalled1, buf) check(string(*stalled1), string(buf[0:16])) check(n, 16) check(err, io.ErrShortWrite) stalled2 := &stalledConnection2{} n, err = util.WriteAll(stalled2, buf) check(string(*stalled2), string(buf[0:16])) check(n, 16) check(err, io.ErrShortWrite) broken := &brokenConnection{} n, err = util.WriteAll(broken, buf) check(string(*broken), string(buf[0:16])) check(n, 16) check(err, io.ErrClosedPipe) panicker := func() { foobar = "bar" panic("foo") } var buffy bytes.Buffer util.LoggersSuspend() util.LoggerAdd(&buffy) defer util.LoggersRestore() util.WithPanicHandler(panicker) time.Sleep(200 * time.Millisecond) // make sure log message is written out check(foobar, "bar") check(len(buffy.String()) > 10, true) listener, err := net.Listen("tcp", "127.0.0.1:39390") if err != nil { panic(err) } go func() { r, err := listener.Accept() if err != nil { panic(err) } buf := make([]byte, 1) r.Read(buf) time.Sleep(10 * time.Second) r.Read(buf) }() long := make([]byte, 10000000) longstr := string(long) buffy.Reset() t0 := time.Now() util.SendLnTo("127.0.0.1:39390", longstr, 5*time.Second) duration := time.Since(t0) check(duration > 4*time.Second && duration < 6*time.Second, true) time.Sleep(200 * time.Millisecond) // make sure log message is written out check(strings.Contains(buffy.String(), "ERROR"), true) go func() { conn, err := listener.Accept() if err != nil { panic(err) } ioutil.ReadAll(conn) }() long = make([]byte, 10000000) longstr = string(long) buffy.Reset() t0 = time.Now() util.SendLnTo("127.0.0.1:39390", longstr, 5*time.Second) duration = time.Since(t0) check(duration < 2*time.Second, true) time.Sleep(200 * time.Millisecond) // make sure log message is written out check(buffy.String(), "") // Test that ReadLn() times out properly go func() { _, err := net.Dial("tcp", "127.0.0.1:39390") if err != nil { panic(err) } }() conn, err := listener.Accept() if err != nil { panic(err) } t0 = time.Now() st, err := util.ReadLn(conn, 5*time.Second) duration = time.Since(t0) check(duration > 4*time.Second && duration < 6*time.Second, true) check(st, "") check(hasWords(err, "timeout"), "") // Test that ReadLn() returns io.EOF if last line not terminated by \n go func() { conn, err := net.Dial("tcp", "127.0.0.1:39390") if err != nil { panic(err) } conn.Write([]byte("foo\r")) conn.Close() }() conn, err = listener.Accept() if err != nil { panic(err) } st, err = util.ReadLn(conn, 5*time.Second) check(err, io.EOF) check(st, "foo") go func() { conn, err := net.Dial("tcp", "127.0.0.1:39390") if err != nil { panic(err) } conn.Write([]byte("\r\r\n\rfo\ro\nbar\r\nfoxtrott")) conn.Close() }() conn, err = listener.Accept() if err != nil { panic(err) } // Test proper trimming of multiple \r st, err = util.ReadLn(conn, 0) check(err, nil) check(st, "") // Test that the empty first line has actually been read // and that the next ReadLn() reads the 2nd line // Also test that negative timeouts work the same as timeout==0 // Also test that \r is not trimmed at start and within line. st, err = util.ReadLn(conn, -1*time.Second) check(err, nil) check(st, "\rfo\ro") // Check 3rd line st, err = util.ReadLn(conn, 0) check(err, nil) check(st, "bar") // Check 4th line and io.EOF error st, err = util.ReadLn(conn, 0) check(err, io.EOF) check(st, "foxtrott") // Test that delayed reads work with timeout==0 go func() { conn, err := net.Dial("tcp", "127.0.0.1:39390") if err != nil { panic(err) } time.Sleep(1 * time.Second) _, err = conn.Write([]byte("foo\r\n")) if err != nil { panic(err) } time.Sleep(2 * time.Second) }() conn, err = listener.Accept() if err != nil { panic(err) } t0 = time.Now() st, err = util.ReadLn(conn, time.Duration(0)) duration = time.Since(t0) check(duration < 2*time.Second, true) check(duration > 800*time.Millisecond, true) check(err, nil) check(st, "foo") counter := util.Counter(13) var b1 UintArray = make([]uint64, 100) var b2 UintArray = make([]uint64, 100) done := make(chan bool) fill := func(b UintArray) { for i := 0; i < 100; i++ { b[i] = <-counter time.Sleep(1 * time.Millisecond) } done <- true } go fill(b1) go fill(b2) <-done <-done check(sort.IsSorted(&b1), true) check(sort.IsSorted(&b2), true) var b3 UintArray = make([]uint64, 200) i := 0 j := 0 k := 0 for i < 100 || j < 100 { if i == 100 { b3[k] = b2[j] j++ k++ continue } if j == 100 { b3[k] = b1[i] i++ k++ continue } if b1[i] == b2[j] { check(b1[i] != b2[j], true) break } if b1[i] < b2[j] { b3[k] = b1[i] i++ } else { b3[k] = b2[j] j++ } k++ } one_streak := true b5 := make([]uint64, 200) for i := 0; i < 200; i++ { if i < 100 && b1[i] != uint64(13+i) && b2[i] != uint64(13+i) { one_streak = false } b5[i] = uint64(13 + i) } check(b3, b5) check(one_streak, false) // Check whether goroutines were actually executed concurrently rather than in sequence tempdir, err := ioutil.TempDir("", "util-test-") if err != nil { panic(err) } defer os.RemoveAll(tempdir) fpath := tempdir + "/foo.log" logfile := util.LogFile(fpath) check(logfile.Close(), nil) n, err = util.WriteAll(logfile, []byte("Test")) check(err, nil) check(n, 4) check(logfile.Close(), nil) n, err = util.WriteAll(logfile, []byte("12")) check(err, nil) check(n, 2) n, err = util.WriteAll(logfile, []byte("3")) check(err, nil) check(n, 1) check(os.Rename(fpath, fpath+".old"), nil) n, err = util.WriteAll(logfile, []byte("Fo")) check(err, nil) check(n, 2) f2, _ := os.OpenFile(fpath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) f2.Write([]byte("o")) f2.Close() n, err = util.WriteAll(logfile, []byte("bar")) check(err, nil) check(n, 3) check(logfile.Close(), nil) data, err := ioutil.ReadFile(fpath) check(err, nil) if err == nil { check(string(data), "Foobar") } data, err = ioutil.ReadFile(fpath + ".old") check(err, nil) if err == nil { check(string(data), "Test123") } test_time := time.Date(2013, time.January, 20, 14, 7, 21, 0, time.Local) check(util.MakeTimestamp(test_time), "20130120140721") test_time = time.Date(2013, time.January, 20, 14, 7, 21, 0, time.UTC) check(util.MakeTimestamp(test_time), "20130120140721") test_time = time.Date(2013, time.January, 20, 14, 7, 21, 0, time.FixedZone("Fooistan", 45678)) check(util.MakeTimestamp(test_time), "20130120140721") illegal := time.Unix(0, 0) buffy.Reset() check(util.ParseTimestamp(""), illegal) time.Sleep(200 * time.Millisecond) // make sure log message is written out check(strings.Contains(buffy.String(), "ERROR"), true) buffy.Reset() check(util.ParseTimestamp("20139910101010"), illegal) time.Sleep(200 * time.Millisecond) // make sure log message is written out check(strings.Contains(buffy.String(), "ERROR"), true) check(util.ParseTimestamp("20131110121314"), time.Date(2013, time.November, 10, 12, 13, 14, 0, time.Local)) check(util.MakeTimestamp(util.ParseTimestamp(util.MakeTimestamp(test_time))), util.MakeTimestamp(test_time)) test_time = test_time.Add(2400 * time.Hour) check(util.MakeTimestamp(util.ParseTimestamp(util.MakeTimestamp(test_time))), util.MakeTimestamp(test_time)) test_time = test_time.Add(2400 * time.Hour) check(util.MakeTimestamp(util.ParseTimestamp(util.MakeTimestamp(test_time))), util.MakeTimestamp(test_time)) test_time = test_time.Add(2400 * time.Hour) check(util.MakeTimestamp(util.ParseTimestamp(util.MakeTimestamp(test_time))), util.MakeTimestamp(test_time)) test_time = test_time.Add(2400 * time.Hour) check(util.MakeTimestamp(util.ParseTimestamp(util.MakeTimestamp(test_time))), util.MakeTimestamp(test_time)) diff := time.Since(util.ParseTimestamp(util.MakeTimestamp(time.Now()))) if diff < time.Second { diff = 0 } check(diff, time.Duration(0)) t0 = time.Now() util.WaitUntil(t0.Add(-10 * time.Second)) util.WaitUntil(t0.Add(-100 * time.Minute)) dur := time.Now().Sub(t0) if dur < 1*time.Second { dur = 0 } check(dur, 0) t0 = time.Now() util.WaitUntil(t0.Add(1200 * time.Millisecond)) dur = time.Now().Sub(t0) if dur >= 1200*time.Millisecond && dur <= 1300*time.Millisecond { dur = 1200 * time.Millisecond } check(dur, 1200*time.Millisecond) mess := "WaitUntil(Jesus first birthday) takes forever" go func() { util.WaitUntil(time.Date(1, time.December, 25, 0, 0, 0, 0, time.UTC)) mess = "" }() time.Sleep(100 * time.Millisecond) check(mess, "") mess = "WaitUntil(1000-11-10 00:00:00) takes forever" go func() { util.WaitUntil(time.Date(1000, time.October, 11, 0, 0, 0, 0, time.UTC)) mess = "" }() time.Sleep(100 * time.Millisecond) check(mess, "") testBase64() }
// Handles one or more messages received over conn. Each message is a single // line terminated by \n. The message may be encrypted as by security.GosaEncrypt(). func handle_request(tcpconn *net.TCPConn) { defer tcpconn.Close() defer atomic.AddInt32(&ActiveConnections, -1) // defer util.Log(2, "DEBUG! Connection to %v closed", tcpconn.RemoteAddr()) // util.Log(2, "DEBUG! Connection from %v", tcpconn.RemoteAddr()) var err error err = tcpconn.SetKeepAlive(true) if err != nil { util.Log(0, "ERROR! SetKeepAlive: %v", err) } var buf bytes.Buffer defer buf.Reset() readbuf := make([]byte, 4096) var conn net.Conn conn = tcpconn n := 1 if config.TLSServerConfig != nil { // If TLS is required, we need to see a STARTTLS before the timeout. // If TLS is optional we need to accept idle connections for backwards compatibility if config.TLSRequired { conn.SetDeadline(time.Now().Add(config.TimeoutTLS)) } for i := range starttls { n, err = conn.Read(readbuf[0:1]) if n == 0 { if i != 0 { // Do not log an error for a port scan that just opens a connection and closes it immediately util.Log(0, "ERROR! Read error while looking for STARTTLS from %v: %v", conn.RemoteAddr(), err) } return } buf.Write(readbuf[0:1]) if readbuf[0] == '\r' && starttls[i] == '\n' { // Read the \n that must follow \r (we don't support lone CR line endings) conn.Read(readbuf[0:1]) // ignore error. It will pop up again further down the line. } if readbuf[0] != starttls[i] { if config.TLSRequired { util.Log(0, "ERROR! No STARTTLS from %v, but TLS is required", conn.RemoteAddr()) util.WriteAll(conn, []byte(message.ErrorReply("STARTTLS is required to connect"))) return } break } if readbuf[0] == '\n' { buf.Reset() // purge STARTTLS\n from buffer conn = tls.Server(conn, config.TLSServerConfig) } } } context := security.ContextFor(conn) if context == nil { return } for n != 0 { //util.Log(2, "DEBUG! Receiving from %v", conn.RemoteAddr()) n, err = conn.Read(readbuf) if err != nil && err != io.EOF { util.Log(0, "ERROR! Read: %v", err) } if err == io.EOF { util.Log(2, "DEBUG! Connection closed by %v", conn.RemoteAddr()) } if n == 0 && err == nil { util.Log(0, "ERROR! Read 0 bytes but no error reported") } // Find complete lines terminated by '\n' and process them. for start := 0; ; { eol := start for ; eol < n; eol++ { if readbuf[eol] == '\n' { break } } // no \n found, append to buf and continue reading if eol == n { buf.Write(readbuf[start:n]) break } // append to rest of line to buffered contents buf.Write(readbuf[start:eol]) start = eol + 1 buf.TrimSpace() // process the message and get a reply (if applicable) if buf.Len() > 0 { // ignore empty lines request_start := time.Now() reply, disconnect := message.ProcessEncryptedMessage(&buf, context) buf.Reset() request_time := time.Since(request_start) RequestProcessingTimes.Push(request_time) request_time -= RequestProcessingTimes.Next().(time.Duration) atomic.AddInt64(&message.RequestProcessingTime, int64(request_time)) if reply.Len() > 0 { util.Log(2, "DEBUG! Sending %v bytes reply to %v", reply.Len(), conn.RemoteAddr()) var deadline time.Time // zero value means "no deadline" if config.Timeout >= 0 { deadline = time.Now().Add(config.Timeout) } conn.SetWriteDeadline(deadline) _, err := util.WriteAll(conn, reply.Bytes()) if err != nil { util.Log(0, "ERROR! WriteAll: %v", err) } reply.Reset() util.WriteAll(conn, []byte{'\r', '\n'}) } if disconnect { util.Log(1, "INFO! Forcing disconnect of %v because of error", conn.RemoteAddr()) return } if Shutdown { util.Log(1, "INFO! Forcing disconnect of %v because of go-susi shutdown", conn.RemoteAddr()) return } } } } if buf.Len() != 0 { util.Log(0, "ERROR! Incomplete message from %v (i.e. not terminated by \"\\n\") of %v bytes: %v", conn.RemoteAddr(), buf.Len(), buf.String()) } }