func TestCapabilitiesExchangeTLS(t *testing.T) { errc := make(chan error, 1) smux := diam.NewServeMux() smux.Handle("CER", handleCER(errc, true)) srv := diamtest.NewUnstartedServer(smux, nil) tm := 100 * time.Millisecond srv.Config.ReadTimeout = tm srv.Config.WriteTimeout = tm srv.StartTLS() defer srv.Close() wait := make(chan struct{}) cmux := diam.NewServeMux() cmux.Handle("CEA", handleCEA(errc, wait)) cli, err := diam.DialTLS(srv.Addr, "", "", cmux, nil) if err != nil { t.Fatal(err) } sendCER(cli) select { case <-wait: case err := <-errc: t.Fatal(err) case <-time.After(time.Second): t.Fatal("Timed out: no CER or CEA received") } }
func TestCapabilitiesExchange(t *testing.T) { errc := make(chan error, 1) smux := diam.NewServeMux() smux.Handle("CER", handleCER(errc, false)) srv := diamtest.NewServer(smux, nil) defer srv.Close() wait := make(chan struct{}) cmux := diam.NewServeMux() cmux.Handle("CEA", handleCEA(errc, wait)) cli, err := diam.Dial(srv.Addr, cmux, nil) if err != nil { t.Fatal(err) } sendCER(cli) select { case <-wait: case err := <-errc: t.Fatal(err) case err := <-smux.ErrorReports(): t.Fatal(err) case <-time.After(time.Second): t.Fatal("Timed out: no CER or CEA received") } }
func TestClient_Handshake_RetransmitTimeout(t *testing.T) { mux := diam.NewServeMux() var retransmits uint32 mux.HandleFunc("CER", func(c diam.Conn, m *diam.Message) { // Do nothing to force timeout. atomic.AddUint32(&retransmits, 1) }) srv := diamtest.NewServer(mux, dict.Default) defer srv.Close() cli := &Client{ Handler: New(clientSettings), MaxRetransmits: 3, RetransmitInterval: time.Millisecond, AcctApplicationID: []*diam.AVP{ diam.NewAVP(avp.AcctApplicationID, avp.Mbit, 0, datatype.Unsigned32(0)), }, } _, err := cli.Dial(srv.Address) if err == nil { t.Fatal("Unexpected CER worked") } if err != ErrHandshakeTimeout { t.Fatal(err) } if n := atomic.LoadUint32(&retransmits); n != 4 { t.Fatalf("Unexpected # of retransmits. Want 4, have %d", n) } }
func TestClient_Handshake_FailedResultCode(t *testing.T) { mux := diam.NewServeMux() mux.HandleFunc("CER", func(c diam.Conn, m *diam.Message) { cer := new(smparser.CER) if _, err := cer.Parse(m); err != nil { panic(err) } a := m.Answer(diam.NoCommonApplication) a.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost) a.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm) a.AddAVP(cer.OriginStateID) a.AddAVP(cer.AcctApplicationID[0]) // The one we send below. a.WriteTo(c) }) srv := diamtest.NewServer(mux, dict.Default) defer srv.Close() cli := &Client{ Handler: New(clientSettings), AcctApplicationID: []*diam.AVP{ diam.NewAVP(avp.AcctApplicationID, avp.Mbit, 0, datatype.Unsigned32(0)), }, } _, err := cli.Dial(srv.Address) if err == nil { t.Fatal("Unexpected CER worked") } e, ok := err.(*ErrFailedResultCode) if !ok { t.Fatal(err) } if !strings.Contains(e.Error(), "failed Result-Code AVP") { t.Fatal(e.Error()) } }
// GetBridge returns the Bridge object for a given client, if it exists. // Otherwise GetBridge connects to the upstream server and set up the // bridge with the client, returning the newly created Bridge object. func GetBridge(c diam.Conn) *Bridge { liveMu.RLock() if b, ok := liveBridge[c.RemoteAddr().String()]; ok { liveMu.RUnlock() return b } liveMu.RUnlock() liveMu.Lock() defer liveMu.Unlock() b := &Bridge{ Client: make(chan *diam.Message), Server: make(chan *diam.Message), } liveBridge[c.RemoteAddr().String()] = b // Prepare for the upstream connection. mux := diam.NewServeMux() mux.HandleFunc("ALL", func(c diam.Conn, m *diam.Message) { // Forward incoming messages to the client. b.Client <- m }) // Connect to upstream server. s, err := diam.Dial(upstreamAddr, mux, nil) if err != nil { return nil } log.Printf("Creating bridge from %s to %s", c.RemoteAddr().String(), s.RemoteAddr().String()) go Pump(c, s, b.Client, b.Server) go Pump(s, c, b.Server, b.Client) return b }
func TestHandleDWR_Fail(t *testing.T) { sm := New(serverSettings) srv := diamtest.NewServer(sm, dict.Default) defer srv.Close() mc := make(chan *diam.Message, 1) mux := diam.NewServeMux() mux.HandleFunc("CEA", func(c diam.Conn, m *diam.Message) { mc <- m }) mux.HandleFunc("DWA", func(c diam.Conn, m *diam.Message) { mc <- m }) cli, err := diam.Dial(srv.Addr, mux, dict.Default) if err != nil { t.Fatal(err) } defer cli.Close() // Send CER first. m := diam.NewRequest(diam.CapabilitiesExchange, 1001, dict.Default) m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost) m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm) m.NewAVP(avp.HostIPAddress, avp.Mbit, 0, localhostAddress) m.NewAVP(avp.VendorID, avp.Mbit, 0, clientSettings.VendorID) m.NewAVP(avp.ProductName, 0, 0, clientSettings.ProductName) m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1)) m.NewAVP(avp.AcctApplicationID, avp.Mbit, 0, datatype.Unsigned32(1001)) m.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, clientSettings.FirmwareRevision) _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } select { case resp := <-mc: if !testResultCode(resp, diam.Success) { t.Fatalf("Unexpected result code for CEA.\n%s", resp) } case err := <-mux.ErrorReports(): t.Fatal(err) case <-time.After(time.Second): t.Fatal("No CEA received") } // Send broken DWR (missing Origin-Host, etc). m = diam.NewRequest(diam.DeviceWatchdog, 0, dict.Default) _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } select { case err := <-sm.ErrorReports(): if err.Error != smparser.ErrMissingOriginHost { t.Fatalf("Unexpected error. Want ErrMissingOriginHost, have %#v", err.Error) } case err := <-mux.ErrorReports(): t.Fatal(err) case <-time.After(time.Second): t.Fatal("No DWA received") } }
// New creates and initializes a new StateMachine for clients or servers. func New(settings *Settings) *StateMachine { sm := &StateMachine{ cfg: settings, mux: diam.NewServeMux(), hsNotifyc: make(chan diam.Conn), } sm.mux.Handle("CER", handleCER(sm)) sm.mux.Handle("DWR", handshakeOK(handleDWR(sm))) return sm }
func TestHandleCER_VS_Auth_Fail(t *testing.T) { sm := New(serverSettings) srv := diamtest.NewServer(sm, dict.Default) defer srv.Close() mc := make(chan *diam.Message, 1) mux := diam.NewServeMux() mux.HandleFunc("CEA", func(c diam.Conn, m *diam.Message) { mc <- m }) cli, err := diam.Dial(srv.Address, mux, dict.Default) if err != nil { t.Fatal(err) } defer cli.Close() m := diam.NewRequest(diam.CapabilitiesExchange, 0, dict.Default) m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost) m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm) m.NewAVP(avp.HostIPAddress, avp.Mbit, 0, localhostAddress) m.NewAVP(avp.VendorID, avp.Mbit, 0, clientSettings.VendorID) m.NewAVP(avp.ProductName, 0, 0, clientSettings.ProductName) m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1)) m.NewAVP(avp.VendorSpecificApplicationID, avp.Mbit, 0, &diam.GroupedAVP{ AVP: []*diam.AVP{ diam.NewAVP(avp.AuthApplicationID, avp.Mbit, 0, datatype.Unsigned32(1000)), }, }) m.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, clientSettings.FirmwareRevision) _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } select { case resp := <-mc: if !testResultCode(resp, diam.NoCommonApplication) { t.Fatalf("Unexpected result code.\n%s", resp) } case err := <-mux.ErrorReports(): t.Fatal(err) case <-time.After(time.Second): t.Fatal("No message received") } }
func TestClient_Handshake_FailParseCEA(t *testing.T) { mux := diam.NewServeMux() mux.HandleFunc("CER", func(c diam.Conn, m *diam.Message) { a := m.Answer(diam.Success) // Missing Origin-Host and other mandatory AVPs. a.WriteTo(c) }) srv := diamtest.NewServer(mux, dict.Default) defer srv.Close() cli := &Client{ Handler: New(clientSettings), AcctApplicationID: []*diam.AVP{ diam.NewAVP(avp.AcctApplicationID, avp.Mbit, 0, datatype.Unsigned32(0)), }, } _, err := cli.Dial(srv.Address) if err != smparser.ErrMissingOriginHost { t.Fatal(err) } }
// TestStateMachine establishes a connection with a test server and // sends a Re-Auth-Request message to ensure the handshake was // completed and that the RAR handler has context from the peer. func TestStateMachine(t *testing.T) { sm := New(serverSettings) if sm.Settings() != serverSettings { t.Fatal("Invalid settings") } srv := diamtest.NewServer(sm, dict.Default) defer srv.Close() // CER handlers are ignored by the state machine. // Using Handle instead of HandleFunc to exercise that code. sm.Handle("CER", func() diam.HandlerFunc { return func(c diam.Conn, m *diam.Message) {} }()) select { case err := <-sm.ErrorReports(): if err == nil { t.Fatal("Expecting error that didn't occur") } case <-time.After(time.Second): t.Fatal("Timed out waiting for error") } // RAR for our test. mc := make(chan *diam.Message, 1) sm.HandleFunc("RAR", func(c diam.Conn, m *diam.Message) { mc <- m }) mux := diam.NewServeMux() mux.HandleFunc("CEA", func(c diam.Conn, m *diam.Message) { mc <- m }) mux.HandleFunc("DWA", func(c diam.Conn, m *diam.Message) { mc <- m }) cli, err := diam.Dial(srv.Address, mux, dict.Default) if err != nil { t.Fatal(err) } defer cli.Close() // Send CER first, wait for CEA. m := diam.NewRequest(diam.CapabilitiesExchange, 1001, dict.Default) m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost) m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm) m.NewAVP(avp.HostIPAddress, avp.Mbit, 0, localhostAddress) m.NewAVP(avp.VendorID, avp.Mbit, 0, clientSettings.VendorID) m.NewAVP(avp.ProductName, 0, 0, clientSettings.ProductName) m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1)) m.NewAVP(avp.AcctApplicationID, avp.Mbit, 0, datatype.Unsigned32(1001)) m.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, clientSettings.FirmwareRevision) _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } // Retransmit CER. _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } // Test CEA Result-Code. select { case resp := <-mc: if !testResultCode(resp, diam.Success) { t.Fatalf("Unexpected result code.\n%s", resp) } case err := <-sm.ErrorReports(): t.Fatal(err) case err := <-mux.ErrorReports(): t.Fatal(err) case <-time.After(time.Second): t.Fatal("No CEA message received") } // Send RAR. m = diam.NewRequest(diam.ReAuth, 0, dict.Default) m.NewAVP(avp.SessionID, avp.Mbit, 0, datatype.OctetString("foobar")) m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost) m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm) m.NewAVP(avp.AuthApplicationID, avp.Mbit, 0, datatype.Unsigned32(1002)) m.NewAVP(avp.ReAuthRequestType, avp.Mbit, 0, datatype.Unsigned32(0)) m.NewAVP(avp.UserName, avp.Mbit, 0, datatype.OctetString("test")) m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1)) _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } // Ensure the RAR was handled by the state machine. select { case <-mc: // All good. case err := <-sm.ErrorReports(): t.Fatal(err) case err := <-mux.ErrorReports(): t.Fatal(err) case <-time.After(time.Second): t.Fatal("No RAR message received") } // Send DWR. m = diam.NewRequest(diam.DeviceWatchdog, 0, dict.Default) m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost) m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm) _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } // Ensure the DWR was handled by the state machine. select { case <-mc: // All good. case err := <-sm.ErrorReports(): t.Fatal(err) case err := <-mux.ErrorReports(): t.Fatal(err) case <-time.After(time.Second): t.Fatal("No DWR message received") } }