diff --git a/README.md b/README.md index 55538a5..b4c7916 100644 --- a/README.md +++ b/README.md @@ -9,3 +9,8 @@ Environmental variables: - `SHORTEN_BIND` - bind address (default: `127.0.0.1:4488`) - `SHORTEN_MAIL` - optional email for support/abuse reports - `POSTGRES_URI` - lib/pq connection string (see [here](https://pkg.go.dev/github.com/lib/pq#section-documentation)) + +### Maintenance + +shorten supports a domain blacklist to ban certain domains (e.g. spam, malware, etc.); +you can use it by connecting to the database and inserting rows into the `blacklist` table. diff --git a/main.go b/main.go index 2c645c8..256fc55 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "log" "math/rand" "net/http" + neturl "net/url" "os" "strings" "time" @@ -71,6 +72,13 @@ func main() { panic(err) } + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS blacklist ( + domain text not null + );`) + if err != nil { + panic(err) + } + bind := os.Getenv("SHORTEN_BIND") if bind == "" { bind = "127.0.0.1:4488" @@ -167,8 +175,25 @@ func (h *Handler) GetCode(url string, ip string) (string, error) { return "", fmt.Errorf("invalid URL") } + parsed, err := neturl.Parse(url) + if err != nil { + return "", fmt.Errorf("invalid URL") + } + hostname := parsed.Hostname() + + // tbh we don't care about the value here but why not + var res string + err = h.db.QueryRow(`SELECT domain FROM blacklist WHERE domain = $1`, hostname).Scan(&res) + if err != sql.ErrNoRows { + if err != nil { + log.Println("sql error:", err) + return "", fmt.Errorf("query: %w", err) + } + return "", fmt.Errorf("invalid URL") + } + var code string - err := h.db.QueryRow(`SELECT code FROM urls WHERE url = $1 LIMIT 1`, url).Scan(&code) + err = h.db.QueryRow(`SELECT code FROM urls WHERE url = $1 LIMIT 1`, url).Scan(&code) if err == nil { return code, nil }