« Back to Index

Go: basic middleware abstraction

View original Gist on GitHub

Tags: #go #middleware

1. middleware.go

// THIS IS THE REDESIGNED VERSION (see 3. middleware.go for my original approach)

package middleware

import (
	"net/http"
)

// Decorator is a middleware function.
type Decorator func(http.Handler) http.Handler

// Pipeline is a type for wrapping an http.Handler and adding optional logging,
// metrics and tracing.
type Pipeline struct {
	middleware []Decorator
}

// New returns a middleware Pipeline.
func New(middleware ...Decorator) *Pipeline {
	p := &Pipeline{}
	p.middleware = append(p.middleware, middleware...)
	return p
}

// Decorate wraps next in the defined middleware and executes the pipeline.
func (p *Pipeline) Decorate(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		for i := len(p.middleware) - 1; i >= 0; i-- {
			next = p.middleware[i](next)
		}
		next.ServeHTTP(w, r)
	})
}

2. main.go

// THIS IS THE REDESIGNED VERSION (see 4. main.go for my original approach)

// Package main is the entry point for the application.
package main

import (
	"context"
	"flag"
	"fmt"
	"log/slog"
	"net/http"
	"os"
	"os/signal"
	"runtime"
	"syscall"
	"time"

	"github.com/prometheus/client_golang/prometheus"
	"golang.org/x/sync/errgroup"

	"github.com/org/repo/internal/config"
	h "github.com/org/repo/internal/http"
	"github.com/org/repo/internal/logging"
	"github.com/org/repo/internal/metrics"
	"github.com/org/repo/internal/middleware"
	"github.com/org/repo/internal/traces"
)

var (
	cfgPath     = flag.String("cfg", "", "Path to config")
	metricsHost = flag.String("metrics-host", "127.0.0.1", "Host to use for metrics & pprof")
	metricsPort = flag.Int("metrics-port", 8447, "Port to expose for metrics & pprof") //nolint:gomnd
)

// ImageTag is the container image version and is displayed in the healthcheck.
// It is overridden at run/build time (see ../../Makefile).
var ImageTag = "dev"

func main() {
	build.AddVersionFlag()
	flag.Parse()

	// NOTE: Moving the logic to `app()` allows us to honour `defer` calls.
	if err := app(); err != nil {
		os.Exit(1)
	}
}

