Skip to main content

Middleware

Middleware lets you add functionality that runs before or after your service methods - like logging, authentication, or rate limiting. You write it once, and it applies to all (or some) of your methods automatically.

What Is Middleware?

Think of middleware like security guards at a building entrance. Every visitor (request) passes through them before reaching their destination (your method). The guard can:
  1. Check credentials - Is this person allowed in?
  2. Log the visit - Record who came and when
  3. Limit access - Only let 10 people in per minute
  4. Add information - Give them a visitor badge (add data to context)
Request → Middleware → Your Method → Middleware → Response
              ↓                          ↓
         (before)                    (after)

Two Types of Middleware

Contract supports two levels of middleware:
TypeLevelUse For
Custom InvokersMethod callsLogging, auth checks, metrics
HTTP MiddlewareHTTP requestsCORS, request IDs, panic recovery

Custom Invokers - Method-Level Middleware

The most common way to add middleware is by creating a custom invoker that wraps method calls.

Basic Structure

import contract "github.com/go-mizu/mizu/contract/v2"

// A custom invoker wraps the default invoker
type LoggingInvoker struct {
    inner contract.TransportInvoker  // The original invoker
}

// Implement the Invoke method
func (l *LoggingInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    // BEFORE: runs before your method
    start := time.Now()
    log.Printf("Starting: %s", method.Name)

    // Call the actual method
    result, err := l.inner.Invoke(ctx, method, input)

    // AFTER: runs after your method
    log.Printf("Finished: %s (took %v)", method.Name, time.Since(start))

    return result, err
}

Using Your Custom Invoker

Pass your invoker when mounting transports:
import (
    contract "github.com/go-mizu/mizu/contract/v2"
    "github.com/go-mizu/mizu/contract/v2/transport/mcp"
    "github.com/go-mizu/mizu/contract/v2/transport/jsonrpc"
    "github.com/go-mizu/mizu/contract/v2/transport/rest"
)

// Create the custom invoker
loggingInvoker := &LoggingInvoker{
    inner: contract.DefaultInvoker(svc),  // Wrap the default
}

// Use with REST
rest.Mount(app.Router, svc, rest.WithInvoker(loggingInvoker))

// Use with MCP
mcp.Mount(app.Router, "/mcp", svc, mcp.WithInvoker(loggingInvoker))

// Use with JSON-RPC
jsonrpc.Mount(app.Router, "/rpc", svc, jsonrpc.WithInvoker(loggingInvoker))

Common Middleware Examples

Logging

Log every method call with timing:
type LoggingInvoker struct {
    inner  contract.TransportInvoker
    logger *slog.Logger
}

func (l *LoggingInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    start := time.Now()

    // Call the method
    result, err := l.inner.Invoke(ctx, method, input)

    // Log the result
    l.logger.Info("method called",
        "method", method.Name,
        "duration_ms", time.Since(start).Milliseconds(),
        "success", err == nil,
    )

    return result, err
}

// Usage
loggingInvoker := &LoggingInvoker{
    inner:  contract.DefaultInvoker(svc),
    logger: slog.Default(),
}

Authentication

Check if the user is logged in:
type AuthInvoker struct {
    inner contract.TransportInvoker
}

func (a *AuthInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    // Get user from context (set by HTTP middleware earlier)
    user := UserFromContext(ctx)

    // Check if this method requires authentication
    if requiresAuth(method.Name) && user == nil {
        return nil, contract.ErrUnauthenticated("please log in first")
    }

    // User is authenticated, proceed
    return a.inner.Invoke(ctx, method, input)
}

// Helper to check which methods need auth
func requiresAuth(methodName string) bool {
    // Public methods that don't need auth
    publicMethods := map[string]bool{
        "Health": true,
        "Login":  true,
        "Signup": true,
    }
    return !publicMethods[methodName]
}

Rate Limiting

Prevent too many requests:
import "golang.org/x/time/rate"

type RateLimitInvoker struct {
    inner   contract.TransportInvoker
    limiter *rate.Limiter
}

func NewRateLimitInvoker(inner contract.TransportInvoker, requestsPerSecond int) *RateLimitInvoker {
    return &RateLimitInvoker{
        inner:   inner,
        limiter: rate.NewLimiter(rate.Limit(requestsPerSecond), requestsPerSecond),
    }
}

func (r *RateLimitInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    // Check if we're over the limit
    if !r.limiter.Allow() {
        return nil, contract.ErrResourceExhausted("too many requests, please slow down")
    }

    // Under the limit, proceed
    return r.inner.Invoke(ctx, method, input)
}

Metrics (Prometheus)

Track method calls for monitoring:
import "github.com/prometheus/client_golang/prometheus"

type MetricsInvoker struct {
    inner    contract.TransportInvoker
    requests *prometheus.CounterVec
    duration *prometheus.HistogramVec
}

