> ## Documentation Index
> Fetch the complete documentation index at: https://docs.go-mizu.dev/llms.txt
> Use this file to discover all available pages before exploring further.

# Middleware

> Add logging, authentication, rate limiting, and other cross-cutting concerns to your API

# Middleware

Middleware lets you add functionality that runs before or after your service methods - like logging, authentication, or rate limiting. You write it once, and it applies to all (or some) of your methods automatically.

## What Is Middleware?

Think of middleware like security guards at a building entrance. Every visitor (request) passes through them before reaching their destination (your method). The guard can:

1. **Check credentials** - Is this person allowed in?
2. **Log the visit** - Record who came and when
3. **Limit access** - Only let 10 people in per minute
4. **Add information** - Give them a visitor badge (add data to context)

```
Request → Middleware → Your Method → Middleware → Response
              ↓                          ↓
         (before)                    (after)
```

## Two Types of Middleware

Contract supports two levels of middleware:

| Type                | Level         | Use For                           |
| ------------------- | ------------- | --------------------------------- |
| **Custom Invokers** | Method calls  | Logging, auth checks, metrics     |
| **HTTP Middleware** | HTTP requests | CORS, request IDs, panic recovery |

## Custom Invokers - Method-Level Middleware

The most common way to add middleware is by creating a custom invoker that wraps method calls.

### Basic Structure

```go theme={null}
import contract "github.com/go-mizu/mizu/contract/v2"

// A custom invoker wraps the default invoker
type LoggingInvoker struct {
    inner contract.TransportInvoker  // The original invoker
}

// Implement the Invoke method
func (l *LoggingInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    // BEFORE: runs before your method
    start := time.Now()
    log.Printf("Starting: %s", method.Name)

    // Call the actual method
    result, err := l.inner.Invoke(ctx, method, input)

    // AFTER: runs after your method
    log.Printf("Finished: %s (took %v)", method.Name, time.Since(start))

    return result, err
}
```

### Using Your Custom Invoker

Pass your invoker when mounting transports:

```go theme={null}
import (
    contract "github.com/go-mizu/mizu/contract/v2"
    "github.com/go-mizu/mizu/contract/v2/transport/mcp"
    "github.com/go-mizu/mizu/contract/v2/transport/jsonrpc"
    "github.com/go-mizu/mizu/contract/v2/transport/rest"
)

// Create the custom invoker
loggingInvoker := &LoggingInvoker{
    inner: contract.DefaultInvoker(svc),  // Wrap the default
}

// Use with REST
rest.Mount(app.Router, svc, rest.WithInvoker(loggingInvoker))

// Use with MCP
mcp.Mount(app.Router, "/mcp", svc, mcp.WithInvoker(loggingInvoker))

// Use with JSON-RPC
jsonrpc.Mount(app.Router, "/rpc", svc, jsonrpc.WithInvoker(loggingInvoker))
```

## Common Middleware Examples

### Logging

Log every method call with timing:

```go theme={null}
type LoggingInvoker struct {
    inner  contract.TransportInvoker
    logger *slog.Logger
}

func (l *LoggingInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    start := time.Now()

    // Call the method
    result, err := l.inner.Invoke(ctx, method, input)

    // Log the result
    l.logger.Info("method called",
        "method", method.Name,
        "duration_ms", time.Since(start).Milliseconds(),
        "success", err == nil,
    )

    return result, err
}

// Usage
loggingInvoker := &LoggingInvoker{
    inner:  contract.DefaultInvoker(svc),
    logger: slog.Default(),
}
```

### Authentication

Check if the user is logged in:

```go theme={null}
type AuthInvoker struct {
    inner contract.TransportInvoker
}

func (a *AuthInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    // Get user from context (set by HTTP middleware earlier)
    user := UserFromContext(ctx)

    // Check if this method requires authentication
    if requiresAuth(method.Name) && user == nil {
        return nil, contract.ErrUnauthenticated("please log in first")
    }

    // User is authenticated, proceed
    return a.inner.Invoke(ctx, method, input)
}

// Helper to check which methods need auth
func requiresAuth(methodName string) bool {
    // Public methods that don't need auth
    publicMethods := map[string]bool{
        "Health": true,
        "Login":  true,
        "Signup": true,
    }
    return !publicMethods[methodName]
}
```