func app() error {
	ctx := context.Background()
	logger := logging.NewLogger()

	cfg, err := config.Load(*cfgPath)
	if err != nil {
		logger.LogAttrs(ctx, slog.LevelError, "config_load", slog.Any("error", fmt.Errorf("unable to load config: %w", err)))
	}
	if err := cfg.Validate(); err != nil {
		logger.LogAttrs(ctx, slog.LevelError, "config_validate", slog.Any("error", fmt.Errorf("invalid config: %w", err)))
	}

	// validate secret is available
	fmt.Printf("foo secret: %#v\n", cfg.Secrets)

	// setup observability
	registry := prometheus.NewRegistry()
	metric := metrics.New(registry)
	state := "enabled"
	if cfg.OtelDisabled {
		state = "disabled"
	}
	logger.LogAttrs(ctx, slog.LevelInfo, "trace exporting "+state)
	traceShutdown, tracer, err := traces.New(ctx, cfg)
	if err != nil {
		logger.LogAttrs(ctx, slog.LevelError, "trace create", slog.Any("error", fmt.Errorf("unable to start tracing: %w", err)))
	}
	defer func() {
		// NOTE: We use a separate context for the shutdown.
		// This is so that the graceful shutdown of the main HTTP server doesn't
		// end up cancelling the traceShutdown context before it has a chance to
		// complete its shutdown process.
		if err := traceShutdown(context.WithoutCancel(ctx)); err != nil {
			logger.LogAttrs(ctx, slog.LevelError, "trace shutdown", slog.Any("error", err))
		}
	}()

	g := new(errgroup.Group)
	timeout := time.Duration(10) * time.Second //nolint:gomnd

	// Setup separate /metrics endpoint for Prometheus scraping.
	// Needs to be separate listener as metrics service runs on its own port.
	var (
		metricsAddr  string
		metricServer *http.Server
	)
	if *metricsPort != 0 {
		metricsAddr = fmt.Sprintf("%s:%d", *metricsHost, *metricsPort)
		g.Go(func() error {
			mux := http.NewServeMux()
			mux.Handle("/metrics", metric.Handler())
			metricServer = &http.Server{
				Addr:              metricsAddr,
				Handler:           mux,
				ReadHeaderTimeout: timeout, // avoid slowloris https://app.deepsource.com/directory/analyzers/go/issues/GO-S2114
			}
			if err := metricServer.ListenAndServe(); err != nil {
				err := fmt.Errorf("metrics server failed to listen and serve requests: %w", err)
				logger.LogAttrs(ctx, slog.LevelError, "error handling metric requests",
					slog.Any("error", err),
					slog.String("metrics_addr", metricsAddr),
				)
				return err
			}
			return nil
		})
	}

	logger = logger.With(
		slog.Int("gomaxprocs", runtime.GOMAXPROCS(0)),
		slog.String("build_info", build.Info.String()),
		slog.String("cfg", *cfgPath),
		slog.String("cfg_addr", cfg.Addr),
		slog.String("metrics_addr", metricsAddr),
	)
	logger.LogAttrs(ctx, slog.LevelInfo, "service starting")

	// Routing
	mux := http.NewServeMux()

	// NOTE: We have separate pipelines.
	// One for the healthcheck and another for all other endpoints.
	// This is because we want to avoid tracing the healthcheck endpoint.
	// For some apps, tracing the healthcheck can result in poor performance.

	pipelineHealth := middleware.New(logging.WithRoutePattern(mux, logger), metric.WithHealthCheckMetrics, h.WithHeaders(cfg))
	healthHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
		w.WriteHeader(http.StatusOK)
		fmt.Fprintf(w, `{"status": "success", "image_tag": "%s"}`, ImageTag)
	})
	mux.Handle("/healthcheck", pipelineHealth.Decorate(healthHandler))

	// NOTE: The order of the middleware is deliberate.
	// WithSpan updates the request context with a Span ID.
	// WithRoutePattern checks for the Span ID and attaches it to the logger.
	// WithMetrics, WithRouteTag and WithParamAttrs all check for the route pattern added by WithRoutePattern.

	pipeline := middleware.New(traces.WithSpan(tracer), logging.WithRoutePattern(mux, logger), metric.WithMetrics, traces.WithRouteTag, traces.WithParamAttrs, h.WithHeaders(cfg))
	exampleHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
		w.WriteHeader(http.StatusOK)
		fmt.Fprint(w, `{"endpoint": "example"}`)
	})
	mux.Handle("/example/{id}", pipeline.Decorate(exampleHandler))

	// API Server
	apiServer := &http.Server{
		Addr:              cfg.Addr,
		Handler:           mux,
		ReadHeaderTimeout: timeout, // avoid slowloris https://app.deepsource.com/directory/analyzers/go/issues/GO-S2114
	}
	g.Go(func() error {
		if err := apiServer.ListenAndServe(); err != nil {
			err := fmt.Errorf("api server failed to listen and serve requests: %w", err)
			logger.LogAttrs(ctx, slog.LevelError, "error handling API requests",
				slog.Any("error", err),
				slog.String("metrics_addr", metricsAddr),
			)
			return err
		}
		return nil
	})

	// Graceful shutdown
	g.Go(func() error {
		quit := make(chan os.Signal, 1)
		signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
		signalType := <-quit
		logger.LogAttrs(ctx, slog.LevelWarn, "http server graceful shutdown",
			slog.String("signal", signalType.String()),
		)
		// NOTE: Shutdown is given a separate context from the API server.
		if apiServer != nil {
			if err := apiServer.Shutdown(context.Background()); err != nil { //nolint:contextcheck
				err := fmt.Errorf("unable to gracefully stop API server: %w", err)
				logger.LogAttrs(ctx, slog.LevelError, "error shutting down API server",
					slog.Any("error", err),
				)
				return err
			}
		}
		if metricServer != nil {
			if err := metricServer.Shutdown(context.Background()); err != nil { //nolint:contextcheck
				err := fmt.Errorf("unable to gracefully stop metrics server: %w", err)
				logger.LogAttrs(ctx, slog.LevelError, "error shutting down metrics server",
					slog.Any("error", err),
				)
				return err
			}
		}
		return nil
	})

	return g.Wait()
}

3. middleware.go

package middleware

import (
	"net/http"
)

// Decorator is a middleware function.
type Decorator func(http.Handler) http.Handler

