示例#1
0
文件: zmq4_test.go 项目: pebbe/zmq4
func TestSecurityNull(t *testing.T) {

	time.Sleep(100 * time.Millisecond)

	var handler, server, client *zmq.Socket
	defer func() {
		for _, s := range []*zmq.Socket{handler} {
			if s != nil {
				s.SetLinger(0)
				s.Close()
			}
		}
	}()

	handler, err := zmq.NewSocket(zmq.REP)
	if err != nil {
		t.Fatal("NewSocket:", err)
	}
	err = handler.Bind("inproc://zeromq.zap.01")
	if err != nil {
		t.Fatal("handler.Bind:", err)
	}

	doHandler := func(state zmq.State) error {
		msg, err := handler.RecvMessage(0)
		if err != nil {
			return err //  Terminating
		}
		version := msg[0]
		sequence := msg[1]
		domain := msg[2]
		// address := msg[3]
		// identity := msg[4]
		mechanism := msg[5]

		if version != "1.0" {
			return errors.New("version != 1.0")
		}
		if mechanism != "NULL" {
			return errors.New("mechanism != NULL")
		}

		if domain == "TEST" {
			handler.SendMessage(version, sequence, "200", "OK", "anonymous", "")
		} else {
			handler.SendMessage(version, sequence, "400", "BAD DOMAIN", "", "")
		}
		return nil
	}

	doQuit := func(i interface{}) error {
		err := handler.Close()
		handler = nil
		if err != nil {
			t.Error("handler.Close:", err)
		}
		return errors.New("Quit")
	}
	quit := make(chan interface{})

	reactor := zmq.NewReactor()
	reactor.AddSocket(handler, zmq.POLLIN, doHandler)
	reactor.AddChannel(quit, 0, doQuit)
	go func() {
		reactor.Run(100 * time.Millisecond)
		quit <- true
	}()
	defer func() {
		quit <- true
		<-quit
		close(quit)
	}()

	//  We bounce between a binding server and a connecting client
	server, err = zmq.NewSocket(zmq.DEALER)
	if err != nil {
		t.Fatal("NewSocket:", err)
	}
	client, err = zmq.NewSocket(zmq.DEALER)
	if err != nil {
		t.Fatal("NewSocket:", err)
	}

	//  We first test client/server with no ZAP domain
	//  Libzmq does not call our ZAP handler, the connect must succeed
	err = server.Bind("tcp://127.0.0.1:9683")
	if err != nil {
		t.Fatal("server.Bind:", err)
	}
	err = client.Connect("tcp://127.0.0.1:9683")
	if err != nil {
		t.Fatal("client.Connect:", err)
	}
	msg, err := bounce(server, client)
	if err != nil {
		t.Error(msg, err)
	}
	server.Unbind("tcp://127.0.0.1:9683")
	client.Disconnect("tcp://127.0.0.1:9683")

	//  Now define a ZAP domain for the server; this enables
	//  authentication. We're using the wrong domain so this test
	//  must fail.
	err = server.SetZapDomain("WRONG")
	if err != nil {
		t.Fatal("server.SetZapDomain:", err)
	}
	err = server.Bind("tcp://127.0.0.1:9687")
	if err != nil {
		t.Fatal("server.Bind:", err)
	}
	err = client.Connect("tcp://127.0.0.1:9687")
	if err != nil {
		t.Fatal("client.Connect:", err)
	}
	err = client.SetRcvtimeo(time.Second)
	if err != nil {
		t.Fatal("client.SetRcvtimeo:", err)
	}
	err = server.SetRcvtimeo(time.Second)
	if err != nil {
		t.Fatal("server.SetRcvtimeo:", err)
	}
	_, err = bounce(server, client)
	if err == nil {
		t.Error("Expected failure, got success")
	}
	server.Unbind("tcp://127.0.0.1:9687")
	client.Disconnect("tcp://127.0.0.1:9687")

	//  Now use the right domain, the test must pass
	err = server.SetZapDomain("TEST")
	if err != nil {
		t.Fatal("server.SetZapDomain:", err)
	}
	err = server.Bind("tcp://127.0.0.1:9688")
	if err != nil {
		t.Fatal("server.Bind:", err)
	}
	err = client.Connect("tcp://127.0.0.1:9688")
	if err != nil {
		t.Fatal("client.Connect:", err)
	}
	msg, err = bounce(server, client)
	if err != nil {
		t.Error(msg, err)
	}
	server.Unbind("tcp://127.0.0.1:9688")
	client.Disconnect("tcp://127.0.0.1:9688")

	err = client.Close()
	client = nil
	if err != nil {
		t.Error("client.Close:", err)
	}
	err = server.Close()
	server = nil
	if err != nil {
		t.Error("server.Close:", err)
	}
}