package otelchi

import (
	"net/http"
	"strconv"
	"sync"
	"time"

	"github.com/felixge/httpsnoop"
	"github.com/go-chi/chi/v5"

	"go.opentelemetry.io/otel"
	"go.opentelemetry.io/otel/propagation"

	otelmetric "go.opentelemetry.io/otel/metric"
	semconv "go.opentelemetry.io/otel/semconv/v1.20.0"
	"go.opentelemetry.io/otel/semconv/v1.20.0/httpconv"
	oteltrace "go.opentelemetry.io/otel/trace"
)

const (
	tracerName = "toastielab.dev/toastie-stuff/otelchi"
)

// Middleware sets up a handler to start tracing the incoming
// requests. The serverName parameter should describe the name of the
// (virtual) server handling the request.
func Middleware(serverName string, opts ...Option) func(http.Handler) http.Handler {
	cfg := config{}
	for _, opt := range opts {
		opt.apply(&cfg)
	}

	if cfg.TracerProvider == nil {
		cfg.TracerProvider = otel.GetTracerProvider()
	}
	tracer := cfg.TracerProvider.Tracer(
		tracerName,
		oteltrace.WithInstrumentationVersion(Version()),
	)

	if cfg.MeterProvider == nil {
		cfg.MeterProvider = otel.GetMeterProvider()
	}
	meter := cfg.MeterProvider.Meter(
		tracerName,
		otelmetric.WithInstrumentationVersion(Version()),
	)

	if cfg.Propagators == nil {
		cfg.Propagators = otel.GetTextMapPropagator()
	}

	return func(handler http.Handler) http.Handler {
		return &otelware{
			serverName:                    serverName,
			tracer:                        tracer,
			meter:                         meter,
			recorder:                      newMetricsRecorder(meter),
			propagators:                   cfg.Propagators,
			handler:                       handler,
			chiRoutes:                     cfg.ChiRoutes,
			reqMethodInSpanName:           cfg.RequestMethodInSpanName,
			filters:                       cfg.Filters,
			traceIDResponseHeaderKey:      cfg.TraceIDResponseHeaderKey,
			traceSampledResponseHeaderKey: cfg.TraceSampledResponseHeaderKey,
			publicEndpointFn:              cfg.PublicEndpointFn,
			disableMeasureInflight:        cfg.DisableMeasureInflight,
			disableMeasureSize:            cfg.DisableMeasureSize,
		}
	}
}

type otelware struct {
	serverName                    string
	tracer                        oteltrace.Tracer
	meter                         otelmetric.Meter
	recorder                      *metricsRecorder
	propagators                   propagation.TextMapPropagator
	handler                       http.Handler
	chiRoutes                     chi.Routes
	reqMethodInSpanName           bool
	filters                       []Filter
	traceIDResponseHeaderKey      string
	traceSampledResponseHeaderKey string
	publicEndpointFn              func(r *http.Request) bool

	disableMeasureInflight bool
	disableMeasureSize     bool
}

type recordingResponseWriter struct {
	writer       http.ResponseWriter
	written      bool
	writtenBytes int64
	status       int
}

var rrwPool = &sync.Pool{
	New: func() interface{} {
		return &recordingResponseWriter{}
	},
}

func getRRW(writer http.ResponseWriter) *recordingResponseWriter {
	rrw := rrwPool.Get().(*recordingResponseWriter)
	rrw.written = false
	rrw.writtenBytes = 0
	rrw.status = http.StatusOK
	rrw.writer = httpsnoop.Wrap(writer, httpsnoop.Hooks{
		Write: func(next httpsnoop.WriteFunc) httpsnoop.WriteFunc {
			return func(b []byte) (int, error) {
				if !rrw.written {
					rrw.written = true
					rrw.writtenBytes += int64(len(b))
				}
				return next(b)
			}
		},
		WriteHeader: func(next httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
			return func(statusCode int) {
				if !rrw.written {
					rrw.written = true
					rrw.status = statusCode
				}
				next(statusCode)
			}
		},
	})
	return rrw
}

func putRRW(rrw *recordingResponseWriter) {
	rrw.writer = nil
	rrwPool.Put(rrw)
}