// Pipeline is a type for wrapping an http.Handler and adding optional logging,
// metrics and tracing.
type Pipeline struct {
	middleware []Decorator
	mux        *http.ServeMux
}

// New returns a middleware Pipeline.
func New(mux *http.ServeMux) *Pipeline {
	return &Pipeline{
		mux: mux,
	}
}

// Use installs one or more middleware in the request cycle.
func (p *Pipeline) Use(middleware ...Decorator) {
	p.middleware = append(p.middleware, middleware...)
}

// Handle registers a handler for the given path.
func (p *Pipeline) Handle(path string, handler http.Handler, middleware ...Decorator) {
	for i := len(middleware) - 1; i >= 0; i-- {
		handler = middleware[i](handler)
	}
	for i := len(p.middleware) - 1; i >= 0; i-- {
		handler = p.middleware[i](handler)
	}
	p.mux.Handle(path, handler)
}

4. main.go

// Package main is the entry point for the application.
package main

import (
	"context"
	"flag"
	"fmt"
	"log/slog"
	"net/http"
	"os"
	"os/signal"
	"runtime"
	"syscall"
	"time"

	"github.com/prometheus/client_golang/prometheus"

	"github.com/org/repo/internal/config"
	"github.com/org/repo/internal/logging"
	"github.com/org/repo/internal/metrics"
	"github.com/org/repo/internal/middleware"
	"github.com/org/repo/internal/traces"
)

var (
	cfgPath     = flag.String("cfg", "", "Path to config")
	metricsHost = flag.String("metrics-host", "127.0.0.1", "Host to use for metrics & pprof")
	metricsPort = flag.Int("metrics-port", 8447, "Port to expose for metrics & pprof") //nolint:gomnd
)

// ImageTag is the container image version and is displayed in the healthcheck.
// It is overridden at run/build time (see ../../Makefile).
var ImageTag = "dev"

func main() {
	build.AddVersionFlag()
	flag.Parse()

	// NOTE: Moving the logic to `app()` allows us to honour `defer` calls.
	if err := app(); err != nil {
		os.Exit(1)
	}
}