### Rate Limiting

Prevent too many requests:

```go theme={null}
import "golang.org/x/time/rate"

type RateLimitInvoker struct {
    inner   contract.TransportInvoker
    limiter *rate.Limiter
}

func NewRateLimitInvoker(inner contract.TransportInvoker, requestsPerSecond int) *RateLimitInvoker {
    return &RateLimitInvoker{
        inner:   inner,
        limiter: rate.NewLimiter(rate.Limit(requestsPerSecond), requestsPerSecond),
    }
}

func (r *RateLimitInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    // Check if we're over the limit
    if !r.limiter.Allow() {
        return nil, contract.ErrResourceExhausted("too many requests, please slow down")
    }

    // Under the limit, proceed
    return r.inner.Invoke(ctx, method, input)
}
```

### Metrics (Prometheus)

Track method calls for monitoring:

```go theme={null}
import "github.com/prometheus/client_golang/prometheus"

type MetricsInvoker struct {
    inner    contract.TransportInvoker
    requests *prometheus.CounterVec
    duration *prometheus.HistogramVec
}

func NewMetricsInvoker(inner contract.TransportInvoker) *MetricsInvoker {
    requests := prometheus.NewCounterVec(
        prometheus.CounterOpts{
            Name: "api_requests_total",
            Help: "Total API requests",
        },
        []string{"method", "status"},
    )

    duration := prometheus.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "api_request_duration_seconds",
            Help:    "API request duration",
            Buckets: prometheus.DefBuckets,
        },
        []string{"method"},
    )

    prometheus.MustRegister(requests, duration)

    return &MetricsInvoker{
        inner:    inner,
        requests: requests,
        duration: duration,
    }
}

func (m *MetricsInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    start := time.Now()

    result, err := m.inner.Invoke(ctx, method, input)

    // Record metrics
    status := "success"
    if err != nil {
        status = "error"
    }
    m.requests.WithLabelValues(method.Name, status).Inc()
    m.duration.WithLabelValues(method.Name).Observe(time.Since(start).Seconds())

    return result, err
}
```

### Distributed Tracing

Add trace spans for debugging:

```go theme={null}
import "go.opentelemetry.io/otel/trace"

type TracingInvoker struct {
    inner  contract.TransportInvoker
    tracer trace.Tracer
}

func (t *TracingInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    // Start a new span for this method call
    ctx, span := t.tracer.Start(ctx, method.Name)
    defer span.End()

    // Call the method
    result, err := t.inner.Invoke(ctx, method, input)

    // Record error if any
    if err != nil {
        span.RecordError(err)
        span.SetStatus(codes.Error, err.Error())
    }

    return result, err
}
```

## Chaining Multiple Middleware

You often want multiple middleware together. Chain them by wrapping one inside another:

```go theme={null}
// Helper function to chain middleware
func ChainInvokers(
    base contract.TransportInvoker,
    wrappers ...func(contract.TransportInvoker) contract.TransportInvoker,
) contract.TransportInvoker {
    // Apply wrappers from last to first
    // So the first wrapper runs first
    for i := len(wrappers) - 1; i >= 0; i-- {
        base = wrappers[i](base)
    }
    return base
}

// Define wrapper functions
func withLogging(inner contract.TransportInvoker) contract.TransportInvoker {
    return &LoggingInvoker{inner: inner}
}

func withMetrics(inner contract.TransportInvoker) contract.TransportInvoker {
    return NewMetricsInvoker(inner)
}

func withAuth(inner contract.TransportInvoker) contract.TransportInvoker {
    return &AuthInvoker{inner: inner}
}

func withRateLimit(inner contract.TransportInvoker) contract.TransportInvoker {
    return NewRateLimitInvoker(inner, 100)  // 100 req/sec
}

// Chain them together
invoker := ChainInvokers(
    contract.DefaultInvoker(svc),
    withLogging,     // Runs first (outermost)
    withMetrics,     // Runs second
    withRateLimit,   // Runs third
    withAuth,        // Runs fourth (innermost before method)
)

// Use the chained invoker
mcp.Mount(mux, "/mcp", svc, mcp.WithInvoker(invoker))
```