// ServeHTTP implements the http.Handler interface. It does the actual
// tracing of the request.
func (ow *otelware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	// go through all filters if any
	for _, filter := range ow.filters {
		// if there is a filter that returns false, we skip tracing
		// and execute next handler
		if !filter(r) {
			ow.handler.ServeHTTP(w, r)
			return
		}
	}

	// extract tracing header using propagator
	ctx := ow.propagators.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
	// create span, based on specification, we need to set already known attributes
	// when creating the span, the only thing missing here is HTTP route pattern since
	// in go-chi/chi route pattern could only be extracted once the request is executed
	// check here for details:
	//
	// https://github.com/go-chi/chi/issues/150#issuecomment-278850733
	//
	// if we have access to chi routes, we could extract the route pattern beforehand.
	spanName := ""
	routePattern := ""
	spanAttributes := httpconv.ServerRequest(ow.serverName, r)

	if ow.chiRoutes != nil {
		rctx := chi.NewRouteContext()
		if ow.chiRoutes.Match(rctx, r.Method, r.URL.Path) {
			routePattern = rctx.RoutePattern()
			spanName = addPrefixToSpanName(ow.reqMethodInSpanName, r.Method, routePattern)
			spanAttributes = append(spanAttributes, semconv.HTTPRoute(routePattern))
		}
	}

	// define span start options
	spanOpts := []oteltrace.SpanStartOption{
		oteltrace.WithAttributes(spanAttributes...),
		oteltrace.WithSpanKind(oteltrace.SpanKindServer),
	}

	if ow.publicEndpointFn != nil && ow.publicEndpointFn(r) {
		// mark span as the root span
		spanOpts = append(spanOpts, oteltrace.WithNewRoot())

		// linking incoming span context to the root span, we need to
		// ensure if the incoming span context is valid (because it is
		// possible for us to receive invalid span context due to various
		// reason such as bug or context propagation error) and it is
		// coming from another service (remote) before linking it to the
		// root span
		spanCtx := oteltrace.SpanContextFromContext(ctx)
		if spanCtx.IsValid() && spanCtx.IsRemote() {
			spanOpts = append(
				spanOpts,
				oteltrace.WithLinks(oteltrace.Link{
					SpanContext: spanCtx,
				}),
			)
		}
	}

	props := httpReqProperties{
		Service: ow.serverName,
		ID:      routePattern,
		Method:  r.Method,
	}
	if routePattern == "" {
		props.ID = r.URL.Path
	}

	if !ow.disableMeasureInflight {
		ow.recorder.RecordRequestsInflight(ctx, props, 1)
		defer ow.recorder.RecordRequestsInflight(ctx, props, -1)
	}

	// start span
	ctx, span := ow.tracer.Start(ctx, spanName, spanOpts...)
	defer span.End()

	// put trace_id to response header only when `WithTraceIDResponseHeader` is used
	if len(ow.traceIDResponseHeaderKey) > 0 && span.SpanContext().HasTraceID() {
		w.Header().Add(ow.traceIDResponseHeaderKey, span.SpanContext().TraceID().String())
		w.Header().Add(ow.traceSampledResponseHeaderKey, strconv.FormatBool(span.SpanContext().IsSampled()))
	}

	// get recording response writer
	rrw := getRRW(w)
	defer putRRW(rrw)

	// execute next http handler
	r = r.WithContext(ctx)
	start := time.Now()
	ow.handler.ServeHTTP(rrw.writer, r)
	duration := time.Since(start)

	props.Code = rrw.status
	ow.recorder.RecordRequestDuration(ctx, props, duration)

	if !ow.disableMeasureSize {
		ow.recorder.RecordResponseSize(ctx, props, rrw.writtenBytes)
	}

	// set span name & http route attribute if route pattern cannot be determined
	// during span creation
	if len(routePattern) == 0 {
		routePattern = chi.RouteContext(r.Context()).RoutePattern()
		span.SetAttributes(semconv.HTTPRoute(routePattern))

		spanName = addPrefixToSpanName(ow.reqMethodInSpanName, r.Method, routePattern)
		span.SetName(spanName)
	}

	// set status code attribute
	span.SetAttributes(semconv.HTTPStatusCode(rrw.status))

	// set span status
	span.SetStatus(httpconv.ServerStatus(rrw.status))
}

func addPrefixToSpanName(shouldAdd bool, prefix, spanName string) string {
	// in chi v5.0.8, the root route will be returned has an empty string
	// (see https://github.com/go-chi/chi/blob/v5.0.8/context.go#L126)
	if spanName == "" {
		spanName = "/"
	}

	if shouldAdd && len(spanName) > 0 {
		spanName = prefix + " " + spanName
	}
	return spanName
}