func app() error {
	ctx := context.Background()
	logger := logging.NewLogger()

	cfg, err := config.Load(*cfgPath)
	if err != nil {
		logger.LogAttrs(ctx, slog.LevelError, "config_load", slog.Any("error", fmt.Errorf("unable to load config: %w", err)))
	}
	if err := cfg.Validate(); err != nil {
		logger.LogAttrs(ctx, slog.LevelError, "config_validate", slog.Any("error", fmt.Errorf("invalid config: %w", err)))
	}

	// validate secret is available
	fmt.Printf("foo secret: %#v\n", cfg.Secrets)

	// setup observability
	registry := prometheus.NewRegistry()
	metric := metrics.New(registry)
	state := "enabled"
	if cfg.OtelDisabled {
		state = "disabled"
	}
	logger.LogAttrs(ctx, slog.LevelInfo, "trace exporting "+state)
	traceShutdown, tracer, err := traces.New(ctx, cfg)
	if err != nil {
		logger.LogAttrs(ctx, slog.LevelError, "trace create", slog.Any("error", fmt.Errorf("unable to start tracing: %w", err)))
	}
	defer func() {
		// NOTE: We use a separate context for the shutdown.
		// This is so that the graceful shutdown of the main HTTP server doesn't
		// end up cancelling the traceShutdown context before it has a chance to
		// complete its shutdown process.
		if err := traceShutdown(context.WithoutCancel(ctx)); err != nil {
			logger.LogAttrs(ctx, slog.LevelError, "trace shutdown", slog.Any("error", err))
		}
	}()

	serverError := make(chan error, 1)
	timeout := time.Duration(10) * time.Second //nolint:gomnd

	// Setup separate /metrics endpoint for Prometheus scraping.
	// Needs to be separate listener as metrics service runs on its own port.
	var metricsAddr string
	if *metricsPort != 0 {
		metricsAddr = fmt.Sprintf("%s:%d", *metricsHost, *metricsPort)
		go func() {
			mux := http.NewServeMux()
			mux.Handle("/metrics", metric.Handler())
			server := &http.Server{
				Addr:              metricsAddr,
				Handler:           mux,
				ReadHeaderTimeout: timeout, // avoid slowloris https://app.deepsource.com/directory/analyzers/go/issues/GO-S2114
			}
			if err := server.ListenAndServe(); err != nil {
				logger.LogAttrs(ctx, slog.LevelError, "error handling metric requests",
					slog.Any("error", err),
					slog.String("metrics_addr", metricsAddr),
				)
				serverError <- err
			}
		}()
	}

	select {
	case err := <-serverError:
		return err
	case <-time.After(1 * time.Second):
		// no-op: allow logic to flow to setting up routing/http server
	}

	logger = logger.With(
		slog.Int("gomaxprocs", runtime.GOMAXPROCS(0)),
		slog.String("build_info", build.Info.String()),
		slog.String("cfg", *cfgPath),
		slog.String("cfg_addr", cfg.Addr),
		slog.String("metrics_addr", metricsAddr),
	)
	logger.LogAttrs(ctx, slog.LevelInfo, "service starting")

	// Routing
	mux := http.NewServeMux()

	// NOTE: We have separate pipelines.
	// One for the healthcheck and another for all other endpoints.
	// This is because we want to avoid tracing the healthcheck endpoint.
	// For some apps, tracing the healthcheck can result in poor performance.

	pipelineHealth := middleware.New(mux)
	pipelineHealth.Use(logging.WithRoutePattern(mux, logger), metric.WithHealthCheckMetrics)
	pipelineHealth.Handle("/healthcheck", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
		w.WriteHeader(http.StatusOK)
		_, _ = w.Write([]byte(fmt.Sprintf(`{"status": "success", "image_tag": "%s"}`, ImageTag)))
	}))

	// NOTE: The order of the middleware is deliberate.
	// WithSpan updates the request context with a Span ID.
	// WithRoutePattern checks for the Span ID and attaches it to the logger.
	// WithMetrics, WithRouteTag and WithParamAttrs all check for the route pattern added by WithRoutePattern.

	pipeline := middleware.New(mux)
	pipeline.Use(traces.WithSpan(tracer), logging.WithRoutePattern(mux, logger), metric.WithMetrics, traces.WithRouteTag, traces.WithParamAttrs)
	pipeline.Handle("/example/{id}", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
		w.WriteHeader(http.StatusOK)
		_, _ = w.Write([]byte(`{"endpoint": "example"}`))
	}))

	// HTTP Server
	server := &http.Server{
		Addr:              cfg.Addr,
		Handler:           mux,
		ReadHeaderTimeout: timeout, // avoid slowloris https://app.deepsource.com/directory/analyzers/go/issues/GO-S2114
	}
	go func() {
		if err := server.ListenAndServe(); err != nil {
			logger.LogAttrs(ctx, slog.LevelError, "error handling server requests",
				slog.Any("error", err),
				slog.String("metrics_addr", metricsAddr),
			)
			serverError <- err
		}
	}()

	select {
	case err := <-serverError:
		return err
	case <-time.After(1 * time.Second):
		// no-op: allow logic to flow to setting up graceful shutdown
	}

	// Graceful shutdown
	quit := make(chan os.Signal, 1)
	signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
	signalType := <-quit
	logger.LogAttrs(ctx, slog.LevelWarn, "http server graceful shutdown",
		slog.String("signal", signalType.String()),
	)
	// NOTE: Shutdown is given a separate context from the API server.
	if err := server.Shutdown(context.Background()); err != nil {
		logger.LogAttrs(ctx, slog.LevelError, "error shutting down server",
			slog.Any("error", fmt.Errorf("unable to gracefully stop server: %w", err)),
		)
	}

	return nil
}

5. logging.go

package logging

import (
	"context"
	"io"
	"log"
	"log/slog"
	"net/http"
	"os"
	"path/filepath"
	"regexp"
	"strconv"
	"strings"

	h "github.com/org/repo/internal/http"
	"github.com/org/repo/internal/traces"
)

// Level allows dynamically changing the output level via .Set() method.
// Defaults to [slog.LevelInfo].
var Level = new(slog.LevelVar)

// routePattern matches the content inside curly brackets.
var routePattern = regexp.MustCompile(`\{([^{}]+)\}`)

// Logger describes the set of features we want to expose from log/slog.
//
// NOTE: Don't confuse our custom With() signature with (*slog.Logger).With
// We return a Logger type where the standard library returns a *slog.Logger
type Logger interface {
	Enabled(ctx context.Context, level slog.Level) bool
	LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr)
	With(args ...any) Logger
	_private()
}