func NewMetricsInvoker(inner contract.TransportInvoker) *MetricsInvoker {
    requests := prometheus.NewCounterVec(
        prometheus.CounterOpts{
            Name: "api_requests_total",
            Help: "Total API requests",
        },
        []string{"method", "status"},
    )

    duration := prometheus.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "api_request_duration_seconds",
            Help:    "API request duration",
            Buckets: prometheus.DefBuckets,
        },
        []string{"method"},
    )

    prometheus.MustRegister(requests, duration)

    return &MetricsInvoker{
        inner:    inner,
        requests: requests,
        duration: duration,
    }
}

func (m *MetricsInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    start := time.Now()

    result, err := m.inner.Invoke(ctx, method, input)

    // Record metrics
    status := "success"
    if err != nil {
        status = "error"
    }
    m.requests.WithLabelValues(method.Name, status).Inc()
    m.duration.WithLabelValues(method.Name).Observe(time.Since(start).Seconds())

    return result, err
}

Distributed Tracing

Add trace spans for debugging:
import "go.opentelemetry.io/otel/trace"

type TracingInvoker struct {
    inner  contract.TransportInvoker
    tracer trace.Tracer
}

func (t *TracingInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    // Start a new span for this method call
    ctx, span := t.tracer.Start(ctx, method.Name)
    defer span.End()

    // Call the method
    result, err := t.inner.Invoke(ctx, method, input)

    // Record error if any
    if err != nil {
        span.RecordError(err)
        span.SetStatus(codes.Error, err.Error())
    }

    return result, err
}

Chaining Multiple Middleware

You often want multiple middleware together. Chain them by wrapping one inside another:
// Helper function to chain middleware
func ChainInvokers(
    base contract.TransportInvoker,
    wrappers ...func(contract.TransportInvoker) contract.TransportInvoker,
) contract.TransportInvoker {
    // Apply wrappers from last to first
    // So the first wrapper runs first
    for i := len(wrappers) - 1; i >= 0; i-- {
        base = wrappers[i](base)
    }
    return base
}

// Define wrapper functions
func withLogging(inner contract.TransportInvoker) contract.TransportInvoker {
    return &LoggingInvoker{inner: inner}
}

func withMetrics(inner contract.TransportInvoker) contract.TransportInvoker {
    return NewMetricsInvoker(inner)
}

func withAuth(inner contract.TransportInvoker) contract.TransportInvoker {
    return &AuthInvoker{inner: inner}
}

func withRateLimit(inner contract.TransportInvoker) contract.TransportInvoker {
    return NewRateLimitInvoker(inner, 100)  // 100 req/sec
}

// Chain them together
invoker := ChainInvokers(
    contract.DefaultInvoker(svc),
    withLogging,     // Runs first (outermost)
    withMetrics,     // Runs second
    withRateLimit,   // Runs third
    withAuth,        // Runs fourth (innermost before method)
)

// Use the chained invoker
mcp.Mount(mux, "/mcp", svc, mcp.WithInvoker(invoker))

Order Matters!

The order of middleware is important. Think about it like layers of an onion:
Request enters →
    [1. Logging starts]
        [2. Metrics starts]
            [3. Rate limit check]
                [4. Auth check]
                    [Your Method]
                [4. Auth done]
            [3. Rate limit done]
        [2. Metrics records]
    [1. Logging finishes]
← Response exits
Good ordering:
  1. Logging first - logs everything, including rejected requests
  2. Metrics second - tracks all requests
  3. Rate limiting third - rejects before expensive checks
  4. Authentication fourth - before business logic

HTTP Middleware

For HTTP-level concerns (like CORS or request IDs), use standard Go HTTP middleware:

CORS (Cross-Origin Requests)

Allow browsers from other domains to call your API:
func withCORS(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // Set CORS headers
        w.Header().Set("Access-Control-Allow-Origin", "*")
        w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
        w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")

        // Handle preflight requests
        if r.Method == "OPTIONS" {
            w.WriteHeader(http.StatusOK)
            return
        }

        next.ServeHTTP(w, r)
    })
}

// Usage
handler := mcp.NewHandler(svc)
mux.Handle("/mcp", withCORS(handler))

Request ID

Add a unique ID to every request for debugging:
import "github.com/google/uuid"

// Key for storing request ID in context
type contextKey string
const requestIDKey contextKey = "requestID"

func withRequestID(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // Get existing ID or create new one
        requestID := r.Header.Get("X-Request-ID")
        if requestID == "" {
            requestID = uuid.New().String()
        }

        // Add to context
        ctx := context.WithValue(r.Context(), requestIDKey, requestID)

        // Add to response headers
        w.Header().Set("X-Request-ID", requestID)

        // Continue with updated context
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

// Access the ID in your service method
func (s *Service) Create(ctx context.Context, in *CreateInput) (*Todo, error) {
    requestID := ctx.Value(requestIDKey).(string)
    log.Printf("[%s] Creating todo: %s", requestID, in.Title)
    // ...
}

Panic Recovery

Catch panics and return proper errors:
func withRecovery(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        defer func() {
            if err := recover(); err != nil {
                log.Printf("PANIC: %v\n%s", err, debug.Stack())
                http.Error(w, "internal server error", http.StatusInternalServerError)
            }
        }()

        next.ServeHTTP(w, r)
    })
}

