feat: Adds middleware to routes

This commit is contained in:
2025-01-19 14:20:51 -05:00
parent b23dc6bf07
commit 185ebbbc15
6 changed files with 158 additions and 37 deletions

View File

@@ -1,10 +1,18 @@
import DatabaseClient
import DatabaseClientLive
import Dependencies
import Fluent
import SharedModels
import Vapor
private let apiMiddleware: [any Middleware] = [
UserPasswordAuthenticator(),
UserTokenAuthenticator(),
User.guardMiddleware()
]
extension ApiRoute {
var middleware: [any Middleware]? { apiMiddleware }
func handle(request: Request) async throws -> any AsyncResponseEncodable {
switch self {
case let .employee(route):
@@ -22,6 +30,7 @@ extension ApiRoute {
}
extension ApiRoute.EmployeeApiRoute {
func handle(request: Request) async throws -> any AsyncResponseEncodable {
@Dependency(\.database) var database
switch self {
@@ -44,6 +53,7 @@ extension ApiRoute.EmployeeApiRoute {
}
extension ApiRoute.PurchaseOrderApiRoute {
func handle(request: Request) async throws -> any AsyncResponseEncodable {
@Dependency(\.database.purchaseOrders) var purchaseOrders
switch self {
@@ -67,6 +77,7 @@ extension ApiRoute.PurchaseOrderApiRoute {
// TODO: Add Login.
extension ApiRoute.UserApiRoute {
func handle(request: Request) async throws -> any AsyncResponseEncodable {
@Dependency(\.database.users) var users
switch self {

View File

@@ -1,10 +1,30 @@
import DatabaseClient
import DatabaseClientLive
import Dependencies
import Elementary
import SharedModels
import Vapor
private let viewProtectedMiddleware: [any Middleware] = [
UserPasswordAuthenticator(),
UserSessionAuthenticator(),
User.redirectMiddleware { req in
"/login?next=\(req.url)"
}
]
extension SharedModels.ViewRoute {
var middleware: [any Middleware]? {
switch self {
case let .employee(route): return route.middleware
case .login: return nil
case let .purchaseOrder(route): return route.middleware
case let .user(route): return route.middleware
case let .vendor(route): return route.middleware
case let .vendorBranch(route): return route.middleware
}
}
func handle(request: Request) async throws -> any AsyncResponseEncodable {
@Dependency(\.database.users) var users
switch self {
@@ -68,6 +88,8 @@ extension SharedModels.ViewRoute.EmployeeRoute {
}
}
var middleware: [any Middleware]? { viewProtectedMiddleware }
func handle(request: Request) async throws -> any AsyncResponseEncodable {
@Dependency(\.database.employees) var employees
@@ -133,6 +155,8 @@ extension SharedModels.ViewRoute.PurchaseOrderRoute {
try await mainPage(html, page: 1, limit: 25)
}
var middleware: [any Middleware]? { viewProtectedMiddleware }
func handle(request: Vapor.Request) async throws -> any AsyncResponseEncodable {
@Dependency(\.database.purchaseOrders) var purchaseOrders
switch self {
@@ -220,6 +244,10 @@ extension SharedModels.ViewRoute.UserRoute {
}
}
var middleware: [any Middleware]? {
viewProtectedMiddleware
}
func handle(request: Request) async throws -> any AsyncResponseEncodable {
@Dependency(\.database.users) var users
switch self {
@@ -271,6 +299,8 @@ extension SharedModels.ViewRoute.VendorRoute {
}
}
var middleware: [any Middleware]? { viewProtectedMiddleware }
func handle(request: Request) async throws -> any AsyncResponseEncodable {
@Dependency(\.database) var database
@@ -318,6 +348,8 @@ extension SharedModels.ViewRoute.VendorRoute {
extension SharedModels.ViewRoute.VendorBranchRoute {
var middleware: [any Middleware]? { viewProtectedMiddleware }
func handle(request: Request) async throws -> any AsyncResponseEncodable {
@Dependency(\.database) var database

View File

@@ -0,0 +1,96 @@
import URLRouting
import Vapor
import VaporRouting
// Taken from github.com/nevillco/vapor-routing
public extension Application {
/// Mounts a router to the Vapor application.
///
/// See ``VaporRouting`` for more information on usage.
///
/// - Parameters:
/// - router: A parser-printer that works on inputs of `URLRequestData`.
/// - middleware: A closure for providing any per-route migrations to be run before processing the request.
/// - closure: A closure that takes a `Request` and the router's output as arguments.
func mount<R: Parser>(
_ router: R,
middleware: @escaping @Sendable (R.Output) -> [any Middleware]? = { _ in nil },
use closure: @escaping @Sendable (Request, R.Output) async throws -> any AsyncResponseEncodable
) where R.Input == URLRequestData, R: Sendable, R.Output: Sendable {
self.middleware.use(AsyncRoutingMiddleware(router: router, middleware: middleware, respond: closure))
}
}
/// Serves requests using a router and response handler.
///
/// You will not typically need to interact with this type directly. Instead you should use the
/// `mount` method on your Vapor application.
///
/// See ``VaporRouting`` for more information on usage.
public struct AsyncRoutingMiddleware<Router: Parser>: AsyncMiddleware
where Router.Input == URLRequestData,
Router: Sendable,
Router.Output: Sendable
{
let router: Router
let middleware: @Sendable (Router.Output) -> [any Middleware]?
let respond: @Sendable (Request, Router.Output) async throws -> any AsyncResponseEncodable
public func respond(
to request: Request,
chainingTo next: any AsyncResponder
) async throws -> Response {
if request.body.data == nil {
try await _ = request.body.collect(max: request.application.routes.defaultMaxBodySize.value)
.get()
}
guard let requestData = URLRequestData(request: request)
else { return try await next.respond(to: request) }
let route: Router.Output
do {
route = try router.parse(requestData)
} catch let routingError {
do {
return try await next.respond(to: request)
} catch {
request.logger.info("\(routingError)")
guard request.application.environment == .development
else { throw error }
return Response(status: .notFound, body: .init(string: "Routing \(routingError)"))
}
}
if let middleware = middleware(route) {
return try await middleware.makeResponder(chainingTo: AsyncBasicResponder { request in
try await self.respond(request, route).encodeResponse(for: request)
}).respond(to: request).get()
// return try await middleware.respond(
// to: request,
// chainingTo: AsyncBasicResponder { request in
// try await self.respond(request, route).encodeResponse(for: request)
// }
// ).get()
} else {
return try await respond(request, route).encodeResponse(for: request)
}
}
}
// Usage:
// app.mount(
// router,
// middleware: { route in
// case .onboarding: return nil
// case .signIn: return BasicAuthMiddleware()
// default: return BearerAuthMiddleware()
// },
// use: { request, route in
// // route handline
// }
// )

View File

@@ -7,10 +7,23 @@ import Fluent
import SharedModels
import Vapor
import VaporElementary
import VaporRouting
@preconcurrency import VaporRouting
func routes(_ app: Application) throws {
app.mount(SiteRoute.router, use: siteHandler)
app.mount(
SiteRoute.router,
middleware: { route in
switch route {
case let .api(route):
return route.middleware
case .health:
return nil
case let .view(route):
return route.middleware
}
},
use: siteHandler
)
app.get { _ in
HTMLResponse {
@@ -21,44 +34,13 @@ func routes(_ app: Application) throws {
}
}
}
//
// app.get("login") { req in
// let context = try req.query.decode(LoginContext.self)
// return await req.render {
// MainPage(displayNav: false, route: .login) {
// UserForm(context: .login(next: context.next))
// }
// }
// }
//
// app.post("login") { req in
// @Dependency(\.database.users) var users
// let loginForm = try req.content.decode(User.Login.self)
// let token = try await users.login(loginForm)
// let user = try await users.get(token.userID)!
// req.session.authenticate(user)
// let context = try req.query.decode(LoginContext.self)
//
// return await req.render {
// MainPage(displayNav: true, route: .purchaseOrders) {
// div(
// .hx.get(context.next ?? "/purchase-orders"),
// .hx.pushURL(true),
// .hx.target("body"),
// .hx.trigger(.event(.revealed)),
// .hx.indicator(".hx-indicator")
// ) {
// Img.spinner().attributes(.class("hx-indicator"))
// }
// }
// }
// }
}
private struct LoginContext: Content {
let next: String?
}
@Sendable
func siteHandler(
request: Request,
route: SiteRoute

View File

@@ -147,7 +147,7 @@ public enum ApiRoute: Sendable {
}
public enum VendorApiRoute: Sendable {
case index(withBranches: Bool?)
case index(withBranches: Bool? = nil)
case create(Vendor.Create)
case delete(id: Vendor.ID)
case get(id: Vendor.ID)