// NewLogger returns a logging.Logger configured for stderr.
func NewLogger() Logger {
	return NewLoggerWithOutputLevel(os.Stdout, Level)
}

// NewLoggerWithOutput returns a [Logger] configured with an output writer.
func NewLoggerWithOutput(w io.Writer) Logger {
	return (*logger)(slog.New(slog.NewJSONHandler(w, defaultOptions()).WithAttrs(defaultAttrs())))
}

// NewLoggerWithOutputLevel returns a [Logger] configured with an output writer and Level.
func NewLoggerWithOutputLevel(w io.Writer, l slog.Leveler) Logger {
	opts := defaultOptions()
	opts.Level = l
	return (*logger)(slog.New(slog.NewJSONHandler(w, opts).WithAttrs(defaultAttrs())))
}

// NewBareLoggerWithOutputLevel returns a [Logger] configured with an output location and [slog.Leveler].
// It does not include any additional attributes.
func NewBareLoggerWithOutputLevel(w io.Writer, l slog.Leveler) Logger {
	opts := defaultOptions()
	opts.Level = l
	return (*logger)(slog.New(slog.NewJSONHandler(w, opts)))
}

// nolint:revive
//
//lint:ignore U1000 Prevents any other package from implementing this interface
type private struct{} //nolint:unused

// IMPORTANT: logger is an alias to slog.Logger to avoid a double-pointer deference.
// All methods off the type will need to type-cast a *logger to *slog.Logger.
// With() must additionally type-cast back to a Logger compatible type.
type logger slog.Logger

func (*logger) _private() {}

func (l *logger) Enabled(ctx context.Context, level slog.Level) bool {
	return (*slog.Logger)(l).Enabled(ctx, level)
}

func (l *logger) LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) {
	(*slog.Logger)(l).LogAttrs(ctx, level, msg, attrs...)
}

func (l *logger) With(args ...any) Logger {
	return (*logger)((*slog.Logger)(l).With(args...))
}

// Adapt returns a [log.Logger] for use with packages that are not yet compatible with
// [log/slog].
func Adapt(l Logger, level slog.Level) *log.Logger {
	// _private() ensures this type assertion cannot panic.
	slogger := (*slog.Logger)(l.(*logger)) //nolint:revive,forcetypeassert
	return slog.NewLogLogger(slogger.Handler(), level)
}

func defaultOptions() *slog.HandlerOptions {
	return &slog.HandlerOptions{
		AddSource:   true,
		ReplaceAttr: slogReplaceAttr,
		Level:       Level,
	}
}

func defaultAttrs() []slog.Attr {
	return []slog.Attr{slog.Group("app",
		slog.String("name", build.Info.Project),
		slog.String("version", build.Info.Version),
	)}
}

// NullLogger discards logs.
func NullLogger() Logger {
	// NOTE: We pass a level not currently defined to reduce operational overhead.
	// The intent, unlike passing nil for the opts argument, is for the logger to
	// not even bother generating a message that will just be discarded.
	// An additional gap of 4 was used as it aligns with Go's original design.
	// https://github.com/golang/go/blob/1e95fc7/src/log/slog/level.go#L34-L42
	return (*logger)(slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError + 4}))) //nolint:gomnd
}

// slogReplaceAttr adjusts the log output.
//
// - Restricts these changes to top-level keys (not keys within groups)
//   - Changes default time field value to UTC time zone
//   - Replaces msg key with event
//   - Omits event field if empty
//   - Omits error field if when nil
//   - Truncates source's filename to example-app directory
//
// - Formats duration and delay values in microseconds as xxxxµs
//
// See https://pkg.go.dev/log/slog#HandlerOptions.ReplaceAttr
// N.B: TextHandler manages quoting attribute values as necessary.
func slogReplaceAttr(groups []string, a slog.Attr) slog.Attr {
	// Limit application of these rules only to top-level keys
	if len(groups) == 0 {
		// Set time zone to UTC
		if a.Key == slog.TimeKey {
			a.Value = slog.TimeValue(a.Value.Time().UTC())
			return a
		}
		// Use event as the default MessageKey, remove if empty
		if a.Key == slog.MessageKey {
			a.Key = "event"
			if a.Value.String() == "" {
				return slog.Attr{}
			}
			return a
		}
		// Display a 'partial' path.
		// Avoids ambiguity when multiple files have the same name across packages.
		// e.g. billing.go appears under 'global', 'billing' and 'server' packages.
		if a.Key == slog.SourceKey {
			if source, ok := a.Value.Any().(*slog.Source); ok {
				a.Key = "caller"
				if _, after, ok := strings.Cut(source.File, "example-app"+string(filepath.Separator)); ok {
					source.File = after
				}
			}
			return a
		}
	}

	// Remove error key=value when error is nil
	if a.Equal(slog.Any("error", error(nil))) {
		return slog.Attr{}
	}

	// Present durations and delays as xxxxµs
	switch a.Key {
	case "dur", "delay", "p95", "previous_p95", "remaining", "max_wait":
		a.Value = slog.StringValue(strconv.FormatInt(a.Value.Duration().Microseconds(), 10) + "µs")
	}

	return a
}

