示例#1
0
文件: handler.go 项目: skriptble/nine
func (h *Handler) HandleElement(el element.Element, props stream.Properties) (
	[]element.Element, stream.Properties) {
	var elems []element.Element
	var challenge bool
	switch el.Tag {
	case "auth":
		mechName := el.SelectAttrValue("mechanism", "")
		mech, ok := h.mechs[mechName]
		if !ok {
			elems = append(elems, element.SASLFailure.InvalidMechanism)
			break
		}
		data := el.Text()
		log.Println("Authenticating")
		elems, props, challenge = mech.Authenticate(data, props)
		if challenge {
			h.current = mech
		}
	case "response":
		if h.current == nil {
			el := element.SASLFailure.NotAuthorized.
				AddChild(element.New("text").SetText("Out of order SASL element"))
			elems = append(elems, el)
		}
		data := el.Text()
		elems, props, challenge = h.current.Authenticate(data, props)
		if !challenge {
			h.current = nil
		}
	}
	return elems, props
}
示例#2
0
func TestWriteElement(t *testing.T) {
	t.Parallel()

	var want, got []byte

	el := element.New("testing").AddAttr("xmlns", "foo:bar")
	want = el.WriteBytes()
	got = make([]byte, len(want))

	read, write := net.Pipe()
	tcpTsp := NewTCP(write, stream.Receiving, nil, false)

	go func() {
		_, err := read.Read(got)
		if err != nil {
			t.Errorf("Received error while reading from connection: %s", err)
		}
	}()

	err := tcpTsp.WriteElement(el)
	if err != nil {
		t.Errorf("Unexpected error from WriteElement: %s", err)
	}

	if !reflect.DeepEqual(want, got) {
		t.Error("Should be able to write element to TCP stream.")
		t.Errorf("\nWant:%v\nGot :%v", want, got)
	}
}
示例#3
0
文件: handler.go 项目: skriptble/nine
func (h *Handler) GenerateFeature(props stream.Properties) stream.Properties {
	if props.Status&stream.Auth != 0 {
		return props
	}
	mechs := element.SASLMechanisms
	for name := range h.mechs {
		mechs = mechs.AddChild(element.New("mechanism").SetText(name))
	}
	props.Features = append(props.Features, mechs)
	return props
}
示例#4
0
func TestWriteElementError(t *testing.T) {
	t.Parallel()

	var want, got error

	want = io.ErrClosedPipe

	el := element.New("testing")
	_, pipe := net.Pipe()
	tcpTsp := NewTCP(pipe, stream.Receiving, nil, false)
	err := pipe.Close()
	if err != nil {
		t.Errorf("Unexpected error: %s", err)
	}

	got = tcpTsp.WriteElement(el)
	if got != want {
		t.Error("Should receive error from connection when writing element.")
		t.Errorf("\nWant:%s\nGot :%s", want, got)
	}
}
示例#5
0
func TestNext(t *testing.T) {
	t.Parallel()

	var want, got interface{}
	var err error
	var el element.Element
	pipe1, pipe2 := net.Pipe()
	el = element.New("testing").AddAttr("foo", "bar").
		SetText("random text").
		AddChild(element.New("baz-quux"))
	tcpTsp := NewTCP(pipe1, stream.Receiving, nil, true)

	// Should be able to get a token from the transport
	go func() {
		_, err := pipe2.Write(el.WriteBytes())
		if err != nil {
			t.Errorf("An unexpected error occurred: %s", err)
		}
	}()

	want = el
	got, err = tcpTsp.Next()
	if err != nil {
		t.Errorf("An unexpected error occurred: %s", err)
	}

	if !reflect.DeepEqual(want, got) {
		t.Error("Should be able to get a token from the transport.")
		t.Errorf("\nWant:%+v\nGot :%+v", want, got)
	}

	// Stream element should return token and not attempt to read the entire stream.
	pipe1, pipe2 = net.Pipe()
	tcpTsp = NewTCP(pipe1, stream.Receiving, nil, true)
	go func() {
		_, err := pipe2.Write(stream.Header{}.WriteBytes())
		if err != nil {
			t.Errorf("An unexpected error occurred: %s", err)
		}
		_, err = pipe2.Write([]byte("<foo/></stream:stream>"))
		if err != nil {
			t.Errorf("An unexpected error occurred: %s", err)
		}
	}()
	el, err = tcpTsp.Next()
	if err != nil {
		t.Errorf("An unexpected error occurred: %s", err)
	}
	if el.Space != namespace.Stream || el.Tag != "stream" {
		t.Error("Stream element should return token and not attempt to read the entire stream.")
	}
	got, err = tcpTsp.Next()
	if err != nil {
		t.Errorf("An unexpected error occurred: %s", err)
	}
	want = element.New("foo")
	if !reflect.DeepEqual(want, got) {
		t.Error("Stream element should return token and not attempt to read the entire stream.")
		t.Errorf("\nWant:%+v\nGot :%+v", want, got)
	}
}