Authentication (Setting User in Context)

Extract user from token and add to context:
type User struct {
    ID    string
    Email string
    Role  string
}

type contextKey string
const userKey contextKey = "user"

func withUser(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        token := r.Header.Get("Authorization")

        if token != "" {
            // Validate token and get user
            user, err := validateToken(token)
            if err == nil {
                // Add user to context
                ctx := context.WithValue(r.Context(), userKey, user)
                r = r.WithContext(ctx)
            }
        }

        next.ServeHTTP(w, r)
    })
}

// Helper to get user from context
func UserFromContext(ctx context.Context) *User {
    user, _ := ctx.Value(userKey).(*User)
    return user
}

Combining HTTP and Invoker Middleware

Use both together for complete coverage:
import "yourapp/todo"

func main() {
    impl := todo.NewService()
    svc := contract.Register[todo.API](impl,
        contract.WithDefaultResource("todos"),
    )

    // Create invoker middleware chain
    invoker := ChainInvokers(
        contract.DefaultInvoker(svc),
        withLogging,
        withMetrics,
        withAuth,
    )

    // Create HTTP handler with invoker
    handler := mcp.NewHandler(svc, mcp.WithInvoker(invoker))

    // Wrap with HTTP middleware
    handler = withRecovery(handler)
    handler = withRequestID(handler)
    handler = withUser(handler)
    handler = withCORS(handler)

    mux := http.NewServeMux()
    mux.Handle("/mcp", handler)

    http.ListenAndServe(":8080", mux)
}
Flow:
  1. CORS - Allow cross-origin requests
  2. User - Extract user from token
  3. Request ID - Add tracking ID
  4. Recovery - Catch panics
  5. Logging - Log method calls
  6. Metrics - Track performance
  7. Auth - Verify user is authorized
  8. Your Method - Actually do the work

Passing Data Through Context

Context is how middleware shares data:
// HTTP middleware sets the user
func withUser(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        user := extractUser(r)
        ctx := context.WithValue(r.Context(), userKey, user)
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

// Invoker middleware can access the user
func (a *AuthInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    user := ctx.Value(userKey).(*User)
    if user == nil && requiresAuth(method.Name) {
        return nil, contract.ErrUnauthenticated("please log in")
    }
    return a.inner.Invoke(ctx, method, input)
}

// Service method can also access the user
func (s *Service) Create(ctx context.Context, in *CreateInput) (*Todo, error) {
    user := ctx.Value(userKey).(*User)
    log.Printf("User %s is creating a todo", user.Email)
    // ...
}

Best Practices

1. Keep Middleware Focused

Each middleware should do one thing well:
// Good: Each does one thing
withLogging(withAuth(withMetrics(base)))

// Bad: One middleware trying to do everything
type EverythingMiddleware struct{} // Does logging, auth, metrics, etc.

2. Don’t Do Heavy Work

Middleware runs on every request. Keep it fast:
// Good: Quick cache lookup
func (a *AuthInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    if !a.tokenCache.IsValid(ctx) {
        return nil, contract.ErrUnauthenticated("invalid token")
    }
    return a.inner.Invoke(ctx, method, input)
}

// Bad: Database call on every request
func (a *AuthInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    user, err := a.db.FindUser(ctx, tokenFromCtx(ctx))  // Slow!
    // ...
}

3. Handle Errors Properly

Use Contract’s error types for consistent handling:
func (a *AuthInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    if !isAuthenticated(ctx) {
        // Use Contract error - maps correctly to HTTP 401
        return nil, contract.ErrUnauthenticated("please log in")
    }
    return a.inner.Invoke(ctx, method, input)
}

4. Make Middleware Configurable

type LoggingInvoker struct {
    inner    contract.TransportInvoker
    logger   *slog.Logger
    logInput bool  // Option to log input
}

func NewLoggingInvoker(inner contract.TransportInvoker, opts ...LoggingOption) *LoggingInvoker {
    l := &LoggingInvoker{
        inner:  inner,
        logger: slog.Default(),
    }
    for _, opt := range opts {
        opt(l)
    }
    return l
}

type LoggingOption func(*LoggingInvoker)

func WithLogger(logger *slog.Logger) LoggingOption {
    return func(l *LoggingInvoker) { l.logger = logger }
}

func WithInputLogging(enabled bool) LoggingOption {
    return func(l *LoggingInvoker) { l.logInput = enabled }
}

See Also