// WithRoutePattern logs the http handler with the route pattern.
func WithRoutePattern(mux *http.ServeMux, logger Logger) func(next http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			ctx := h.ContextWithRoutePattern(mux, r)
			route, _ := h.RoutePatternFromContext(ctx) //nolint:contextcheck
			attrs := []any{
				slog.String("pattern", route),
			}
			segments := routePattern.FindAllStringSubmatch(route, -1)
			for _, seg := range segments {
				if len(seg) == 2 { //nolint:gomnd
					key := seg[1]
					attrs = append(attrs, slog.String(key, r.PathValue(key)))
				}
			}
			if traceID := traces.GetTraceID(ctx); traceID != "" { //nolint:contextcheck
				logger = logger.With(slog.String("trace_id", traceID))
			}
			logger.LogAttrs(r.Context(), slog.LevelInfo, "handler_called", slog.Group("request", attrs...))
			next.ServeHTTP(w, r.WithContext(ctx)) //nolint:contextcheck
		})
	}
}

6. http.go

package http

import (
	"context"
	"net/http"
)

// RouteContextKey is a custom type for the key used in context.Context
type routeContextKey struct{}

// RoutePatternContextKey is the key to store a route pattern in context.Context
var RoutePatternContextKey = routeContextKey{}

// ContextWithRoutePattern sets the request route into the context.Context
func ContextWithRoutePattern(mux *http.ServeMux, r *http.Request) context.Context {
	_, route := mux.Handler(r)
	return context.WithValue(r.Context(), RoutePatternContextKey, route)
}

// RoutePatternFromContext returns the request route from the context.Context
func RoutePatternFromContext(ctx context.Context) (string, bool) {
	v := ctx.Value(RoutePatternContextKey)
	if v != nil {
		if s, ok := v.(string); ok {
			return s, ok
		}
	}
	return "", false
}

7. metrics.go

package metrics

import (
	"context"
	"net/http"
	"time"

	"github.com/prometheus/client_golang/prometheus"
	"github.com/prometheus/client_golang/prometheus/collectors"
	"github.com/prometheus/client_golang/prometheus/promauto"
	"github.com/prometheus/client_golang/prometheus/promhttp"

	h "github.com/org/repo/internal/http"
	"github.com/org/repo/internal/traces"
)

const (
	namespace = "fastly"
	subsystem = "domainr_api"
)

// New initializes our prometheus metrics.
func New(registry *prometheus.Registry) *Metrics {
	m := &Metrics{
		registry: registry,
	}

	registerAndSetBuildInfo(registry)
	registerUptime(registry)

	m.registerRequestMetrics()

	registry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{Namespace: namespace}))
	registry.MustRegister(collectors.NewGoCollector())

	return m
}

// Metrics contains all the Prometheus tracked by Domainr API
type Metrics struct {
	registry          *prometheus.Registry
	healthChecksTotal *prometheus.CounterVec
	requestsTotal     *prometheus.CounterVec
	requestDuration   *prometheus.HistogramVec
	requestSize       *prometheus.SummaryVec
	responseSize      *prometheus.SummaryVec
}

// Handler is the prometheus http metrics handler.
func (m *Metrics) Handler() http.Handler {
	return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{
		EnableOpenMetrics: true,
	})
}

// RegisterCounter returns a counter.
func (m *Metrics) RegisterCounter(name, help string) prometheus.Counter {
	customCounter := prometheus.NewCounter(prometheus.CounterOpts{
		Namespace: namespace,
		Subsystem: subsystem,
		Name:      name,
		Help:      help,
	})
	m.registry.MustRegister(customCounter)
	return customCounter
}

