Ejemplo n.º 1
0
func TestCapabilitiesExchange(t *testing.T) {
	errc := make(chan error, 1)

	smux := diam.NewServeMux()
	smux.Handle("CER", handleCER(errc, false))

	srv := diamtest.NewServer(smux, nil)
	defer srv.Close()

	wait := make(chan struct{})
	cmux := diam.NewServeMux()
	cmux.Handle("CEA", handleCEA(errc, wait))

	cli, err := diam.Dial(srv.Addr, cmux, nil)
	if err != nil {
		t.Fatal(err)
	}

	sendCER(cli)

	select {
	case <-wait:
	case err := <-errc:
		t.Fatal(err)
	case err := <-smux.ErrorReports():
		t.Fatal(err)
	case <-time.After(time.Second):
		t.Fatal("Timed out: no CER or CEA received")
	}
}
Ejemplo n.º 2
0
// GetBridge returns the Bridge object for a given client, if it exists.
// Otherwise GetBridge connects to the upstream server and set up the
// bridge with the client, returning the newly created Bridge object.
func GetBridge(c diam.Conn) *Bridge {
	liveMu.RLock()
	if b, ok := liveBridge[c.RemoteAddr().String()]; ok {
		liveMu.RUnlock()
		return b
	}
	liveMu.RUnlock()
	liveMu.Lock()
	defer liveMu.Unlock()
	b := &Bridge{
		Client: make(chan *diam.Message),
		Server: make(chan *diam.Message),
	}
	liveBridge[c.RemoteAddr().String()] = b
	// Prepare for the upstream connection.
	mux := diam.NewServeMux()
	mux.HandleFunc("ALL", func(c diam.Conn, m *diam.Message) {
		// Forward incoming messages to the client.
		b.Client <- m
	})
	// Connect to upstream server.
	s, err := diam.Dial(upstreamAddr, mux, nil)
	if err != nil {
		return nil
	}
	log.Printf("Creating bridge from %s to %s",
		c.RemoteAddr().String(), s.RemoteAddr().String())
	go Pump(c, s, b.Client, b.Server)
	go Pump(s, c, b.Server, b.Client)
	return b
}
Ejemplo n.º 3
0
func TestHandleDWR_Fail(t *testing.T) {
	sm := New(serverSettings)
	srv := diamtest.NewServer(sm, dict.Default)
	defer srv.Close()
	mc := make(chan *diam.Message, 1)
	mux := diam.NewServeMux()
	mux.HandleFunc("CEA", func(c diam.Conn, m *diam.Message) {
		mc <- m
	})
	mux.HandleFunc("DWA", func(c diam.Conn, m *diam.Message) {
		mc <- m
	})
	cli, err := diam.Dial(srv.Addr, mux, dict.Default)
	if err != nil {
		t.Fatal(err)
	}
	defer cli.Close()
	// Send CER first.
	m := diam.NewRequest(diam.CapabilitiesExchange, 1001, dict.Default)
	m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost)
	m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm)
	m.NewAVP(avp.HostIPAddress, avp.Mbit, 0, localhostAddress)
	m.NewAVP(avp.VendorID, avp.Mbit, 0, clientSettings.VendorID)
	m.NewAVP(avp.ProductName, 0, 0, clientSettings.ProductName)
	m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1))
	m.NewAVP(avp.AcctApplicationID, avp.Mbit, 0, datatype.Unsigned32(1001))
	m.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, clientSettings.FirmwareRevision)
	_, err = m.WriteTo(cli)
	if err != nil {
		t.Fatal(err)
	}
	select {
	case resp := <-mc:
		if !testResultCode(resp, diam.Success) {
			t.Fatalf("Unexpected result code for CEA.\n%s", resp)
		}
	case err := <-mux.ErrorReports():
		t.Fatal(err)
	case <-time.After(time.Second):
		t.Fatal("No CEA received")
	}
	// Send broken DWR (missing Origin-Host, etc).
	m = diam.NewRequest(diam.DeviceWatchdog, 0, dict.Default)
	_, err = m.WriteTo(cli)
	if err != nil {
		t.Fatal(err)
	}
	select {
	case err := <-sm.ErrorReports():
		if err.Error != smparser.ErrMissingOriginHost {
			t.Fatalf("Unexpected error. Want ErrMissingOriginHost, have %#v", err.Error)
		}
	case err := <-mux.ErrorReports():
		t.Fatal(err)
	case <-time.After(time.Second):
		t.Fatal("No DWA received")
	}
}
Ejemplo n.º 4
0
func TestHandleCER_HandshakeMetadata(t *testing.T) {
	sm := New(serverSettings)
	srv := diamtest.NewServer(sm, dict.Default)
	defer srv.Close()
	hsc := make(chan diam.Conn, 1)
	cli, err := diam.Dial(srv.Address, nil, dict.Default)
	if err != nil {
		t.Fatal(err)
	}
	defer cli.Close()
	ready := make(chan struct{})
	go func() {
		close(ready)
		c := <-sm.HandshakeNotify()
		hsc <- c
	}()
	<-ready
	m := diam.NewRequest(diam.CapabilitiesExchange, 1001, dict.Default)
	m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost)
	m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm)
	m.NewAVP(avp.HostIPAddress, avp.Mbit, 0, localhostAddress)
	m.NewAVP(avp.VendorID, avp.Mbit, 0, clientSettings.VendorID)
	m.NewAVP(avp.ProductName, 0, 0, clientSettings.ProductName)
	m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1))
	m.NewAVP(avp.AcctApplicationID, avp.Mbit, 0, datatype.Unsigned32(1001))
	m.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, clientSettings.FirmwareRevision)
	_, err = m.WriteTo(cli)
	if err != nil {
		t.Fatal(err)
	}
	select {
	case c := <-hsc:
		ctx := c.Context()
		meta, ok := smpeer.FromContext(ctx)
		if !ok {
			t.Fatal("Handshake ok but no context/metadata found")
		}
		if meta.OriginHost != clientSettings.OriginHost {
			t.Fatalf("Unexpected OriginHost. Want %q, have %q",
				clientSettings.OriginHost, meta.OriginHost)
		}
		if meta.OriginRealm != clientSettings.OriginRealm {
			t.Fatalf("Unexpected OriginRealm. Want %q, have %q",
				clientSettings.OriginRealm, meta.OriginRealm)
		}
	}
}
Ejemplo n.º 5
0
func TestHandleCER_VS_Auth_Fail(t *testing.T) {
	sm := New(serverSettings)
	srv := diamtest.NewServer(sm, dict.Default)
	defer srv.Close()
	mc := make(chan *diam.Message, 1)
	mux := diam.NewServeMux()
	mux.HandleFunc("CEA", func(c diam.Conn, m *diam.Message) {
		mc <- m
	})
	cli, err := diam.Dial(srv.Address, mux, dict.Default)
	if err != nil {
		t.Fatal(err)
	}
	defer cli.Close()
	m := diam.NewRequest(diam.CapabilitiesExchange, 0, dict.Default)
	m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost)
	m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm)
	m.NewAVP(avp.HostIPAddress, avp.Mbit, 0, localhostAddress)
	m.NewAVP(avp.VendorID, avp.Mbit, 0, clientSettings.VendorID)
	m.NewAVP(avp.ProductName, 0, 0, clientSettings.ProductName)
	m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1))
	m.NewAVP(avp.VendorSpecificApplicationID, avp.Mbit, 0, &diam.GroupedAVP{
		AVP: []*diam.AVP{
			diam.NewAVP(avp.AuthApplicationID, avp.Mbit, 0, datatype.Unsigned32(1000)),
		},
	})
	m.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, clientSettings.FirmwareRevision)
	_, err = m.WriteTo(cli)
	if err != nil {
		t.Fatal(err)
	}
	select {
	case resp := <-mc:
		if !testResultCode(resp, diam.NoCommonApplication) {
			t.Fatalf("Unexpected result code.\n%s", resp)
		}
	case err := <-mux.ErrorReports():
		t.Fatal(err)
	case <-time.After(time.Second):
		t.Fatal("No message received")
	}
}
Ejemplo n.º 6
0
// TestStateMachine establishes a connection with a test server and
// sends a Re-Auth-Request message to ensure the handshake was
// completed and that the RAR handler has context from the peer.
func TestStateMachine(t *testing.T) {
	sm := New(serverSettings)
	if sm.Settings() != serverSettings {
		t.Fatal("Invalid settings")
	}
	srv := diamtest.NewServer(sm, dict.Default)
	defer srv.Close()
	// CER handlers are ignored by the state machine.
	// Using Handle instead of HandleFunc to exercise that code.
	sm.Handle("CER", func() diam.HandlerFunc {
		return func(c diam.Conn, m *diam.Message) {}
	}())
	select {
	case err := <-sm.ErrorReports():
		if err == nil {
			t.Fatal("Expecting error that didn't occur")
		}
	case <-time.After(time.Second):
		t.Fatal("Timed out waiting for error")
	}
	// RAR for our test.
	mc := make(chan *diam.Message, 1)
	sm.HandleFunc("RAR", func(c diam.Conn, m *diam.Message) {
		mc <- m
	})
	mux := diam.NewServeMux()
	mux.HandleFunc("CEA", func(c diam.Conn, m *diam.Message) {
		mc <- m
	})
	mux.HandleFunc("DWA", func(c diam.Conn, m *diam.Message) {
		mc <- m
	})
	cli, err := diam.Dial(srv.Address, mux, dict.Default)
	if err != nil {
		t.Fatal(err)
	}
	defer cli.Close()
	// Send CER first, wait for CEA.
	m := diam.NewRequest(diam.CapabilitiesExchange, 1001, dict.Default)
	m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost)
	m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm)
	m.NewAVP(avp.HostIPAddress, avp.Mbit, 0, localhostAddress)
	m.NewAVP(avp.VendorID, avp.Mbit, 0, clientSettings.VendorID)
	m.NewAVP(avp.ProductName, 0, 0, clientSettings.ProductName)
	m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1))
	m.NewAVP(avp.AcctApplicationID, avp.Mbit, 0, datatype.Unsigned32(1001))
	m.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, clientSettings.FirmwareRevision)
	_, err = m.WriteTo(cli)
	if err != nil {
		t.Fatal(err)
	}
	// Retransmit CER.
	_, err = m.WriteTo(cli)
	if err != nil {
		t.Fatal(err)
	}
	// Test CEA Result-Code.
	select {
	case resp := <-mc:
		if !testResultCode(resp, diam.Success) {
			t.Fatalf("Unexpected result code.\n%s", resp)
		}
	case err := <-sm.ErrorReports():
		t.Fatal(err)
	case err := <-mux.ErrorReports():
		t.Fatal(err)
	case <-time.After(time.Second):
		t.Fatal("No CEA message received")
	}
	// Send RAR.
	m = diam.NewRequest(diam.ReAuth, 0, dict.Default)
	m.NewAVP(avp.SessionID, avp.Mbit, 0, datatype.OctetString("foobar"))
	m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost)
	m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm)
	m.NewAVP(avp.AuthApplicationID, avp.Mbit, 0, datatype.Unsigned32(1002))
	m.NewAVP(avp.ReAuthRequestType, avp.Mbit, 0, datatype.Unsigned32(0))
	m.NewAVP(avp.UserName, avp.Mbit, 0, datatype.OctetString("test"))
	m.NewAVP(avp.OriginStateID, avp.Mbit, 0, datatype.Unsigned32(1))
	_, err = m.WriteTo(cli)
	if err != nil {
		t.Fatal(err)
	}
	// Ensure the RAR was handled by the state machine.
	select {
	case <-mc:
		// All good.
	case err := <-sm.ErrorReports():
		t.Fatal(err)
	case err := <-mux.ErrorReports():
		t.Fatal(err)
	case <-time.After(time.Second):
		t.Fatal("No RAR message received")
	}
	// Send DWR.
	m = diam.NewRequest(diam.DeviceWatchdog, 0, dict.Default)
	m.NewAVP(avp.OriginHost, avp.Mbit, 0, clientSettings.OriginHost)
	m.NewAVP(avp.OriginRealm, avp.Mbit, 0, clientSettings.OriginRealm)
	_, err = m.WriteTo(cli)
	if err != nil {
		t.Fatal(err)
	}
	// Ensure the DWR was handled by the state machine.
	select {
	case <-mc:
	// All good.
	case err := <-sm.ErrorReports():
		t.Fatal(err)
	case err := <-mux.ErrorReports():
		t.Fatal(err)
	case <-time.After(time.Second):
		t.Fatal("No DWR message received")
	}
}
Ejemplo n.º 7
0
// Dial calls the address set as ip:port, performs a handshake and optionally
// start a watchdog goroutine in background.
func (cli *Client) Dial(addr string) (diam.Conn, error) {
	return cli.dial(func() (diam.Conn, error) {
		return diam.Dial(addr, cli.Handler, cli.Dict)
	})
}