func TestCapabilitiesExchange(t *testing.T) { errc := make(chan error, 1) smux := diam.NewServeMux() smux.Handle("CER", handleCER(errc, false)) srv := diamtest.NewServer(smux, nil) defer srv.Close() wait := make(chan struct{}) cmux := diam.NewServeMux() cmux.Handle("CEA", handleCEA(errc, wait)) cli, err := diam.Dial(srv.Addr, cmux, nil) if err != nil { t.Fatal(err) } sendCER(cli) select { case <-wait: case err := <-errc: t.Fatal(err) case err := <-smux.ErrorReports(): t.Fatal(err) case <-time.After(time.Second): t.Fatal("Timed out: no CER or CEA received") } }
// GetBridge returns the Bridge object for a given client, if it exists. // Otherwise GetBridge connects to the upstream server and set up the // bridge with the client, returning the newly created Bridge object. func GetBridge(c diam.Conn) *Bridge { liveMu.RLock() if b, ok := liveBridge[c.RemoteAddr().String()]; ok { liveMu.RUnlock() return b } liveMu.RUnlock() liveMu.Lock() defer liveMu.Unlock() b := &Bridge{ Client: make(chan *diam.Message), Server: make(chan *diam.Message), } liveBridge[c.RemoteAddr().String()] = b // Prepare for the upstream connection. mux := diam.NewServeMux() mux.HandleFunc("ALL", func(c diam.Conn, m *diam.Message) { // Forward incoming messages to the client. b.Client <- m }) // Connect to upstream server. s, err := diam.Dial(upstreamAddr, mux, nil) if err != nil { return nil } log.Printf("Creating bridge from %s to %s", c.RemoteAddr().String(), s.RemoteAddr().String()) go Pump(c, s, b.Client, b.Server) go Pump(s, c, b.Server, b.Client) return b }
func TestHandleDWR_Fail(t *testing.T) { sm := New(serverSettings) srv := diamtest.NewServer(sm, dict.Default) defer srv.Close() mc := make(chan *diam.Message, 1) mux := diam.NewServeMux() mux.HandleFunc("CEA", func(c diam.Conn, m *diam.Message) { mc <- m }) mux.HandleFunc("DWA", func(c diam.Conn, m *diam.Message) { mc <- m }) cli, err := diam.Dial(srv.Addr, mux, dict.Default) if err != nil { t.Fatal(err) } defer cli.Close() // Send CER first. m := diam.NewRequest(diam.CapabilitiesExchange, 1001, dict.Default) m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost) m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm) m.NewAVP(avp.HostIPAddress, avp.Mbit, 0, localhostAddress) m.NewAVP(avp.VendorID, avp.Mbit, 0, clientSettings.VendorID) m.NewAVP(avp.ProductName, 0, 0, clientSettings.ProductName) m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1)) m.NewAVP(avp.AcctApplicationID, avp.Mbit, 0, datatype.Unsigned32(1001)) m.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, clientSettings.FirmwareRevision) _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } select { case resp := <-mc: if !testResultCode(resp, diam.Success) { t.Fatalf("Unexpected result code for CEA.\n%s", resp) } case err := <-mux.ErrorReports(): t.Fatal(err) case <-time.After(time.Second): t.Fatal("No CEA received") } // Send broken DWR (missing Origin-Host, etc). m = diam.NewRequest(diam.DeviceWatchdog, 0, dict.Default) _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } select { case err := <-sm.ErrorReports(): if err.Error != smparser.ErrMissingOriginHost { t.Fatalf("Unexpected error. Want ErrMissingOriginHost, have %#v", err.Error) } case err := <-mux.ErrorReports(): t.Fatal(err) case <-time.After(time.Second): t.Fatal("No DWA received") } }
func TestHandleCER_HandshakeMetadata(t *testing.T) { sm := New(serverSettings) srv := diamtest.NewServer(sm, dict.Default) defer srv.Close() hsc := make(chan diam.Conn, 1) cli, err := diam.Dial(srv.Address, nil, dict.Default) if err != nil { t.Fatal(err) } defer cli.Close() ready := make(chan struct{}) go func() { close(ready) c := <-sm.HandshakeNotify() hsc <- c }() <-ready m := diam.NewRequest(diam.CapabilitiesExchange, 1001, dict.Default) m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost) m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm) m.NewAVP(avp.HostIPAddress, avp.Mbit, 0, localhostAddress) m.NewAVP(avp.VendorID, avp.Mbit, 0, clientSettings.VendorID) m.NewAVP(avp.ProductName, 0, 0, clientSettings.ProductName) m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1)) m.NewAVP(avp.AcctApplicationID, avp.Mbit, 0, datatype.Unsigned32(1001)) m.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, clientSettings.FirmwareRevision) _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } select { case c := <-hsc: ctx := c.Context() meta, ok := smpeer.FromContext(ctx) if !ok { t.Fatal("Handshake ok but no context/metadata found") } if meta.OriginHost != clientSettings.OriginHost { t.Fatalf("Unexpected OriginHost. Want %q, have %q", clientSettings.OriginHost, meta.OriginHost) } if meta.OriginRealm != clientSettings.OriginRealm { t.Fatalf("Unexpected OriginRealm. Want %q, have %q", clientSettings.OriginRealm, meta.OriginRealm) } } }
func TestHandleCER_VS_Auth_Fail(t *testing.T) { sm := New(serverSettings) srv := diamtest.NewServer(sm, dict.Default) defer srv.Close() mc := make(chan *diam.Message, 1) mux := diam.NewServeMux() mux.HandleFunc("CEA", func(c diam.Conn, m *diam.Message) { mc <- m }) cli, err := diam.Dial(srv.Address, mux, dict.Default) if err != nil { t.Fatal(err) } defer cli.Close() m := diam.NewRequest(diam.CapabilitiesExchange, 0, dict.Default) m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost) m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm) m.NewAVP(avp.HostIPAddress, avp.Mbit, 0, localhostAddress) m.NewAVP(avp.VendorID, avp.Mbit, 0, clientSettings.VendorID) m.NewAVP(avp.ProductName, 0, 0, clientSettings.ProductName) m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1)) m.NewAVP(avp.VendorSpecificApplicationID, avp.Mbit, 0, &diam.GroupedAVP{ AVP: []*diam.AVP{ diam.NewAVP(avp.AuthApplicationID, avp.Mbit, 0, datatype.Unsigned32(1000)), }, }) m.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, clientSettings.FirmwareRevision) _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } select { case resp := <-mc: if !testResultCode(resp, diam.NoCommonApplication) { t.Fatalf("Unexpected result code.\n%s", resp) } case err := <-mux.ErrorReports(): t.Fatal(err) case <-time.After(time.Second): t.Fatal("No message received") } }
// TestStateMachine establishes a connection with a test server and // sends a Re-Auth-Request message to ensure the handshake was // completed and that the RAR handler has context from the peer. func TestStateMachine(t *testing.T) { sm := New(serverSettings) if sm.Settings() != serverSettings { t.Fatal("Invalid settings") } srv := diamtest.NewServer(sm, dict.Default) defer srv.Close() // CER handlers are ignored by the state machine. // Using Handle instead of HandleFunc to exercise that code. sm.Handle("CER", func() diam.HandlerFunc { return func(c diam.Conn, m *diam.Message) {} }()) select { case err := <-sm.ErrorReports(): if err == nil { t.Fatal("Expecting error that didn't occur") } case <-time.After(time.Second): t.Fatal("Timed out waiting for error") } // RAR for our test. mc := make(chan *diam.Message, 1) sm.HandleFunc("RAR", func(c diam.Conn, m *diam.Message) { mc <- m }) mux := diam.NewServeMux() mux.HandleFunc("CEA", func(c diam.Conn, m *diam.Message) { mc <- m }) mux.HandleFunc("DWA", func(c diam.Conn, m *diam.Message) { mc <- m }) cli, err := diam.Dial(srv.Address, mux, dict.Default) if err != nil { t.Fatal(err) } defer cli.Close() // Send CER first, wait for CEA. m := diam.NewRequest(diam.CapabilitiesExchange, 1001, dict.Default) m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost) m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm) m.NewAVP(avp.HostIPAddress, avp.Mbit, 0, localhostAddress) m.NewAVP(avp.VendorID, avp.Mbit, 0, clientSettings.VendorID) m.NewAVP(avp.ProductName, 0, 0, clientSettings.ProductName) m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1)) m.NewAVP(avp.AcctApplicationID, avp.Mbit, 0, datatype.Unsigned32(1001)) m.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, clientSettings.FirmwareRevision) _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } // Retransmit CER. _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } // Test CEA Result-Code. select { case resp := <-mc: if !testResultCode(resp, diam.Success) { t.Fatalf("Unexpected result code.\n%s", resp) } case err := <-sm.ErrorReports(): t.Fatal(err) case err := <-mux.ErrorReports(): t.Fatal(err) case <-time.After(time.Second): t.Fatal("No CEA message received") } // Send RAR. m = diam.NewRequest(diam.ReAuth, 0, dict.Default) m.NewAVP(avp.SessionID, avp.Mbit, 0, datatype.OctetString("foobar")) m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost) m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm) m.NewAVP(avp.AuthApplicationID, avp.Mbit, 0, datatype.Unsigned32(1002)) m.NewAVP(avp.ReAuthRequestType, avp.Mbit, 0, datatype.Unsigned32(0)) m.NewAVP(avp.UserName, avp.Mbit, 0, datatype.OctetString("test")) m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1)) _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } // Ensure the RAR was handled by the state machine. select { case <-mc: // All good. case err := <-sm.ErrorReports(): t.Fatal(err) case err := <-mux.ErrorReports(): t.Fatal(err) case <-time.After(time.Second): t.Fatal("No RAR message received") } // Send DWR. m = diam.NewRequest(diam.DeviceWatchdog, 0, dict.Default) m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost) m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm) _, err = m.WriteTo(cli) if err != nil { t.Fatal(err) } // Ensure the DWR was handled by the state machine. select { case <-mc: // All good. case err := <-sm.ErrorReports(): t.Fatal(err) case err := <-mux.ErrorReports(): t.Fatal(err) case <-time.After(time.Second): t.Fatal("No DWR message received") } }
// Dial calls the address set as ip:port, performs a handshake and optionally // start a watchdog goroutine in background. func (cli *Client) Dial(addr string) (diam.Conn, error) { return cli.dial(func() (diam.Conn, error) { return diam.Dial(addr, cli.Handler, cli.Dict) }) }