func (m *Metrics) WithMetrics(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		label := prometheus.Labels{}
		if route, ok := h.RoutePatternFromContext(r.Context()); ok {
			label = prometheus.Labels{"route": route}
		}

		// Wraps the provided http.Handler to observe the request result with the provided metrics.
		promhttp.InstrumentHandlerCounter(
			m.requestsTotal.MustCurryWith(label),
			promhttp.InstrumentHandlerDuration(
				m.requestDuration.MustCurryWith(label),
				promhttp.InstrumentHandlerRequestSize(
					m.requestSize.MustCurryWith(label),
					promhttp.InstrumentHandlerResponseSize(
						m.responseSize.MustCurryWith(label),
						next,
					),
				),
				[]promhttp.Option{promhttp.WithExemplarFromContext(getTraceID)}...,
			),
			[]promhttp.Option{promhttp.WithExemplarFromContext(getTraceID)}...,
		).ServeHTTP(w, r)
	})
}

func (m *Metrics) WithHealthCheckMetrics(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		label := prometheus.Labels{}
		if route, ok := h.RoutePatternFromContext(r.Context()); ok {
			label = prometheus.Labels{"route": route}
		}

		// Wraps the provided http.Handler to observe the request result with the provided metrics.
		promhttp.InstrumentHandlerCounter(
			m.healthChecksTotal.MustCurryWith(label),
			next,
		).ServeHTTP(w, r)
	})
}

func (m *Metrics) registerRequestMetrics() {
	m.healthChecksTotal = promauto.With(m.registry).NewCounterVec(
		prometheus.CounterOpts{
			Namespace: namespace,
			Subsystem: subsystem,
			Name:      "health_checks",
			Help:      "Number of health check requests received",
		}, []string{"code", "route"},
	)
	m.requestsTotal = promauto.With(m.registry).NewCounterVec(
		prometheus.CounterOpts{
			Namespace: namespace,
			Subsystem: subsystem,
			Name:      "requests_total",
			Help:      "Tracks the number of HTTP requests.",
		}, []string{"code", "route"},
	)
	m.requestDuration = promauto.With(m.registry).NewHistogramVec(
		prometheus.HistogramOpts{
			Namespace: namespace,
			Subsystem: subsystem,
			Name:      "request_duration_seconds",
			Help:      "Tracks the latencies for HTTP requests.",
			Buckets:   prometheus.DefBuckets,
		},
		[]string{"code", "route"},
	)
	m.requestSize = promauto.With(m.registry).NewSummaryVec(
		prometheus.SummaryOpts{
			Namespace: namespace,
			Subsystem: subsystem,
			Name:      "request_size_bytes",
			Help:      "Tracks the size of HTTP requests.",
		},
		[]string{"code", "route"},
	)
	m.responseSize = promauto.With(m.registry).NewSummaryVec(
		prometheus.SummaryOpts{
			Namespace: namespace,
			Subsystem: subsystem,
			Name:      "response_size_bytes",
			Help:      "Tracks the size of HTTP responses.",
		},
		[]string{"code", "route"},
	)
}

func getTraceID(ctx context.Context) prometheus.Labels {
	if traceID := traces.GetTraceID(ctx); traceID != "" {
		return prometheus.Labels{"traceID": traceID}
	}
	return nil
}

func registerAndSetBuildInfo(registry *prometheus.Registry) {
	info := build.Info
	version := prometheus.NewGaugeVec(prometheus.GaugeOpts{
		Namespace: namespace,
		Name:      "build_info",
		Help:      "Build information for domainr-api.",
	}, []string{
		"version",
		"revision",
		"goversion",
	})
	registry.MustRegister(version)

	version.WithLabelValues(
		info.Version+"-"+info.BuildNumber,
		info.GitRevision,
		info.GoVersion,
	).Set(1)
}

func registerUptime(registry *prometheus.Registry) {
	start := time.Now()
	uptime := prometheus.NewGaugeFunc(prometheus.GaugeOpts{
		Namespace: namespace,
		Name:      "uptime_duration_seconds",
		Help:      "Process uptime",
	}, func() float64 {
		return time.Since(start).Seconds()
	})
	registry.MustRegister(uptime)
}

8. traces.go

package traces