### Order Matters!

The order of middleware is important. Think about it like layers of an onion:

```
Request enters →
    [1. Logging starts]
        [2. Metrics starts]
            [3. Rate limit check]
                [4. Auth check]
                    [Your Method]
                [4. Auth done]
            [3. Rate limit done]
        [2. Metrics records]
    [1. Logging finishes]
← Response exits
```

Good ordering:

1. **Logging** first - logs everything, including rejected requests
2. **Metrics** second - tracks all requests
3. **Rate limiting** third - rejects before expensive checks
4. **Authentication** fourth - before business logic

## HTTP Middleware

For HTTP-level concerns (like CORS or request IDs), use standard Go HTTP middleware:

### CORS (Cross-Origin Requests)

Allow browsers from other domains to call your API:

```go theme={null}
func withCORS(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // Set CORS headers
        w.Header().Set("Access-Control-Allow-Origin", "*")
        w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
        w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")

        // Handle preflight requests
        if r.Method == "OPTIONS" {
            w.WriteHeader(http.StatusOK)
            return
        }

        next.ServeHTTP(w, r)
    })
}

// Usage
handler := mcp.NewHandler(svc)
mux.Handle("/mcp", withCORS(handler))
```

### Request ID

Add a unique ID to every request for debugging:

```go theme={null}
import "github.com/google/uuid"

// Key for storing request ID in context
type contextKey string
const requestIDKey contextKey = "requestID"

func withRequestID(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // Get existing ID or create new one
        requestID := r.Header.Get("X-Request-ID")
        if requestID == "" {
            requestID = uuid.New().String()
        }

        // Add to context
        ctx := context.WithValue(r.Context(), requestIDKey, requestID)

        // Add to response headers
        w.Header().Set("X-Request-ID", requestID)

        // Continue with updated context
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

// Access the ID in your service method
func (s *Service) Create(ctx context.Context, in *CreateInput) (*Todo, error) {
    requestID := ctx.Value(requestIDKey).(string)
    log.Printf("[%s] Creating todo: %s", requestID, in.Title)
    // ...
}
```

### Panic Recovery

Catch panics and return proper errors:

```go theme={null}
func withRecovery(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        defer func() {
            if err := recover(); err != nil {
                log.Printf("PANIC: %v\n%s", err, debug.Stack())
                http.Error(w, "internal server error", http.StatusInternalServerError)
            }
        }()

        next.ServeHTTP(w, r)
    })
}
```

### Authentication (Setting User in Context)

Extract user from token and add to context:

```go theme={null}
type User struct {
    ID    string
    Email string
    Role  string
}

type contextKey string
const userKey contextKey = "user"

func withUser(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        token := r.Header.Get("Authorization")

        if token != "" {
            // Validate token and get user
            user, err := validateToken(token)
            if err == nil {
                // Add user to context
                ctx := context.WithValue(r.Context(), userKey, user)
                r = r.WithContext(ctx)
            }
        }

        next.ServeHTTP(w, r)
    })
}

// Helper to get user from context
func UserFromContext(ctx context.Context) *User {
    user, _ := ctx.Value(userKey).(*User)
    return user
}
```

## Combining HTTP and Invoker Middleware

Use both together for complete coverage:

