func (h *Handler) HandleElement(el element.Element, props stream.Properties) ( []element.Element, stream.Properties) { var elems []element.Element var challenge bool switch el.Tag { case "auth": mechName := el.SelectAttrValue("mechanism", "") mech, ok := h.mechs[mechName] if !ok { elems = append(elems, element.SASLFailure.InvalidMechanism) break } data := el.Text() log.Println("Authenticating") elems, props, challenge = mech.Authenticate(data, props) if challenge { h.current = mech } case "response": if h.current == nil { el := element.SASLFailure.NotAuthorized. AddChild(element.New("text").SetText("Out of order SASL element")) elems = append(elems, el) } data := el.Text() elems, props, challenge = h.current.Authenticate(data, props) if !challenge { h.current = nil } } return elems, props }
func TestWriteElement(t *testing.T) { t.Parallel() var want, got []byte el := element.New("testing").AddAttr("xmlns", "foo:bar") want = el.WriteBytes() got = make([]byte, len(want)) read, write := net.Pipe() tcpTsp := NewTCP(write, stream.Receiving, nil, false) go func() { _, err := read.Read(got) if err != nil { t.Errorf("Received error while reading from connection: %s", err) } }() err := tcpTsp.WriteElement(el) if err != nil { t.Errorf("Unexpected error from WriteElement: %s", err) } if !reflect.DeepEqual(want, got) { t.Error("Should be able to write element to TCP stream.") t.Errorf("\nWant:%v\nGot :%v", want, got) } }
func (h *Handler) GenerateFeature(props stream.Properties) stream.Properties { if props.Status&stream.Auth != 0 { return props } mechs := element.SASLMechanisms for name := range h.mechs { mechs = mechs.AddChild(element.New("mechanism").SetText(name)) } props.Features = append(props.Features, mechs) return props }
func TestWriteElementError(t *testing.T) { t.Parallel() var want, got error want = io.ErrClosedPipe el := element.New("testing") _, pipe := net.Pipe() tcpTsp := NewTCP(pipe, stream.Receiving, nil, false) err := pipe.Close() if err != nil { t.Errorf("Unexpected error: %s", err) } got = tcpTsp.WriteElement(el) if got != want { t.Error("Should receive error from connection when writing element.") t.Errorf("\nWant:%s\nGot :%s", want, got) } }
func TestNext(t *testing.T) { t.Parallel() var want, got interface{} var err error var el element.Element pipe1, pipe2 := net.Pipe() el = element.New("testing").AddAttr("foo", "bar"). SetText("random text"). AddChild(element.New("baz-quux")) tcpTsp := NewTCP(pipe1, stream.Receiving, nil, true) // Should be able to get a token from the transport go func() { _, err := pipe2.Write(el.WriteBytes()) if err != nil { t.Errorf("An unexpected error occurred: %s", err) } }() want = el got, err = tcpTsp.Next() if err != nil { t.Errorf("An unexpected error occurred: %s", err) } if !reflect.DeepEqual(want, got) { t.Error("Should be able to get a token from the transport.") t.Errorf("\nWant:%+v\nGot :%+v", want, got) } // Stream element should return token and not attempt to read the entire stream. pipe1, pipe2 = net.Pipe() tcpTsp = NewTCP(pipe1, stream.Receiving, nil, true) go func() { _, err := pipe2.Write(stream.Header{}.WriteBytes()) if err != nil { t.Errorf("An unexpected error occurred: %s", err) } _, err = pipe2.Write([]byte("<foo/></stream:stream>")) if err != nil { t.Errorf("An unexpected error occurred: %s", err) } }() el, err = tcpTsp.Next() if err != nil { t.Errorf("An unexpected error occurred: %s", err) } if el.Space != namespace.Stream || el.Tag != "stream" { t.Error("Stream element should return token and not attempt to read the entire stream.") } got, err = tcpTsp.Next() if err != nil { t.Errorf("An unexpected error occurred: %s", err) } want = element.New("foo") if !reflect.DeepEqual(want, got) { t.Error("Stream element should return token and not attempt to read the entire stream.") t.Errorf("\nWant:%+v\nGot :%+v", want, got) } }