import (
	"context"
	"fmt"
	"net/http"
	"regexp"

	"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
	"go.opentelemetry.io/otel"
	"go.opentelemetry.io/otel/attribute"
	"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
	"go.opentelemetry.io/otel/propagation"
	"go.opentelemetry.io/otel/sdk/resource"
	sdktrace "go.opentelemetry.io/otel/sdk/trace"
	semconv "go.opentelemetry.io/otel/semconv/v1.24.0"
	"go.opentelemetry.io/otel/trace"

	"github.com/org/repo/internal/config"
	h "github.com/org/repo/internal/http"
)

const (
	serviceName = "example-app"
	version     = "0.1.0"
)

// routePattern matches the content inside curly brackets.
var routePattern = regexp.MustCompile(`\{([^{}]+)\}`)

func newResource() *resource.Resource {
	info := build.Info
	return resource.NewWithAttributes(
		semconv.SchemaURL,
		semconv.ServiceName(serviceName),
		semconv.ServiceVersionKey.String(
			info.Version+"-"+info.BuildNumber,
		),
	)
}

func New(ctx context.Context, cfg *config.Config) (func(context.Context) error, trace.Tracer, error) {
	var tracerProvider *sdktrace.TracerProvider
	if cfg.OtelDisabled {
		tracerProvider = sdktrace.NewTracerProvider(
			sdktrace.WithResource(newResource()),
			sdktrace.WithSampler(sdktrace.NeverSample()),
		)
	} else {
		exporter, err := otlptracehttp.New(ctx)
		if err != nil {
			return nil, nil, fmt.Errorf("creating OTLP trace exporter: %w", err)
		}

		tracerProvider = sdktrace.NewTracerProvider(
			sdktrace.WithBatcher(exporter),
			sdktrace.WithResource(newResource()),
		)
	}

	otel.SetTracerProvider(tracerProvider)
	tracer := tracerProvider.Tracer(serviceName, trace.WithInstrumentationVersion(version))
	otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}))

	return tracerProvider.Shutdown, tracer, nil
}

func NewHandler(handler http.Handler) http.Handler {
	return otelhttp.NewHandler(handler, "server",
		otelhttp.WithMessageEvents(otelhttp.ReadEvents, otelhttp.WriteEvents),
		otelhttp.WithFilter(otelReqFilter),
	)
}

// GetTraceID returns the current traceID if it will be sampled and exported
func GetTraceID(ctx context.Context) string {
	span := trace.SpanFromContext(ctx)
	sampled := span.SpanContext().IsSampled()
	traceID := span.SpanContext().TraceID().String()
	if traceID != "" && sampled {
		return traceID
	}
	return ""
}

func Tracer() trace.Tracer {
	return otel.GetTracerProvider().Tracer(serviceName)
}

// Start creates a span and a context.Context containing the newly-created span.
func Start(ctx context.Context, name string) (context.Context, trace.Span) {
	return Tracer().Start(ctx, name, trace.WithSpanKind(trace.SpanKindInternal))
}

func otelReqFilter(req *http.Request) bool {
	return req.URL.Path != "/healthcheck"
}

// WithSpan starts a trace.
func WithSpan(tracer trace.Tracer) func(next http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			ctx, span := tracer.Start(r.Context(), "handler_called")
			defer span.End()
			next.ServeHTTP(w, r.WithContext(ctx))
		})
	}
}

// WithRouteTag annotates spans and metrics with the provided route name with
// HTTP route attribute.
func WithRouteTag(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		route, _ := h.RoutePatternFromContext(r.Context()) //nolint:contextcheck
		ctx := otel.GetTextMapPropagator().Extract(
			r.Context(), propagation.HeaderCarrier(r.Header),
		)
		otelhttp.WithRouteTag(route, next).ServeHTTP(w, r.WithContext(ctx))
	})
}

// WithParamAttrs annotates spans with route pattern parameters and their values.
func WithParamAttrs(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		ctx := r.Context()
		route, _ := h.RoutePatternFromContext(ctx) //nolint:contextcheck
		span := trace.SpanFromContext(ctx)
		segments := routePattern.FindAllStringSubmatch(route, -1)
		for _, seg := range segments {
			if len(seg) == 2 { //nolint:gomnd
				key := seg[1]
				span.SetAttributes(attribute.String(key, r.PathValue(key)))
			}
		}
		next.ServeHTTP(w, r)
	})
}