```go theme={null}
import "yourapp/todo"

func main() {
    impl := todo.NewService()
    svc := contract.Register[todo.API](impl,
        contract.WithDefaultResource("todos"),
    )

    // Create invoker middleware chain
    invoker := ChainInvokers(
        contract.DefaultInvoker(svc),
        withLogging,
        withMetrics,
        withAuth,
    )

    // Create HTTP handler with invoker
    handler := mcp.NewHandler(svc, mcp.WithInvoker(invoker))

    // Wrap with HTTP middleware
    handler = withRecovery(handler)
    handler = withRequestID(handler)
    handler = withUser(handler)
    handler = withCORS(handler)

    mux := http.NewServeMux()
    mux.Handle("/mcp", handler)

    http.ListenAndServe(":8080", mux)
}
```

Flow:

1. **CORS** - Allow cross-origin requests
2. **User** - Extract user from token
3. **Request ID** - Add tracking ID
4. **Recovery** - Catch panics
5. **Logging** - Log method calls
6. **Metrics** - Track performance
7. **Auth** - Verify user is authorized
8. **Your Method** - Actually do the work

## Passing Data Through Context

Context is how middleware shares data:

```go theme={null}
// HTTP middleware sets the user
func withUser(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        user := extractUser(r)
        ctx := context.WithValue(r.Context(), userKey, user)
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

// Invoker middleware can access the user
func (a *AuthInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    user := ctx.Value(userKey).(*User)
    if user == nil && requiresAuth(method.Name) {
        return nil, contract.ErrUnauthenticated("please log in")
    }
    return a.inner.Invoke(ctx, method, input)
}

// Service method can also access the user
func (s *Service) Create(ctx context.Context, in *CreateInput) (*Todo, error) {
    user := ctx.Value(userKey).(*User)
    log.Printf("User %s is creating a todo", user.Email)
    // ...
}
```

## Best Practices

### 1. Keep Middleware Focused

Each middleware should do one thing well:

```go theme={null}
// Good: Each does one thing
withLogging(withAuth(withMetrics(base)))

// Bad: One middleware trying to do everything
type EverythingMiddleware struct{} // Does logging, auth, metrics, etc.
```

### 2. Don't Do Heavy Work

Middleware runs on every request. Keep it fast:

```go theme={null}
// Good: Quick cache lookup
func (a *AuthInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    if !a.tokenCache.IsValid(ctx) {
        return nil, contract.ErrUnauthenticated("invalid token")
    }
    return a.inner.Invoke(ctx, method, input)
}

// Bad: Database call on every request
func (a *AuthInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    user, err := a.db.FindUser(ctx, tokenFromCtx(ctx))  // Slow!
    // ...
}
```

### 3. Handle Errors Properly

Use Contract's error types for consistent handling:

```go theme={null}
func (a *AuthInvoker) Invoke(ctx context.Context, method *contract.Method, input []byte) (any, error) {
    if !isAuthenticated(ctx) {
        // Use Contract error - maps correctly to HTTP 401
        return nil, contract.ErrUnauthenticated("please log in")
    }
    return a.inner.Invoke(ctx, method, input)
}
```

### 4. Make Middleware Configurable

```go theme={null}
type LoggingInvoker struct {
    inner    contract.TransportInvoker
    logger   *slog.Logger
    logInput bool  // Option to log input
}

func NewLoggingInvoker(inner contract.TransportInvoker, opts ...LoggingOption) *LoggingInvoker {
    l := &LoggingInvoker{
        inner:  inner,
        logger: slog.Default(),
    }
    for _, opt := range opts {
        opt(l)
    }
    return l
}

type LoggingOption func(*LoggingInvoker)

func WithLogger(logger *slog.Logger) LoggingOption {
    return func(l *LoggingInvoker) { l.logger = logger }
}

func WithInputLogging(enabled bool) LoggingOption {
    return func(l *LoggingInvoker) { l.logInput = enabled }
}
```

## See Also

* [Service Definition](/contract/service) - Writing services
* [Error Handling](/contract/errors) - Return errors from middleware
* [Architecture](/contract/architecture) - How transports work
* [Invokers](/contract/invoker) - Deep dive into invokers
