feat: jwt auth and api keys
Brijesh Wawdhane ops@brijesh.dev
Wed, 11 Dec 2024 23:33:35 +0530
12 files changed,
1162 insertions(+),
0 deletions(-)
A
.air.toml
@@ -0,0 +1,46 @@
+root = "." +testdata_dir = "testdata" +tmp_dir = "tmp" + +[build] + args_bin = [] + bin = "./main" + cmd = "make build" + delay = 1000 + exclude_dir = ["assets", "tmp", "vendor", "testdata"] + exclude_file = [] + exclude_regex = ["_test.go"] + exclude_unchanged = false + follow_symlink = false + full_bin = "" + include_dir = [] + include_ext = ["go", "tpl", "tmpl", "html"] + include_file = [] + kill_delay = "0s" + log = "build-errors.log" + poll = false + poll_interval = 0 + post_cmd = [] + pre_cmd = [] + rerun = false + rerun_delay = 500 + send_interrupt = false + stop_on_error = false + +[color] + app = "" + build = "yellow" + main = "magenta" + runner = "green" + watcher = "cyan" + +[log] + main_only = false + time = false + +[misc] + clean_on_exit = false + +[screen] + clear_on_rebuild = false + keep_scroll = true
A
.gitignore
@@ -0,0 +1,34 @@
+# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with "go test -c" +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work +tmp/ + +# IDE specific files +.vscode +.idea + +# .env file +.env + +# Project build +main +*templ.go + +# OS X generated file +.DS_Store +
A
Dockerfile
@@ -0,0 +1,16 @@
+FROM golang:1.23-alpine AS build + +WORKDIR /app + +COPY go.mod go.sum ./ +RUN go mod download + +COPY . . + +RUN go build -o main cmd/api/main.go + +FROM alpine:3.20.1 AS prod +WORKDIR /app +COPY --from=build /app/main /app/main +EXPOSE ${PORT} +CMD ["./main"]
A
Makefile
@@ -0,0 +1,48 @@
+build: + @echo "Building..." + + + @go build -o main cmd/api/main.go + +run: + @go run cmd/api/main.go + +docker-run: + @if docker compose up --build 2>/dev/null; then \ + : ; \ + else \ + echo "Falling back to Docker Compose V1"; \ + docker-compose up --build; \ + fi + + +docker-down: + @if docker compose down 2>/dev/null; then \ + : ; \ + else \ + echo "Falling back to Docker Compose V1"; \ + docker-compose down; \ + fi + + +clean: + @echo "Cleaning..." + @rm -f main + +watch: + @if command -v air > /dev/null; then \ + air; \ + echo "Watching...";\ + else \ + read -p "Go's 'air' is not installed on your machine. Do you want to install it? [Y/n] " choice; \ + if [ "$$choice" != "n" ] && [ "$$choice" != "N" ]; then \ + go install github.com/air-verse/air@latest; \ + air; \ + echo "Watching...";\ + else \ + echo "You chose not to install air. Exiting..."; \ + exit 1; \ + fi; \ + fi + +.PHONY: build run clean watch docker-run docker-down
A
cmd/api/main.go
@@ -0,0 +1,57 @@
+package main + +import ( + "context" + "fmt" + "log" + "net/http" + "os/signal" + "syscall" + "time" + + "argus-core/internal/server" +) + +func gracefulShutdown(apiServer *http.Server, done chan bool) { + // Create context that listens for the interrupt signal from the OS. + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + // Listen for the interrupt signal. + <-ctx.Done() + + log.Println("shutting down gracefully, press Ctrl+C again to force") + + // The context is used to inform the server it has 5 seconds to finish + // the request it is currently handling + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := apiServer.Shutdown(ctx); err != nil { + log.Printf("Server forced to shutdown with error: %v", err) + } + + log.Println("Server exiting") + + // Notify the main goroutine that the shutdown is complete + done <- true +} + +func main() { + + server := server.NewServer() + + // Create a done channel to signal when the shutdown is complete + done := make(chan bool, 1) + + // Run graceful shutdown in a separate goroutine + go gracefulShutdown(server, done) + + err := server.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + panic(fmt.Sprintf("http server error: %s", err)) + } + + // Wait for the graceful shutdown to complete + <-done + log.Println("Graceful shutdown complete.") +}
A
go.mod
@@ -0,0 +1,22 @@
+module argus-core + +go 1.23.3 + +require ( + github.com/go-chi/chi/v5 v5.1.0 + github.com/go-chi/cors v1.2.1 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/google/uuid v1.6.0 + github.com/jackc/pgx/v5 v5.7.1 + github.com/joho/godotenv v1.5.1 + golang.org/x/crypto v0.27.0 +) + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/stretchr/testify v1.9.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/text v0.18.0 // indirect +)
A
go.sum
@@ -0,0 +1,38 @@
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= +github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= +github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
A
internal/auth/auth.go
@@ -0,0 +1,214 @@
+package auth + +import ( + "errors" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" + + "argus-core/internal/database" +) + +// Request types +type RegisterRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +type LoginRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +type CreateAPIKeyRequest struct { + Name string `json:"name"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` +} + +var ( + ErrInvalidCredentials = errors.New("invalid credentials") + ErrUserExists = errors.New("user already exists") + ErrInvalidToken = errors.New("invalid token") + ErrAPIKeyNotFound = errors.New("API key not found") +) + +type Service interface { + // User authentication + Register(email, password string) (*database.User, error) + Login(email, password string) (string, *database.User, error) // Returns JWT token and user + ValidateToken(token string) (*database.User, error) + + // API Key management + CreateAPIKey(userID uuid.UUID, name string, expiresAt *time.Time) (*database.APIKey, string, error) // Returns APIKey and the actual key + ValidateAPIKey(key string) (*database.APIKey, error) + ListAPIKeys(userID uuid.UUID) ([]database.APIKey, error) + RevokeAPIKey(userID, keyID uuid.UUID) error + DeleteAPIKey(userID, keyID uuid.UUID) error +} + +type service struct { + db database.Service + jwtSecret []byte + tokenDuration time.Duration +} + +type Config struct { + JWTSecret string + TokenDuration time.Duration +} + +func NewService(db database.Service, config Config) Service { + return &service{ + db: db, + jwtSecret: []byte(config.JWTSecret), + tokenDuration: config.TokenDuration, + } +} + +func (s *service) Register(email, password string) (*database.User, error) { + // Check if user already exists + existingUser, err := s.db.GetUserByEmail(email) + if err == nil && existingUser != nil { + return nil, ErrUserExists + } + + // Hash password + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return nil, err + } + + // Create user + user, err := s.db.CreateUser(email, string(hashedPassword)) + if err != nil { + return nil, err + } + + return user, nil +} + +func (s *service) Login(email, password string) (string, *database.User, error) { + // Get user + user, err := s.db.GetUserByEmail(email) + if err != nil { + return "", nil, ErrInvalidCredentials + } + + // Check password + if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil { + return "", nil, ErrInvalidCredentials + } + + // Generate JWT token + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": user.ID.String(), + "exp": time.Now().Add(s.tokenDuration).Unix(), + }) + + tokenString, err := token.SignedString(s.jwtSecret) + if err != nil { + return "", nil, err + } + + return tokenString, user, nil +} + +func (s *service) ValidateToken(tokenString string) (*database.User, error) { + // Parse token + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, ErrInvalidToken + } + return s.jwtSecret, nil + }) + + if err != nil || !token.Valid { + return nil, ErrInvalidToken + } + + // Extract claims + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, ErrInvalidToken + } + + // Parse user ID + userID, err := uuid.Parse(claims["sub"].(string)) + if err != nil { + return nil, ErrInvalidToken + } + + // Get user from database + user, err := s.db.GetUserByID(userID) + if err != nil { + return nil, err + } + + return user, nil +} + +func (s *service) CreateAPIKey(userID uuid.UUID, name string, expiresAt *time.Time) (*database.APIKey, string, error) { + // Generate random API key + apiKeyStr, err := generateAPIKey() + if err != nil { + return nil, "", fmt.Errorf("failed to generate API key: %w", err) + } + + // Hash the API key + keyHash := hashAPIKey(apiKeyStr) + + // Create API key in database + apiKey, err := s.db.CreateAPIKey(userID, name, keyHash, expiresAt) + if err != nil { + return nil, "", fmt.Errorf("failed to create API key: %w", err) + } + + return apiKey, apiKeyStr, nil +} + +func (s *service) ValidateAPIKey(key string) (*database.APIKey, error) { + // Validate key format + if !validateAPIKeyFormat(key) { + return nil, ErrAPIKeyNotFound + } + + keyHash := hashAPIKey(key) + + apiKey, err := s.db.GetAPIKeyByHash(keyHash) + if err != nil { + return nil, ErrAPIKeyNotFound + } + + // Check if key is expired + if apiKey.ExpiresAt != nil && time.Now().After(*apiKey.ExpiresAt) { + return nil, ErrAPIKeyNotFound + } + + // Check if key is active + if !apiKey.IsActive { + return nil, ErrAPIKeyNotFound + } + + // Update last used timestamp + if err := s.db.UpdateAPIKeyLastUsed(apiKey.ID); err != nil { + // Log error but don't fail the request + // log.Printf("Failed to update API key last used: %v", err) + } + + return apiKey, nil +} + +func (s *service) ListAPIKeys(userID uuid.UUID) ([]database.APIKey, error) { + return s.db.ListAPIKeys(userID) +} + +func (s *service) RevokeAPIKey(userID, keyID uuid.UUID) error { + return s.db.RevokeAPIKey(userID, keyID) +} + +func (s *service) DeleteAPIKey(userID, keyID uuid.UUID) error { + return s.db.DeleteAPIKey(userID, keyID) +}
A
internal/auth/utils.go
@@ -0,0 +1,65 @@
+package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" +) + +const ( + // APIKeyPrefix is the prefix for all API keys + APIKeyPrefix = "argus" + // APIKeyBytes is the number of random bytes to generate for the API key + APIKeyBytes = 32 +) + +// generateAPIKey generates a new API key with format: argus_<random-string> +// The random string is base64 encoded and URL safe +func generateAPIKey() (string, error) { + // Generate random bytes + randomBytes := make([]byte, APIKeyBytes) + _, err := rand.Read(randomBytes) + if err != nil { + return "", fmt.Errorf("failed to generate random bytes: %w", err) + } + + // Encode as base64 and make it URL safe + // Use RawURLEncoding to avoid special characters like '/' and '+' + randomString := base64.RawURLEncoding.EncodeToString(randomBytes) + + // Format: argus_<random-string> + return fmt.Sprintf("%s_%s", APIKeyPrefix, randomString), nil +} + +// hashAPIKey creates a SHA-256 hash of the API key +// This is what we'll store in the database +func hashAPIKey(key string) string { + // Create SHA-256 hash + hasher := sha256.New() + hasher.Write([]byte(key)) + + // Convert to hex string + return hex.EncodeToString(hasher.Sum(nil)) +} + +// validateAPIKeyFormat checks if the API key has the correct format +func validateAPIKeyFormat(key string) bool { + // Check if key starts with the correct prefix + if len(key) < len(APIKeyPrefix)+2 { // +2 for '_' and at least one character + return false + } + + prefix := key[:len(APIKeyPrefix)] + if prefix != APIKeyPrefix { + return false + } + + // Check if the next character is underscore + if key[len(APIKeyPrefix)] != '_' { + return false + } + + return true +}
A
internal/database/database.go
@@ -0,0 +1,348 @@
+package database + +import ( + "context" + "database/sql" + "fmt" + "log" + "os" + "strconv" + "time" + + "github.com/google/uuid" + _ "github.com/jackc/pgx/v5/stdlib" + _ "github.com/joho/godotenv/autoload" +) + +type User struct { + ID uuid.UUID `json:"id"` + Email string `json:"email"` + PasswordHash string `json:"-"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// APIKey represents an API key in the database +type APIKey struct { + ID uuid.UUID `json:"id"` + UserID uuid.UUID `json:"user_id"` + Name string `json:"name"` + KeyHash string `json:"-"` + CreatedAt time.Time `json:"created_at"` + LastUsedAt *time.Time `json:"last_used_at,omitempty"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + IsActive bool `json:"is_active"` +} + +// Service represents a service that interacts with a database. +type Service interface { + // Health returns a map of health status information. + Health() map[string]string + + // Close terminates the database connection. + Close() error + + // User-related queries + CreateUser(email, passwordHash string) (*User, error) + GetUserByEmail(email string) (*User, error) + GetUserByID(id uuid.UUID) (*User, error) + + // API Key-related queries + CreateAPIKey(userID uuid.UUID, name, keyHash string, expiresAt *time.Time) (*APIKey, error) + GetAPIKeyByHash(keyHash string) (*APIKey, error) + ListAPIKeys(userID uuid.UUID) ([]APIKey, error) + UpdateAPIKeyLastUsed(keyID uuid.UUID) error + RevokeAPIKey(userID, keyID uuid.UUID) error + DeleteAPIKey(userID, keyID uuid.UUID) error +} + +type service struct { + db *sql.DB +} + +var ( + database = os.Getenv("DB_DATABASE") + password = os.Getenv("DB_PASSWORD") + username = os.Getenv("DB_USERNAME") + port = os.Getenv("DB_PORT") + host = os.Getenv("DB_HOST") + schema = os.Getenv("DB_SCHEMA") + dbInstance *service +) + +func New() Service { + // Reuse Connection + if dbInstance != nil { + return dbInstance + } + connStr := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable&search_path=%s", username, password, host, port, database, schema) + db, err := sql.Open("pgx", connStr) + if err != nil { + log.Fatal(err) + } + dbInstance = &service{ + db: db, + } + return dbInstance +} + +// Health checks the health of the database connection by pinging the database. +// It returns a map with keys indicating various health statistics. +func (s *service) Health() map[string]string { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + stats := make(map[string]string) + + // Ping the database + err := s.db.PingContext(ctx) + if err != nil { + stats["status"] = "down" + stats["error"] = fmt.Sprintf("db down: %v", err) + log.Fatalf("db down: %v", err) // Log the error and terminate the program + return stats + } + + // Database is up, add more statistics + stats["status"] = "up" + stats["message"] = "It's healthy" + + // Get database stats (like open connections, in use, idle, etc.) + dbStats := s.db.Stats() + stats["open_connections"] = strconv.Itoa(dbStats.OpenConnections) + stats["in_use"] = strconv.Itoa(dbStats.InUse) + stats["idle"] = strconv.Itoa(dbStats.Idle) + stats["wait_count"] = strconv.FormatInt(dbStats.WaitCount, 10) + stats["wait_duration"] = dbStats.WaitDuration.String() + stats["max_idle_closed"] = strconv.FormatInt(dbStats.MaxIdleClosed, 10) + stats["max_lifetime_closed"] = strconv.FormatInt(dbStats.MaxLifetimeClosed, 10) + + // Evaluate stats to provide a health message + if dbStats.OpenConnections > 40 { // Assuming 50 is the max for this example + stats["message"] = "The database is experiencing heavy load." + } + + if dbStats.WaitCount > 1000 { + stats["message"] = "The database has a high number of wait events, indicating potential bottlenecks." + } + + if dbStats.MaxIdleClosed > int64(dbStats.OpenConnections)/2 { + stats["message"] = "Many idle connections are being closed, consider revising the connection pool settings." + } + + if dbStats.MaxLifetimeClosed > int64(dbStats.OpenConnections)/2 { + stats["message"] = "Many connections are being closed due to max lifetime, consider increasing max lifetime or revising the connection usage pattern." + } + + return stats +} + +// Close closes the database connection. +// It logs a message indicating the disconnection from the specific database. +// If the connection is successfully closed, it returns nil. +// If an error occurs while closing the connection, it returns the error. +func (s *service) Close() error { + log.Printf("Disconnected from database: %s", database) + return s.db.Close() +} + +// auth queries + +func (s *service) CreateUser(email, passwordHash string) (*User, error) { + var user User + err := s.db.QueryRow(` + INSERT INTO users (email, password_hash) + VALUES ($1, $2) + RETURNING id, email, password_hash, created_at, updated_at + `, email, passwordHash).Scan( + &user.ID, + &user.Email, + &user.PasswordHash, + &user.CreatedAt, + &user.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("error creating user: %w", err) + } + return &user, nil +} + +func (s *service) GetUserByEmail(email string) (*User, error) { + var user User + err := s.db.QueryRow(` + SELECT id, email, password_hash, created_at, updated_at + FROM users + WHERE email = $1 + `, email).Scan( + &user.ID, + &user.Email, + &user.PasswordHash, + &user.CreatedAt, + &user.UpdatedAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("user not found") + } + return nil, fmt.Errorf("error getting user: %w", err) + } + return &user, nil +} + +func (s *service) GetUserByID(id uuid.UUID) (*User, error) { + var user User + err := s.db.QueryRow(` + SELECT id, email, password_hash, created_at, updated_at + FROM users + WHERE id = $1 + `, id).Scan( + &user.ID, + &user.Email, + &user.PasswordHash, + &user.CreatedAt, + &user.UpdatedAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("user not found") + } + return nil, fmt.Errorf("error getting user: %w", err) + } + return &user, nil +} + +// API Key-related query implementations +func (s *service) CreateAPIKey(userID uuid.UUID, name, keyHash string, expiresAt *time.Time) (*APIKey, error) { + var apiKey APIKey + err := s.db.QueryRow(` + INSERT INTO api_keys (user_id, name, key_hash, expires_at) + VALUES ($1, $2, $3, $4) + RETURNING id, user_id, name, key_hash, created_at, last_used_at, expires_at, is_active + `, userID, name, keyHash, expiresAt).Scan( + &apiKey.ID, + &apiKey.UserID, + &apiKey.Name, + &apiKey.KeyHash, + &apiKey.CreatedAt, + &apiKey.LastUsedAt, + &apiKey.ExpiresAt, + &apiKey.IsActive, + ) + if err != nil { + return nil, fmt.Errorf("error creating API key: %w", err) + } + return &apiKey, nil +} + +func (s *service) GetAPIKeyByHash(keyHash string) (*APIKey, error) { + var apiKey APIKey + err := s.db.QueryRow(` + SELECT id, user_id, name, key_hash, created_at, last_used_at, expires_at, is_active + FROM api_keys + WHERE key_hash = $1 + `, keyHash).Scan( + &apiKey.ID, + &apiKey.UserID, + &apiKey.Name, + &apiKey.KeyHash, + &apiKey.CreatedAt, + &apiKey.LastUsedAt, + &apiKey.ExpiresAt, + &apiKey.IsActive, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("API key not found") + } + return nil, fmt.Errorf("error getting API key: %w", err) + } + return &apiKey, nil +} + +func (s *service) ListAPIKeys(userID uuid.UUID) ([]APIKey, error) { + rows, err := s.db.Query(` + SELECT id, user_id, name, key_hash, created_at, last_used_at, expires_at, is_active + FROM api_keys + WHERE user_id = $1 + ORDER BY created_at DESC + `, userID) + if err != nil { + return nil, fmt.Errorf("error listing API keys: %w", err) + } + defer rows.Close() + + var apiKeys []APIKey + for rows.Next() { + var apiKey APIKey + err := rows.Scan( + &apiKey.ID, + &apiKey.UserID, + &apiKey.Name, + &apiKey.KeyHash, + &apiKey.CreatedAt, + &apiKey.LastUsedAt, + &apiKey.ExpiresAt, + &apiKey.IsActive, + ) + if err != nil { + return nil, fmt.Errorf("error scanning API key: %w", err) + } + apiKeys = append(apiKeys, apiKey) + } + return apiKeys, nil +} + +func (s *service) UpdateAPIKeyLastUsed(keyID uuid.UUID) error { + _, err := s.db.Exec(` + UPDATE api_keys + SET last_used_at = CURRENT_TIMESTAMP + WHERE id = $1 + `, keyID) + if err != nil { + return fmt.Errorf("error updating API key last used: %w", err) + } + return nil +} + +func (s *service) RevokeAPIKey(userID, keyID uuid.UUID) error { + result, err := s.db.Exec(` + UPDATE api_keys + SET is_active = false + WHERE id = $1 AND user_id = $2 + `, keyID, userID) + if err != nil { + return fmt.Errorf("error revoking API key: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("error getting rows affected: %w", err) + } + + if rowsAffected == 0 { + return fmt.Errorf("API key not found or not owned by user") + } + + return nil +} + +func (s *service) DeleteAPIKey(userID, keyID uuid.UUID) error { + result, err := s.db.Exec(` + DELETE FROM api_keys + WHERE id = $1 AND user_id = $2 + `, keyID, userID) + if err != nil { + return fmt.Errorf("error deleting API key: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("error getting rows affected: %w", err) + } + + if rowsAffected == 0 { + return fmt.Errorf("API key not found or not owned by user") + } + + return nil +}
A
internal/server/routes.go
@@ -0,0 +1,223 @@
+package server + +import ( + "context" + "encoding/json" + "log" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/cors" + "github.com/google/uuid" + + "argus-core/internal/auth" +) + +func (s *Server) RegisterRoutes() http.Handler { + r := chi.NewRouter() + r.Use(middleware.Logger) + + r.Use(cors.Handler(cors.Options{ + AllowedOrigins: []string{"https://*", "http://*"}, + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"}, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"}, + AllowCredentials: false, + MaxAge: 300, + })) + + r.Get("/", s.HelloWorldHandler) + r.Get("/health", s.healthHandler) + + // Auth routes + r.Post("/auth/register", s.handleRegister) + r.Post("/auth/login", s.handleLogin) + + // Protected routes + r.Group(func(r chi.Router) { + r.Use(s.authMiddleware) + r.Get("/auth/me", s.handleGetCurrentUser) + r.Post("/api-keys", s.handleCreateAPIKey) + r.Get("/api-keys", s.handleListAPIKeys) + r.Delete("/api-keys/{keyID}", s.handleDeleteAPIKey) + r.Post("/api-keys/{keyID}/revoke", s.handleRevokeAPIKey) + }) + + return r +} + +func (s *Server) HelloWorldHandler(w http.ResponseWriter, r *http.Request) { + resp := make(map[string]string) + resp["message"] = "Hello World" + + jsonResp, err := json.Marshal(resp) + if err != nil { + log.Fatalf("error handling JSON marshal. Err: %v", err) + } + + _, _ = w.Write(jsonResp) +} + +func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) { + jsonResp, _ := json.Marshal(s.db.Health()) + _, _ = w.Write(jsonResp) +} + +func (s *Server) handleRegister(w http.ResponseWriter, r *http.Request) { + var req auth.RegisterRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.respondWithError(w, http.StatusBadRequest, "Invalid request payload") + return + } + + user, err := s.auth.Register(req.Email, req.Password) + if err != nil { + log.Println("Failed to register user:", err) + s.respondWithError(w, http.StatusInternalServerError, "Failed to register user") + return + } + + s.respondWithJSON(w, http.StatusCreated, user) +} + +func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { + var req auth.LoginRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.respondWithError(w, http.StatusBadRequest, "Invalid request payload") + return + } + + token, user, err := s.auth.Login(req.Email, req.Password) + if err != nil { + s.respondWithError(w, http.StatusUnauthorized, "Invalid credentials") + return + } + + // Return both token and user info in the response body + s.respondWithJSON(w, http.StatusOK, map[string]interface{}{ + "token": token, + "user": user, + }) +} + +func (s *Server) handleGetCurrentUser(w http.ResponseWriter, r *http.Request) { + // Get user ID from context (set by authMiddleware) + userID := r.Context().Value("userID").(uuid.UUID) + + // Get user from database + user, err := s.db.GetUserByID(userID) + if err != nil { + s.respondWithError(w, http.StatusInternalServerError, "Failed to get user details") + return + } + + s.respondWithJSON(w, http.StatusOK, user) +} + +func (s *Server) handleCreateAPIKey(w http.ResponseWriter, r *http.Request) { + var req auth.CreateAPIKeyRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.respondWithError(w, http.StatusBadRequest, "Invalid request payload") + return + } + + userID := r.Context().Value("userID").(uuid.UUID) + apiKey, keyString, err := s.auth.CreateAPIKey(userID, req.Name, req.ExpiresAt) + if err != nil { + s.respondWithError(w, http.StatusInternalServerError, "Failed to create API key") + return + } + + s.respondWithJSON(w, http.StatusCreated, map[string]interface{}{ + "api_key": apiKey, + "key": keyString, + }) +} + +func (s *Server) handleListAPIKeys(w http.ResponseWriter, r *http.Request) { + userID := r.Context().Value("userID").(uuid.UUID) + apiKeys, err := s.auth.ListAPIKeys(userID) + if err != nil { + s.respondWithError(w, http.StatusInternalServerError, "Failed to list API keys") + return + } + + s.respondWithJSON(w, http.StatusOK, apiKeys) +} + +func (s *Server) handleRevokeAPIKey(w http.ResponseWriter, r *http.Request) { + userID := r.Context().Value("userID").(uuid.UUID) + keyID, err := uuid.Parse(chi.URLParam(r, "keyID")) + if err != nil { + s.respondWithError(w, http.StatusBadRequest, "Invalid key ID") + return + } + + if err := s.auth.RevokeAPIKey(userID, keyID); err != nil { + s.respondWithError(w, http.StatusInternalServerError, "Failed to revoke API key") + return + } + + w.WriteHeader(http.StatusNoContent) +} + +func (s *Server) handleDeleteAPIKey(w http.ResponseWriter, r *http.Request) { + userID := r.Context().Value("userID").(uuid.UUID) + keyID, err := uuid.Parse(chi.URLParam(r, "keyID")) + if err != nil { + s.respondWithError(w, http.StatusBadRequest, "Invalid key ID") + return + } + + if err := s.auth.DeleteAPIKey(userID, keyID); err != nil { + s.respondWithError(w, http.StatusInternalServerError, "Failed to delete API key") + return + } + + w.WriteHeader(http.StatusNoContent) +} + +func (s *Server) authMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get token from Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + s.respondWithError(w, http.StatusUnauthorized, "No authorization header") + return + } + + // Remove "Bearer " prefix if present + token := authHeader + if len(authHeader) > 7 && authHeader[:7] == "Bearer " { + token = authHeader[7:] + } + + // Validate token + user, err := s.auth.ValidateToken(token) + if err != nil { + s.respondWithError(w, http.StatusUnauthorized, "Invalid token") + return + } + + // Add user ID to context + ctx := context.WithValue(r.Context(), "userID", user.ID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func (s *Server) respondWithError(w http.ResponseWriter, code int, message string) { + s.respondWithJSON(w, code, map[string]string{"error": message}) +} + +func (s *Server) respondWithJSON(w http.ResponseWriter, code int, payload interface{}) { + response, err := json.Marshal(payload) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"Failed to marshal JSON response"}`)) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + w.Write(response) +}
A
internal/server/server.go
@@ -0,0 +1,51 @@
+package server + +import ( + "fmt" + "net/http" + "os" + "strconv" + "time" + + _ "github.com/joho/godotenv/autoload" + + "argus-core/internal/auth" + "argus-core/internal/database" +) + +type Server struct { + port int + + db database.Service + auth auth.Service +} + +func NewServer() *http.Server { + port, _ := strconv.Atoi(os.Getenv("PORT")) + + // Initialize database service + db := database.New() + + // Initialize auth service + authService := auth.NewService(db, auth.Config{ + JWTSecret: os.Getenv("JWT_SECRET"), + TokenDuration: 24 * time.Hour, + }) + + NewServer := &Server{ + port: port, + db: db, + auth: authService, + } + + // Declare Server config + server := &http.Server{ + Addr: fmt.Sprintf(":%d", NewServer.port), + Handler: NewServer.RegisterRoutes(), + IdleTimeout: time.Minute, + ReadTimeout: 10 * time.Second, + WriteTimeout: 30 * time.Second, + } + + return server +}