package router import ( "errors" "fmt" "github.com/franckcuny/web-request" "net/http" "regexp" "strings" ) type Router struct { routes []*Route knownPaths map[string]map[string]bool withOptions bool notAllowed bool } var defaultHTTPMethods = []string{"GET", "HEAD", "PUT", "POST", "PATCH", "OPTIONS"} var reMessyPath = regexp.MustCompile("/{2,}") // Create a new router. func BuildRouter() *Router { router := Router{} router.knownPaths = map[string]map[string]bool{} return &router } // For each path, add a new route to respond to OPTIONS. func (self *Router) AddOptions() *Router { self.withOptions = true routes := self.GetRouteList() for _, r := range routes { m := self.GetMethodsForPath(r) allowed := strings.Join(m, ", ") defaultFn := func(req *request.Request, resp *request.Response) error { resp.Status = 204 resp.AddHeader("Allow", allowed) return nil } self.AddRoute(&Route{Method: "OPTIONS", Path: r, Code: defaultFn}) } return self } // For each path, add a route to respond 405 if the method is not implemented. func (self *Router) AddNotAllowed() *Router { self.notAllowed = true routes := self.GetRouteList() for _, r := range routes { methods := self.GetMethodsForPath(r) supportedMethods := map[string]bool{} for _, m := range methods { supportedMethods[m] = true } allowed := strings.Join(methods, ", ") defaultFn := func(req *request.Request, resp *request.Response) error { resp.Status = 405 resp.AddHeader("Allow", allowed) return nil } for _, dm := range defaultHTTPMethods { if supportedMethods[dm] == false { self.AddRoute(&Route{Method: dm, Path: r, Code: defaultFn}) } } } return self } func (self *Router) routeIsKnown(route *Route) bool { if self.knownPaths[route.Path] == nil { self.knownPaths[route.Path] = map[string]bool{} return false } else if self.knownPaths[route.Path][route.Method] == false { return false } return true } // Add a route to the router. func (self *Router) AddRoute(route *Route) error { if self.routeIsKnown(route) == true { return errors.New(fmt.Sprintf("The route %s with the method %s already exist.", route.Path, route.Method)) } route.init() self.routes = append(self.routes, route) self.knownPaths[route.Path][route.Method] = true return nil } func (self *Router) canonpath(url string) string { url = reMessyPath.ReplaceAllString(url, "/") return url } // Will try to find a route that match the HTTP Request. func (self *Router) Match(request *http.Request) (*Match, error) { matches := []*Match{} method := request.Method url := self.canonpath(request.URL.Path) components := strings.Split(url, "/") for _, r := range self.routes { match := r.Match(method, components) if match != nil { matches = append(matches, match) } } if len(matches) == 0 { return nil, nil } else if len(matches) == 1 { return matches[0], nil } else { return self.disambiguateMatches(request.URL.Path, matches) } return nil, nil } func (self *Router) disambiguateMatches(path string, matches []*Match) (*Match, error) { min := -1 found := []*Match{} for _, m := range matches { req := m.Route.requiredNamedComponents vars := len(req) if min == -1 || vars < min { found = append(found, m) min = vars } else if vars == min { found = append(found, m) } } if len(found) > 1 { msg := fmt.Sprintf("Ambiguous match: path %s could match any of:", path) for _, f := range found { msg = fmt.Sprintf("%s %s", msg, f.Route.Path) } err := errors.New(msg) return nil, err } return found[0], nil } // Get a list of routes from the router. func (self *Router) GetRouteList() []string { routes := make([]string, len(self.knownPaths)) i := 0 for path, _ := range self.knownPaths { routes[i] = path i = i + 1 } return routes } // Check that the router knows a given path. func (self *Router) HasPath(path string) bool { if self.knownPaths[path] != nil { return true } return false } // Remove a path (and all associated routes) from the router. func (self *Router) RemovePath(path string) error { p := self.HasPath(path) if p == false { return errors.New("foo") } delete(self.knownPaths, path) newRoutes := []*Route{} for _, p := range self.routes { if p.Path != path { newRoutes = append(newRoutes, p) } } self.routes = newRoutes return nil } // Return all the routes known by the router. func (self *Router) GetAllRoutes() []*Route { return self.routes } // Return all the routes that implement the given HTTP method. func (self *Router) GetAllRoutesByMethods(method string) []*Route { routes := []*Route{} for _, r := range self.routes { if r.Method == method { routes = append(routes, r) } } return routes } // Return a list of HTTP methods implemented for a given path. func (self *Router) GetMethodsForPath(path string) []string { p := self.knownPaths[path] m := make([]string, len(p)) i := 0 for k, _ := range p { m[i] = k i = i + 1 } return m }