otelchi/test/cases/middleware_test.go

150 lines
4.2 KiB
Go

package otelchi_test
import (
"bufio"
"context"
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
"toastielab.dev/toastie-stuff/otelchi"
)
var sc = trace.NewSpanContext(trace.SpanContextConfig{
TraceID: [16]byte{1},
SpanID: [8]byte{1},
Remote: true,
TraceFlags: trace.FlagsSampled,
})
func TestPassthroughSpanFromGlobalTracer(t *testing.T) {
var called bool
router := chi.NewRouter()
router.Use(otelchi.Middleware("foobar"))
// The default global TracerProvider provides "pass through" spans for any
// span context in the incoming request context.
router.HandleFunc("/user/{id}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
got := trace.SpanFromContext(r.Context()).SpanContext()
assert.Equal(t, sc, got)
w.WriteHeader(http.StatusOK)
}))
r := httptest.NewRequest("GET", "/user/123", nil)
r = r.WithContext(trace.ContextWithRemoteSpanContext(context.Background(), sc))
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
assert.True(t, called, "failed to run test")
}
func TestPropagationWithGlobalPropagators(t *testing.T) {
defer func(p propagation.TextMapPropagator) {
otel.SetTextMapPropagator(p)
}(otel.GetTextMapPropagator())
prop := propagation.TraceContext{}
otel.SetTextMapPropagator(prop)
r := httptest.NewRequest("GET", "/user/123", nil)
w := httptest.NewRecorder()
ctx := trace.ContextWithRemoteSpanContext(context.Background(), sc)
otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(r.Header))
var called bool
router := chi.NewRouter()
router.Use(otelchi.Middleware("foobar"))
router.HandleFunc("/user/{id}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
span := trace.SpanFromContext(r.Context())
assert.Equal(t, sc, span.SpanContext())
w.WriteHeader(http.StatusOK)
}))
router.ServeHTTP(w, r)
assert.True(t, called, "failed to run test")
}
func TestPropagationWithCustomPropagators(t *testing.T) {
prop := propagation.TraceContext{}
r := httptest.NewRequest("GET", "/user/123", nil)
w := httptest.NewRecorder()
ctx := trace.ContextWithRemoteSpanContext(context.Background(), sc)
prop.Inject(ctx, propagation.HeaderCarrier(r.Header))
var called bool
router := chi.NewRouter()
router.Use(otelchi.Middleware("foobar", otelchi.WithPropagators(prop)))
router.HandleFunc("/user/{id}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
span := trace.SpanFromContext(r.Context())
assert.Equal(t, sc, span.SpanContext())
w.WriteHeader(http.StatusOK)
}))
router.ServeHTTP(w, r)
assert.True(t, called, "failed to run test")
}
func TestResponseWriterInterfaces(t *testing.T) {
// make sure the recordingResponseWriter preserves interfaces implemented by the wrapped writer
router := chi.NewRouter()
router.Use(otelchi.Middleware("foobar"))
router.HandleFunc("/user/{id}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Implements(t, (*http.Hijacker)(nil), w)
assert.Implements(t, (*http.Pusher)(nil), w)
assert.Implements(t, (*http.Flusher)(nil), w)
assert.Implements(t, (*io.ReaderFrom)(nil), w)
w.WriteHeader(http.StatusOK)
}))
r := httptest.NewRequest("GET", "/user/123", nil)
w := &testResponseWriter{
writer: httptest.NewRecorder(),
}
router.ServeHTTP(w, r)
}
type testResponseWriter struct {
writer http.ResponseWriter
}
func (rw *testResponseWriter) Header() http.Header {
return rw.writer.Header()
}
func (rw *testResponseWriter) Write(b []byte) (int, error) {
return rw.writer.Write(b)
}
func (rw *testResponseWriter) WriteHeader(statusCode int) {
rw.writer.WriteHeader(statusCode)
}
// implement Hijacker
func (rw *testResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, nil
}
// implement Pusher
func (rw *testResponseWriter) Push(target string, opts *http.PushOptions) error {
return nil
}
// implement Flusher
func (rw *testResponseWriter) Flush() {
}
// implement io.ReaderFrom
func (rw *testResponseWriter) ReadFrom(r io.Reader) (n int64, err error) {
return 0, nil
}