otelchi/middleware.go

271 lines
7.8 KiB
Go
Raw Permalink Normal View History

2025-03-22 23:29:15 +13:00
package otelchi
import (
"net/http"
"strconv"
"sync"
"time"
"github.com/felixge/httpsnoop"
"github.com/go-chi/chi/v5"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
otelmetric "go.opentelemetry.io/otel/metric"
semconv "go.opentelemetry.io/otel/semconv/v1.20.0"
"go.opentelemetry.io/otel/semconv/v1.20.0/httpconv"
oteltrace "go.opentelemetry.io/otel/trace"
)
const (
tracerName = "toastielab.dev/toastie-stuff/otelchi"
)
// Middleware sets up a handler to start tracing the incoming
// requests. The serverName parameter should describe the name of the
// (virtual) server handling the request.
func Middleware(serverName string, opts ...Option) func(http.Handler) http.Handler {
cfg := config{}
for _, opt := range opts {
opt.apply(&cfg)
}
if cfg.TracerProvider == nil {
cfg.TracerProvider = otel.GetTracerProvider()
}
tracer := cfg.TracerProvider.Tracer(
tracerName,
oteltrace.WithInstrumentationVersion(Version()),
)
if cfg.MeterProvider == nil {
cfg.MeterProvider = otel.GetMeterProvider()
}
meter := cfg.MeterProvider.Meter(
tracerName,
otelmetric.WithInstrumentationVersion(Version()),
)
if cfg.Propagators == nil {
cfg.Propagators = otel.GetTextMapPropagator()
}
return func(handler http.Handler) http.Handler {
return &otelware{
serverName: serverName,
tracer: tracer,
meter: meter,
recorder: newMetricsRecorder(meter),
propagators: cfg.Propagators,
handler: handler,
chiRoutes: cfg.ChiRoutes,
reqMethodInSpanName: cfg.RequestMethodInSpanName,
filters: cfg.Filters,
traceIDResponseHeaderKey: cfg.TraceIDResponseHeaderKey,
traceSampledResponseHeaderKey: cfg.TraceSampledResponseHeaderKey,
publicEndpointFn: cfg.PublicEndpointFn,
disableMeasureInflight: cfg.DisableMeasureInflight,
disableMeasureSize: cfg.DisableMeasureSize,
}
}
}
type otelware struct {
serverName string
tracer oteltrace.Tracer
meter otelmetric.Meter
recorder *metricsRecorder
propagators propagation.TextMapPropagator
handler http.Handler
chiRoutes chi.Routes
reqMethodInSpanName bool
filters []Filter
traceIDResponseHeaderKey string
traceSampledResponseHeaderKey string
publicEndpointFn func(r *http.Request) bool
disableMeasureInflight bool
disableMeasureSize bool
}
type recordingResponseWriter struct {
writer http.ResponseWriter
written bool
writtenBytes int64
status int
}
var rrwPool = &sync.Pool{
New: func() interface{} {
return &recordingResponseWriter{}
},
}
func getRRW(writer http.ResponseWriter) *recordingResponseWriter {
rrw := rrwPool.Get().(*recordingResponseWriter)
rrw.written = false
rrw.writtenBytes = 0
rrw.status = http.StatusOK
rrw.writer = httpsnoop.Wrap(writer, httpsnoop.Hooks{
Write: func(next httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return func(b []byte) (int, error) {
if !rrw.written {
rrw.written = true
rrw.writtenBytes += int64(len(b))
}
return next(b)
}
},
WriteHeader: func(next httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
return func(statusCode int) {
if !rrw.written {
rrw.written = true
rrw.status = statusCode
}
next(statusCode)
}
},
})
return rrw
}
func putRRW(rrw *recordingResponseWriter) {
rrw.writer = nil
rrwPool.Put(rrw)
}
// ServeHTTP implements the http.Handler interface. It does the actual
// tracing of the request.
func (ow *otelware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// go through all filters if any
for _, filter := range ow.filters {
// if there is a filter that returns false, we skip tracing
// and execute next handler
if !filter(r) {
ow.handler.ServeHTTP(w, r)
return
}
}
// extract tracing header using propagator
ctx := ow.propagators.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
// create span, based on specification, we need to set already known attributes
// when creating the span, the only thing missing here is HTTP route pattern since
// in go-chi/chi route pattern could only be extracted once the request is executed
// check here for details:
//
// https://github.com/go-chi/chi/issues/150#issuecomment-278850733
//
// if we have access to chi routes, we could extract the route pattern beforehand.
spanName := ""
routePattern := ""
spanAttributes := httpconv.ServerRequest(ow.serverName, r)
if ow.chiRoutes != nil {
rctx := chi.NewRouteContext()
if ow.chiRoutes.Match(rctx, r.Method, r.URL.Path) {
routePattern = rctx.RoutePattern()
spanName = addPrefixToSpanName(ow.reqMethodInSpanName, r.Method, routePattern)
spanAttributes = append(spanAttributes, semconv.HTTPRoute(routePattern))
}
}
// define span start options
spanOpts := []oteltrace.SpanStartOption{
oteltrace.WithAttributes(spanAttributes...),
oteltrace.WithSpanKind(oteltrace.SpanKindServer),
}
if ow.publicEndpointFn != nil && ow.publicEndpointFn(r) {
// mark span as the root span
spanOpts = append(spanOpts, oteltrace.WithNewRoot())
// linking incoming span context to the root span, we need to
// ensure if the incoming span context is valid (because it is
// possible for us to receive invalid span context due to various
// reason such as bug or context propagation error) and it is
// coming from another service (remote) before linking it to the
// root span
spanCtx := oteltrace.SpanContextFromContext(ctx)
if spanCtx.IsValid() && spanCtx.IsRemote() {
spanOpts = append(
spanOpts,
oteltrace.WithLinks(oteltrace.Link{
SpanContext: spanCtx,
}),
)
}
}
props := httpReqProperties{
Service: ow.serverName,
ID: routePattern,
Method: r.Method,
}
if routePattern == "" {
props.ID = r.URL.Path
}
if !ow.disableMeasureInflight {
ow.recorder.RecordRequestsInflight(ctx, props, 1)
defer ow.recorder.RecordRequestsInflight(ctx, props, -1)
}
// start span
ctx, span := ow.tracer.Start(ctx, spanName, spanOpts...)
defer span.End()
// put trace_id to response header only when `WithTraceIDResponseHeader` is used
if len(ow.traceIDResponseHeaderKey) > 0 && span.SpanContext().HasTraceID() {
w.Header().Add(ow.traceIDResponseHeaderKey, span.SpanContext().TraceID().String())
w.Header().Add(ow.traceSampledResponseHeaderKey, strconv.FormatBool(span.SpanContext().IsSampled()))
}
// get recording response writer
rrw := getRRW(w)
defer putRRW(rrw)
// execute next http handler
r = r.WithContext(ctx)
start := time.Now()
ow.handler.ServeHTTP(rrw.writer, r)
duration := time.Since(start)
props.Code = rrw.status
ow.recorder.RecordRequestDuration(ctx, props, duration)
if !ow.disableMeasureSize {
ow.recorder.RecordResponseSize(ctx, props, rrw.writtenBytes)
}
// set span name & http route attribute if route pattern cannot be determined
// during span creation
if len(routePattern) == 0 {
routePattern = chi.RouteContext(r.Context()).RoutePattern()
span.SetAttributes(semconv.HTTPRoute(routePattern))
spanName = addPrefixToSpanName(ow.reqMethodInSpanName, r.Method, routePattern)
span.SetName(spanName)
}
// set status code attribute
span.SetAttributes(semconv.HTTPStatusCode(rrw.status))
// set span status
span.SetStatus(httpconv.ServerStatus(rrw.status))
}
func addPrefixToSpanName(shouldAdd bool, prefix, spanName string) string {
// in chi v5.0.8, the root route will be returned has an empty string
// (see https://github.com/go-chi/chi/blob/v5.0.8/context.go#L126)
if spanName == "" {
spanName = "/"
}
if shouldAdd && len(spanName) > 0 {
spanName = prefix + " " + spanName
}
return spanName
}