Rework project architecture with extensive bug fixes and security enhancements.

This commit is contained in:
Michael Reber 2025-02-11 13:01:01 +01:00
parent b390b87fef
commit 6796dae280
25 changed files with 1584 additions and 2600 deletions

View File

@ -1,138 +1,147 @@
/**
* Renders the list of clients.
* @param {Array} data - Array of client objects.
*/
function renderClientList(data) {
$.each(data, function(index, obj) {
// render client status css tag style
let clientStatusHtml = '>'
if (obj.Client.enabled) {
clientStatusHtml = `style="visibility: hidden;">`
}
// render client allocated ip addresses
let allocatedIpsHtml = "";
$.each(obj.Client.allocated_ips, function(index, obj) {
allocatedIpsHtml += `<small class="badge badge-secondary">${escapeHtml(obj)}</small>&nbsp;`;
})
// render client allowed ip addresses
let allowedIpsHtml = "";
$.each(obj.Client.allowed_ips, function(index, obj) {
allowedIpsHtml += `<small class="badge badge-secondary">${escapeHtml(obj)}</small>&nbsp;`;
})
let subnetRangesString = "";
if (obj.Client.subnet_ranges && obj.Client.subnet_ranges.length > 0) {
subnetRangesString = obj.Client.subnet_ranges.join(',')
}
let additionalNotesHtml = "";
if (obj.Client.additional_notes && obj.Client.additional_notes.length > 0) {
additionalNotesHtml = `<div style="display: none"><i class="fas fa-additional_notes"></i>${escapeHtml(obj.Client.additional_notes.toUpperCase())}</div>`
}
// render client html content
let html = `<div class="col-sm-6 col-md-6 col-lg-4" id="client_${obj.Client.id}">
<div class="card">
<div class="overlay" id="paused_${obj.Client.id}"` + clientStatusHtml
+ `<i class="paused-client fas fa-3x fa-play" onclick="resumeClient('${obj.Client.id}')"></i>
</div>
<div class="card-header">
<div class="btn-group">
<a href="download?clientid=${obj.Client.id}" class="btn btn-outline-primary btn-sm">Download</a>
</div>
<div class="btn-group">
<button type="button" class="btn btn-outline-primary btn-sm" data-toggle="modal"
data-target="#modal_qr_client" data-clientid="${obj.Client.id}"
data-clientname="${escapeHtml(obj.Client.name)}" ${obj.QRCode != "" ? '' : ' disabled'}>QR code</button>
</div>
<div class="btn-group">
<button type="button" class="btn btn-outline-primary btn-sm" data-toggle="modal"
data-target="#modal_email_client" data-clientid="${obj.Client.id}"
data-clientname="${escapeHtml(obj.Client.name)}">Email</button>
</div>
<div class="btn-group">
<button type="button" class="btn btn-outline-danger btn-sm">More</button>
<button type="button" class="btn btn-outline-danger btn-sm dropdown-toggle dropdown-icon"
data-toggle="dropdown">
</button>
<div class="dropdown-menu" role="menu">
<a class="dropdown-item" href="#" data-toggle="modal"
data-target="#modal_edit_client" data-clientid="${obj.Client.id}"
data-clientname="${escapeHtml(obj.Client.name)}">Edit</a>
<a class="dropdown-item" href="#" data-toggle="modal"
data-target="#modal_pause_client" data-clientid="${obj.Client.id}"
data-clientname="${escapeHtml(obj.Client.name)}">Disable</a>
<a class="dropdown-item" href="#" data-toggle="modal"
data-target="#modal_remove_client" data-clientid="${obj.Client.id}"
data-clientname="${escapeHtml(obj.Client.name)}">Delete</a>
</div>
</div>
</div>
<div class="card-body">
<div class="info-box-text"><i class="fas fa-user"></i> ${escapeHtml(obj.Client.name)}</div>
<div style="display: none"><i class="fas fa-key"></i> ${escapeHtml(obj.Client.public_key)}</div>
<div style="display: none"><i class="fas fa-subnetrange"></i>${escapeHtml(subnetRangesString)}</div>
${additionalNotesHtml}
<div class="info-box-text"><i class="fas fa-envelope"></i> ${escapeHtml(obj.Client.email)}</div>
<div class="info-box-text"><i class="fas fa-clock"></i>
${prettyDateTime(obj.Client.created_at)}</div>
<div class="info-box-text"><i class="fas fa-history"></i>
${prettyDateTime(obj.Client.updated_at)}</div>
<div class="info-box-text"><i class="fas fa-server" style="${obj.Client.use_server_dns ? "opacity: 1.0" : "opacity: 0.5"}"></i>
${obj.Client.use_server_dns ? 'DNS enabled' : 'DNS disabled'}</div>
<div class="info-box-text"><i class="fas fa-file"></i>
${escapeHtml(obj.Client.additional_notes)}</div>
<div class="info-box-text"><strong>IP Allocation</strong></div>`
+ allocatedIpsHtml
+ `<div class="info-box-text"><strong>Allowed IPs</strong></div>`
+ allowedIpsHtml
+`</div>
</div>
</div>`
// add the client html elements to the list
$('#client-list').append(html);
data.forEach(function(obj) {
// Determine the CSS style for the client overlay based on its enabled status.
const clientStatusHtml = obj.Client.enabled
? 'style="visibility: hidden;">'
: '>';
// Render allocated IP addresses as badges.
const allocatedIpsHtml = obj.Client.allocated_ips
.map(ip => `<small class="badge badge-secondary">${escapeHtml(ip)}</small>&nbsp;`)
.join('');
// Render allowed IP addresses as badges.
const allowedIpsHtml = obj.Client.allowed_ips
.map(ip => `<small class="badge badge-secondary">${escapeHtml(ip)}</small>&nbsp;`)
.join('');
// Join subnet ranges, if any.
const subnetRangesString = (obj.Client.subnet_ranges && obj.Client.subnet_ranges.length > 0)
? obj.Client.subnet_ranges.join(',')
: '';
// Render additional notes (hidden by default).
const additionalNotesHtml = (obj.Client.additional_notes && obj.Client.additional_notes.length > 0)
? `<div style="display: none"><i class="fas fa-additional_notes"></i>${escapeHtml(obj.Client.additional_notes.toUpperCase())}</div>`
: '';
// Build the client card HTML.
const html = `
<div class="col-sm-6 col-md-6 col-lg-4" id="client_${obj.Client.id}">
<div class="card">
<div class="overlay" id="paused_${obj.Client.id}" ${clientStatusHtml}
<i class="paused-client fas fa-3x fa-play" onclick="resumeClient('${obj.Client.id}')"></i>
</div>
<div class="card-header">
<div class="btn-group">
<a href="download?clientid=${obj.Client.id}" class="btn btn-outline-primary btn-sm">Download</a>
</div>
<div class="btn-group">
<button type="button" class="btn btn-outline-primary btn-sm" data-toggle="modal"
data-target="#modal_qr_client" data-clientid="${obj.Client.id}"
data-clientname="${escapeHtml(obj.Client.name)}" ${obj.QRCode !== "" ? '' : 'disabled'}>QR code</button>
</div>
<div class="btn-group">
<button type="button" class="btn btn-outline-primary btn-sm" data-toggle="modal"
data-target="#modal_email_client" data-clientid="${obj.Client.id}"
data-clientname="${escapeHtml(obj.Client.name)}">Email</button>
</div>
<div class="btn-group">
<button type="button" class="btn btn-outline-danger btn-sm">More</button>
<button type="button" class="btn btn-outline-danger btn-sm dropdown-toggle dropdown-icon" data-toggle="dropdown"></button>
<div class="dropdown-menu" role="menu">
<a class="dropdown-item" href="#" data-toggle="modal"
data-target="#modal_edit_client" data-clientid="${obj.Client.id}"
data-clientname="${escapeHtml(obj.Client.name)}">Edit</a>
<a class="dropdown-item" href="#" data-toggle="modal"
data-target="#modal_pause_client" data-clientid="${obj.Client.id}"
data-clientname="${escapeHtml(obj.Client.name)}">Disable</a>
<a class="dropdown-item" href="#" data-toggle="modal"
data-target="#modal_remove_client" data-clientid="${obj.Client.id}"
data-clientname="${escapeHtml(obj.Client.name)}">Delete</a>
</div>
</div>
</div>
<div class="card-body">
<div class="info-box-text"><i class="fas fa-user"></i> ${escapeHtml(obj.Client.name)}</div>
<div style="display: none"><i class="fas fa-key"></i> ${escapeHtml(obj.Client.public_key)}</div>
<div style="display: none"><i class="fas fa-subnetrange"></i> ${escapeHtml(subnetRangesString)}</div>
${additionalNotesHtml}
<div class="info-box-text"><i class="fas fa-envelope"></i> ${escapeHtml(obj.Client.email)}</div>
<div class="info-box-text"><i class="fas fa-clock"></i> ${prettyDateTime(obj.Client.created_at)}</div>
<div class="info-box-text"><i class="fas fa-history"></i> ${prettyDateTime(obj.Client.updated_at)}</div>
<div class="info-box-text"><i class="fas fa-server" style="${obj.Client.use_server_dns ? 'opacity: 1.0' : 'opacity: 0.5'}"></i> ${obj.Client.use_server_dns ? 'DNS enabled' : 'DNS disabled'}</div>
<div class="info-box-text"><i class="fas fa-file"></i> ${escapeHtml(obj.Client.additional_notes)}</div>
<div class="info-box-text"><strong>IP Allocation</strong></div>
${allocatedIpsHtml}
<div class="info-box-text"><strong>Allowed IPs</strong></div>
${allowedIpsHtml}
</div>
</div>
</div>
`;
// Append the generated client card HTML to the client list container.
$('#client-list').append(html);
});
}
function renderUserList(data) {
$.each(data, function(index, obj) {
let clientStatusHtml = '>'
// render user html content
let html = `<div class="col-sm-6 col-md-6 col-lg-4" id="user_${obj.username}">
<div class="card">
<div class="card-header">
<div class="btn-group">
<button type="button" class="btn btn-outline-primary btn-sm" data-toggle="modal" data-target="#modal_edit_user" data-username="${obj.username}">Edit</button>
</div>
<div class="btn-group">
<button type="button" class="btn btn-outline-danger btn-sm" data-toggle="modal"
data-target="#modal_remove_user" data-username="${obj.username}">Delete</button>
</div>
</div>
<div class="card-body">
<div class="info-box-text"><i class="fas fa-user"></i> ${obj.username}</div>
<div class="info-box-text"><i class="fas fa-terminal"></i> ${obj.admin? 'Administrator':'Manager'}</div>
</div>
</div>
</div>`
// add the user html elements to the list
$('#users-list').append(html);
}
/**
* Renders the list of users.
* @param {Array} data - Array of user objects.
*/
function renderUserList(data) {
data.forEach(function(obj) {
const html = `
<div class="col-sm-6 col-md-6 col-lg-4" id="user_${obj.username}">
<div class="card">
<div class="card-header">
<div class="btn-group">
<button type="button" class="btn btn-outline-primary btn-sm" data-toggle="modal" data-target="#modal_edit_user" data-username="${obj.username}">Edit</button>
</div>
<div class="btn-group">
<button type="button" class="btn btn-outline-danger btn-sm" data-toggle="modal" data-target="#modal_remove_user" data-username="${obj.username}">Delete</button>
</div>
</div>
<div class="card-body">
<div class="info-box-text"><i class="fas fa-user"></i> ${obj.username}</div>
<div class="info-box-text"><i class="fas fa-terminal"></i> ${obj.admin ? 'Administrator' : 'Manager'}</div>
</div>
</div>
</div>
`;
$('#users-list').append(html);
});
}
function escapeHtml(unsafe) {
}
/**
* Escapes HTML characters in a string to prevent XSS.
* @param {string} unsafe - The string to escape.
* @returns {string} - The escaped string.
*/
function escapeHtml(unsafe) {
if (typeof unsafe !== "string") return unsafe;
return unsafe
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&#039;");
}
function prettyDateTime(timeStr) {
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&#039;");
}
/**
* Formats a date/time string into a human-readable format.
* @param {string} timeStr - The time string to format.
* @returns {string} - The formatted date/time.
*/
function prettyDateTime(timeStr) {
const dt = new Date(timeStr);
const offsetMs = dt.getTimezoneOffset() * 60 * 1000;
const dateLocal = new Date(dt.getTime() - offsetMs);
return dateLocal.toISOString().slice(0, 19).replace(/-/g, "/").replace("T", " ");
}
}

View File

@ -1,10 +1,26 @@
package emailer
// Attachment represents a file attachment to be sent with an email.
type Attachment struct {
// Name is the filename of the attachment.
Name string
// Data holds the binary content of the attachment.
Data []byte
}
// Emailer defines an interface for sending emails.
// Implementations should handle constructing and sending emails
// with optional attachments.
type Emailer interface {
// Send sends an email with the provided details.
//
// Parameters:
// toName: The recipient's name.
// to: The recipient's email address.
// subject: The email subject.
// content: The email body content (can include HTML).
// attachments: A slice of Attachment objects to include with the email.
//
// Returns an error if sending fails.
Send(toName string, to string, subject string, content string, attachments []Attachment) error
}

View File

@ -7,46 +7,54 @@ import (
"github.com/sendgrid/sendgrid-go/helpers/mail"
)
// SendgridApiMail implements the Emailer interface using SendGrid's API.
type SendgridApiMail struct {
apiKey string
fromName string
from string
}
// NewSendgridApiMail creates a new SendgridApiMail instance with the provided API key and sender information.
func NewSendgridApiMail(apiKey, fromName, from string) *SendgridApiMail {
ans := SendgridApiMail{apiKey: apiKey, fromName: fromName, from: from}
return &ans
return &SendgridApiMail{
apiKey: apiKey,
fromName: fromName,
from: from,
}
}
func (o *SendgridApiMail) Send(toName string, to string, subject string, content string, attachments []Attachment) error {
// Send sends an email using the SendGrid API.
// It builds a V3Mail object with the given recipient details, subject, content, and attachments.
func (s *SendgridApiMail) Send(toName, to, subject, content string, attachments []Attachment) error {
m := mail.NewV3Mail()
mailFrom := mail.NewEmail(o.fromName, o.from)
mailContent := mail.NewContent("text/html", content)
// Set sender, recipient, content, and subject.
mailFrom := mail.NewEmail(s.fromName, s.from)
mailTo := mail.NewEmail(toName, to)
m.SetFrom(mailFrom)
m.AddContent(mailContent)
m.AddContent(mail.NewContent("text/html", content))
personalization := mail.NewPersonalization()
personalization.AddTos(mailTo)
personalization.Subject = subject
m.AddPersonalizations(personalization)
toAdd := make([]*mail.Attachment, 0, len(attachments))
for i := range attachments {
var att mail.Attachment
encoded := base64.StdEncoding.EncodeToString(attachments[i].Data)
att.SetContent(encoded)
att.SetType("text/plain")
att.SetFilename(attachments[i].Name)
att.SetDisposition("attachment")
toAdd = append(toAdd, &att)
// Process attachments.
var sgAttachments []*mail.Attachment
for _, a := range attachments {
encoded := base64.StdEncoding.EncodeToString(a.Data)
sgAtt := mail.NewAttachment()
sgAtt.SetContent(encoded)
// Set a default content type. Adjust if you need to support other file types.
sgAtt.SetType("text/plain")
sgAtt.SetFilename(a.Name)
sgAtt.SetDisposition("attachment")
sgAttachments = append(sgAttachments, sgAtt)
}
m.AddAttachment(sgAttachments...)
m.AddAttachment(toAdd...)
request := sendgrid.GetRequest(o.apiKey, "/v3/mail/send", "https://api.sendgrid.com")
// Build and send the request.
request := sendgrid.GetRequest(s.apiKey, "/v3/mail/send", "https://api.sendgrid.com")
request.Method = "POST"
request.Body = mail.GetRequestBody(m)
_, err := sendgrid.API(request)

View File

@ -9,6 +9,7 @@ import (
mail "github.com/xhit/go-simple-mail/v2"
)
// SmtpMail implements the Emailer interface using an SMTP server.
type SmtpMail struct {
hostname string
port int
@ -22,8 +23,9 @@ type SmtpMail struct {
from string
}
func authType(authType string) mail.AuthType {
switch strings.ToUpper(authType) {
// authType converts a string to the corresponding mail.AuthType.
func authType(authTypeStr string) mail.AuthType {
switch strings.ToUpper(authTypeStr) {
case "PLAIN":
return mail.AuthPlain
case "LOGIN":
@ -33,8 +35,9 @@ func authType(authType string) mail.AuthType {
}
}
func encryptionType(encryptionType string) mail.Encryption {
switch strings.ToUpper(encryptionType) {
// encryptionType converts a string to the corresponding mail.Encryption.
func encryptionType(encryptionTypeStr string) mail.Encryption {
switch strings.ToUpper(encryptionTypeStr) {
case "NONE":
return mail.EncryptionNone
case "SSL":
@ -48,53 +51,71 @@ func encryptionType(encryptionType string) mail.Encryption {
}
}
func NewSmtpMail(hostname string, port int, username string, password string, SmtpHelo string, noTLSCheck bool, auth string, fromName, from string, encryption string) *SmtpMail {
ans := SmtpMail{hostname: hostname, port: port, username: username, password: password, smtpHelo: SmtpHelo, noTLSCheck: noTLSCheck, fromName: fromName, from: from, authType: authType(auth), encryption: encryptionType(encryption)}
return &ans
// NewSmtpMail returns a new instance of SmtpMail configured with the provided parameters.
func NewSmtpMail(hostname string, port int, username string, password string, smtpHelo string, noTLSCheck bool, auth string, fromName, from string, encryption string) *SmtpMail {
return &SmtpMail{
hostname: hostname,
port: port,
username: username,
password: password,
smtpHelo: smtpHelo,
noTLSCheck: noTLSCheck,
fromName: fromName,
from: from,
authType: authType(auth),
encryption: encryptionType(encryption),
}
}
func addressField(address string, name string) string {
// addressField formats an email address with an optional display name.
func addressField(address, name string) string {
if name == "" {
return address
}
return fmt.Sprintf("%s <%s>", name, address)
}
func (o *SmtpMail) Send(toName string, to string, subject string, content string, attachments []Attachment) error {
// Send sends an email with the specified details and attachments via SMTP.
func (s *SmtpMail) Send(toName, to, subject, content string, attachments []Attachment) error {
server := mail.NewSMTPClient()
server.Host = o.hostname
server.Port = o.port
server.Authentication = o.authType
server.Username = o.username
server.Password = o.password
server.Helo = o.smtpHelo
server.Encryption = o.encryption
server.Host = s.hostname
server.Port = s.port
server.Authentication = s.authType
server.Username = s.username
server.Password = s.password
server.Helo = s.smtpHelo
server.Encryption = s.encryption
server.KeepAlive = false
server.ConnectTimeout = 10 * time.Second
server.SendTimeout = 10 * time.Second
if o.noTLSCheck {
// If noTLSCheck is true, skip TLS certificate verification.
if s.noTLSCheck {
server.TLSConfig = &tls.Config{InsecureSkipVerify: true}
}
smtpClient, err := server.Connect()
if err != nil {
return err
return fmt.Errorf("failed to connect to SMTP server: %w", err)
}
email := mail.NewMSG()
email.SetFrom(addressField(o.from, o.fromName)).
email.SetFrom(addressField(s.from, s.fromName)).
AddTo(addressField(to, toName)).
SetSubject(subject).
SetBody(mail.TextHTML, content)
for _, v := range attachments {
email.Attach(&mail.File{Name: v.Name, Data: v.Data})
// Attach files, if any.
for _, att := range attachments {
email.Attach(&mail.File{
Name: att.Name,
Data: att.Data,
})
}
err = email.Send(smtpClient)
if err := email.Send(smtpClient); err != nil {
return fmt.Errorf("failed to send email: %w", err)
}
return err
return nil
}

View File

@ -2,19 +2,23 @@ package handler
import (
"net/http"
"strings"
"github.com/labstack/echo/v4"
)
// ContentTypeJson checks that the requests have the Content-Type header set to "application/json".
// This helps against CSRF attacks.
// ContentTypeJson is middleware that ensures the request's Content-Type header
// starts with "application/json". This helps mitigate CSRF attacks by rejecting
// requests that do not explicitly signal a JSON payload.
func ContentTypeJson(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
contentType := c.Request().Header.Get("Content-Type")
if contentType != "application/json" {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "Only JSON allowed"})
if !strings.HasPrefix(contentType, "application/json") {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{
Success: false,
Message: "Only JSON allowed",
})
}
return next(c)
}
}

View File

@ -1,6 +1,6 @@
package handler
type jsonHTTPResponse struct {
Status bool `json:"status"`
Success bool `json:"success"`
Message string `json:"message"`
}

File diff suppressed because it is too large Load Diff

View File

@ -11,22 +11,24 @@ import (
"github.com/swissmakers/wireguard-manager/util"
)
// ValidSession is middleware that checks for a valid session.
// If the session is invalid, it redirects the user to the login page.
func ValidSession(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if !isValidSession(c) {
nextURL := c.Request().URL
if nextURL != nil && c.Request().Method == http.MethodGet {
return c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf(util.BasePath+"/login?next=%s", c.Request().URL))
} else {
return c.Redirect(http.StatusTemporaryRedirect, util.BasePath+"/login")
// If the request is a GET, append the current URL as a query parameter "next"
if c.Request().Method == http.MethodGet {
return c.Redirect(http.StatusTemporaryRedirect,
fmt.Sprintf("%s/login?next=%s", util.BasePath, c.Request().URL.String()))
}
return c.Redirect(http.StatusTemporaryRedirect, util.BasePath+"/login?next="+util.BasePath)
}
return next(c)
}
}
// RefreshSession must only be used after ValidSession middleware
// RefreshSession checks if the session is eligible for the refresh, but doesn't check if it's fully valid
// RefreshSession middleware refreshes a "remember me" session.
// This should be used after ValidSession has verified the session.
func RefreshSession(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
doRefreshSession(c)
@ -34,6 +36,7 @@ func RefreshSession(next echo.HandlerFunc) echo.HandlerFunc {
}
}
// NeedsAdmin middleware ensures that only admin users proceed.
func NeedsAdmin(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if !isAdmin(c) {
@ -43,23 +46,31 @@ func NeedsAdmin(next echo.HandlerFunc) echo.HandlerFunc {
}
}
// isValidSession checks whether the session is valid.
func isValidSession(c echo.Context) bool {
// If login is disabled, always return true.
if util.DisableLogin {
return true
}
sess, _ := session.Get("session", c)
// Retrieve session; if an error occurs, consider the session invalid.
sess, err := session.Get("session", c)
if err != nil {
return false
}
// Check for a valid session token in both session and cookie.
cookie, err := c.Cookie("session_token")
if err != nil || sess.Values["session_token"] != cookie.Value {
return false
}
// Check time bounds
// Check time bounds.
createdAt := getCreatedAt(sess)
updatedAt := getUpdatedAt(sess)
maxAge := getMaxAge(sess)
// Temporary session is considered valid within 24h if browser is not closed before
// This value is not saved and is used as virtual expiration
if maxAge == 0 {
// Default temporary session duration (24h) when not set.
maxAge = 86400
}
expiration := updatedAt + int64(maxAge)
@ -68,7 +79,7 @@ func isValidSession(c echo.Context) bool {
return false
}
// Check if user still exists and unchanged
// Check if user still exists and has not changed.
username := fmt.Sprintf("%s", sess.Values["username"])
userHash := getUserHash(sess)
if uHash, ok := util.DBUsersToCRC32[username]; !ok || userHash != uHash {
@ -78,15 +89,19 @@ func isValidSession(c echo.Context) bool {
return true
}
// Refreshes a "remember me" session when the user visits web pages (not API)
// Session must be valid before calling this function
// Refresh is performed at most once per 24h
// doRefreshSession refreshes the session data if the session is eligible.
// The session must already be valid before calling this function.
func doRefreshSession(c echo.Context) {
if util.DisableLogin {
return
}
sess, _ := session.Get("session", c)
sess, err := session.Get("session", c)
if err != nil {
// Cannot retrieve session; nothing to do.
return
}
maxAge := getMaxAge(sess)
if maxAge <= 0 {
return
@ -97,17 +112,20 @@ func doRefreshSession(c echo.Context) {
return
}
// Refresh no sooner than 24h
// Determine if a refresh is due.
createdAt := getCreatedAt(sess)
updatedAt := getUpdatedAt(sess)
expiration := updatedAt + int64(getMaxAge(sess))
expiration := updatedAt + int64(maxAge)
now := time.Now().UTC().Unix()
// Only refresh if at least 24h have passed since last update
// and the session has not yet reached its maximum duration.
if updatedAt > now || expiration < now || now-updatedAt < 86_400 || createdAt+util.SessionMaxDuration < now {
return
}
cookiePath := util.GetCookiePath()
// Update the session timestamp.
sess.Values["updated_at"] = now
sess.Options = &sessions.Options{
Path: cookiePath,
@ -115,131 +133,148 @@ func doRefreshSession(c echo.Context) {
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
sess.Save(c.Request(), c.Response())
if err := sess.Save(c.Request(), c.Response()); err != nil {
// Log error if needed.
return
}
cookie := new(http.Cookie)
cookie.Name = "session_token"
cookie.Path = cookiePath
cookie.Value = oldCookie.Value
cookie.MaxAge = maxAge
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteLaxMode
// Reset the session cookie.
cookie := &http.Cookie{
Name: "session_token",
Path: cookiePath,
Value: oldCookie.Value,
MaxAge: maxAge,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
c.SetCookie(cookie)
}
// Get time in seconds this session is valid without updating
// getMaxAge returns the session's maximum age (in seconds).
func getMaxAge(sess *sessions.Session) int {
if util.DisableLogin {
return 0
}
maxAge := sess.Values["max_age"]
switch typedMaxAge := maxAge.(type) {
maxAgeVal := sess.Values["max_age"]
switch v := maxAgeVal.(type) {
case int:
return typedMaxAge
return v
case int64:
return int(v)
default:
return 0
}
}
// Get a timestamp in seconds of the time the session was created
// getCreatedAt returns the timestamp when the session was created.
func getCreatedAt(sess *sessions.Session) int64 {
if util.DisableLogin {
return 0
}
createdAt := sess.Values["created_at"]
switch typedCreatedAt := createdAt.(type) {
createdAtVal := sess.Values["created_at"]
switch v := createdAtVal.(type) {
case int:
return int64(v)
case int64:
return typedCreatedAt
return v
default:
return 0
}
}
// Get a timestamp in seconds of the last session update
// getUpdatedAt returns the timestamp of the last session update.
func getUpdatedAt(sess *sessions.Session) int64 {
if util.DisableLogin {
return 0
}
lastUpdate := sess.Values["updated_at"]
switch typedLastUpdate := lastUpdate.(type) {
updatedAtVal := sess.Values["updated_at"]
switch v := updatedAtVal.(type) {
case int:
return int64(v)
case int64:
return typedLastUpdate
return v
default:
return 0
}
}
// Get CRC32 of a user at the moment of log in
// Any changes to user will result in logout of other (not updated) sessions
// getUserHash returns the CRC32 hash of the user at the time of login.
func getUserHash(sess *sessions.Session) uint32 {
if util.DisableLogin {
return 0
}
userHash := sess.Values["user_hash"]
switch typedUserHash := userHash.(type) {
userHashVal := sess.Values["user_hash"]
switch v := userHashVal.(type) {
case uint32:
return typedUserHash
return v
// In case the hash was stored as an int, convert it.
case int:
return uint32(v)
case int64:
return uint32(v)
default:
return 0
}
}
// currentUser to get username of logged in user
// currentUser retrieves the username of the logged-in user.
func currentUser(c echo.Context) string {
if util.DisableLogin {
return ""
}
sess, _ := session.Get("session", c)
username := fmt.Sprintf("%s", sess.Values["username"])
return username
sess, err := session.Get("session", c)
if err != nil {
return ""
}
return fmt.Sprintf("%s", sess.Values["username"])
}
// isAdmin to get user type: admin or manager
// isAdmin checks whether the logged-in user is an admin.
func isAdmin(c echo.Context) bool {
if util.DisableLogin {
return true
}
sess, _ := session.Get("session", c)
admin := fmt.Sprintf("%t", sess.Values["admin"])
return admin == "true"
sess, err := session.Get("session", c)
if err != nil {
return false
}
// Use type assertion for a boolean.
if admin, ok := sess.Values["admin"].(bool); ok {
return admin
}
return false
}
// setUser updates the session with new user information.
func setUser(c echo.Context, username string, admin bool, userCRC32 uint32) {
sess, _ := session.Get("session", c)
sess, err := session.Get("session", c)
if err != nil {
return
}
sess.Values["username"] = username
sess.Values["user_hash"] = userCRC32
sess.Values["admin"] = admin
sess.Save(c.Request(), c.Response())
_ = sess.Save(c.Request(), c.Response())
}
// clearSession to remove current session
// clearSession removes the current session data and invalidates the session cookie.
func clearSession(c echo.Context) {
sess, _ := session.Get("session", c)
sess.Values["username"] = ""
sess.Values["user_hash"] = 0
sess.Values["admin"] = false
sess.Values["session_token"] = ""
sess.Values["max_age"] = -1
sess.Options.MaxAge = -1
sess.Save(c.Request(), c.Response())
cookiePath := util.GetCookiePath()
cookie, err := c.Cookie("session_token")
if err != nil {
cookie = new(http.Cookie)
sess, err := session.Get("session", c)
if err == nil {
sess.Values["username"] = ""
sess.Values["user_hash"] = 0
sess.Values["admin"] = false
sess.Values["session_token"] = ""
sess.Values["max_age"] = -1
sess.Options.MaxAge = -1
_ = sess.Save(c.Request(), c.Response())
}
cookiePath := util.GetCookiePath()
cookie, err := c.Cookie("session_token")
if err != nil {
cookie = &http.Cookie{}
}
cookie.Name = "session_token"
cookie.Path = cookiePath
cookie.MaxAge = -1

186
main.go
View File

@ -13,39 +13,42 @@ import (
"syscall"
"time"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
"github.com/swissmakers/wireguard-manager/store"
"github.com/swissmakers/wireguard-manager/emailer"
"github.com/swissmakers/wireguard-manager/handler"
"github.com/swissmakers/wireguard-manager/router"
"github.com/swissmakers/wireguard-manager/store"
"github.com/swissmakers/wireguard-manager/store/jsondb"
"github.com/swissmakers/wireguard-manager/util"
)
var (
// command-line banner information
// App version information.
appVersion = "stable"
gitCommit = "N/A"
gitRef = "N/A"
buildTime = time.Now().UTC().Format("01-02-2006 15:04:05")
// configuration variables
flagDisableLogin = false
flagProxy = false
flagBindAddress = "0.0.0.0:5000"
flagSmtpHostname = "127.0.0.1"
flagSmtpPort = 25
flagSmtpUsername string
flagSmtpPassword string
flagSmtpAuthType = "NONE"
flagSmtpNoTLSCheck = false
flagSmtpEncryption = "STARTTLS"
flagSmtpHelo = "localhost"
flagSendgridApiKey string
flagEmailFrom string
flagEmailFromName = "WireGuard Manager"
// Configuration variables with defaults.
flagDisableLogin = false
flagProxy = false
flagBindAddress = "0.0.0.0:5000"
flagSmtpHostname = "127.0.0.1"
flagSmtpPort = 25
flagSmtpUsername string
flagSmtpPassword string
flagSmtpAuthType = "NONE"
flagSmtpNoTLSCheck = false
flagSmtpEncryption = "STARTTLS"
flagSmtpHelo = "localhost"
flagSendgridApiKey string
flagEmailFrom string
flagEmailFromName = "WireGuard Manager"
// IMPORTANT: If no SESSION_SECRET is provided via environment or file,
// a random secret is generated which will change on every restart.
// For production, be sure to supply a fixed value.
flagSessionSecret = util.RandomString(32)
flagSessionMaxDuration = 90
flagWgConfTemplate string
@ -57,23 +60,18 @@ const (
defaultEmailSubject = "Your wireguard configuration"
defaultEmailContent = `Hi,</br>
<p>In this email you can find your personal configuration for our wireguard server.</p>
<p>Best</p>
`
)
// embed the "templates" directory
//
//go:embed templates/*
var embeddedTemplates embed.FS
// embed the "assets" directory
//
//go:embed assets/*
var embeddedAssets embed.FS
func init() {
// command-line flags and env variables
// Bind command-line flags and environment variables.
flag.BoolVar(&flagDisableLogin, "disable-login", util.LookupEnvOrBool("DISABLE_LOGIN", flagDisableLogin), "Disable authentication on the app. This is potentially dangerous.")
flag.BoolVar(&flagProxy, "proxy", util.LookupEnvOrBool("PROXY", flagProxy), "Behind a proxy. Use X-FORWARDED-FOR for failed login logging")
flag.StringVar(&flagBindAddress, "bind-address", util.LookupEnvOrString("BIND_ADDRESS", flagBindAddress), "Address:Port to which the app will be bound.")
@ -82,8 +80,8 @@ func init() {
flag.StringVar(&flagSmtpHelo, "smtp-helo", util.LookupEnvOrString("SMTP_HELO", flagSmtpHelo), "SMTP HELO Hostname")
flag.StringVar(&flagSmtpUsername, "smtp-username", util.LookupEnvOrString("SMTP_USERNAME", flagSmtpUsername), "SMTP Username")
flag.BoolVar(&flagSmtpNoTLSCheck, "smtp-no-tls-check", util.LookupEnvOrBool("SMTP_NO_TLS_CHECK", flagSmtpNoTLSCheck), "Disable TLS verification for SMTP. This is potentially dangerous.")
flag.StringVar(&flagSmtpEncryption, "smtp-encryption", util.LookupEnvOrString("SMTP_ENCRYPTION", flagSmtpEncryption), "SMTP Encryption : NONE, SSL, SSLTLS, TLS or STARTTLS (by default)")
flag.StringVar(&flagSmtpAuthType, "smtp-auth-type", util.LookupEnvOrString("SMTP_AUTH_TYPE", flagSmtpAuthType), "SMTP Auth Type : PLAIN, LOGIN or NONE.")
flag.StringVar(&flagSmtpEncryption, "smtp-encryption", util.LookupEnvOrString("SMTP_ENCRYPTION", flagSmtpEncryption), "SMTP Encryption: NONE, SSL, SSLTLS, TLS or STARTTLS (by default)")
flag.StringVar(&flagSmtpAuthType, "smtp-auth-type", util.LookupEnvOrString("SMTP_AUTH_TYPE", flagSmtpAuthType), "SMTP Auth Type: PLAIN, LOGIN or NONE.")
flag.StringVar(&flagEmailFrom, "email-from", util.LookupEnvOrString("EMAIL_FROM_ADDRESS", flagEmailFrom), "'From' email address.")
flag.StringVar(&flagEmailFromName, "email-from-name", util.LookupEnvOrString("EMAIL_FROM_NAME", flagEmailFromName), "'From' email name.")
flag.StringVar(&flagWgConfTemplate, "wg-conf-template", util.LookupEnvOrString("WG_CONF_TEMPLATE", flagWgConfTemplate), "Path to custom wg.conf template.")
@ -91,27 +89,25 @@ func init() {
flag.StringVar(&flagSubnetRanges, "subnet-ranges", util.LookupEnvOrString("SUBNET_RANGES", flagSubnetRanges), "IP ranges to choose from when assigning an IP for a client.")
flag.IntVar(&flagSessionMaxDuration, "session-max-duration", util.LookupEnvOrInt("SESSION_MAX_DURATION", flagSessionMaxDuration), "Max time in days a remembered session is refreshed and valid.")
// Handle SMTP password, Sendgrid API key and session secret.
var (
smtpPasswordLookup = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword)
sendgridApiKeyLookup = util.LookupEnvOrString("SENDGRID_API_KEY", flagSendgridApiKey)
sessionSecretLookup = util.LookupEnvOrString("SESSION_SECRET", flagSessionSecret)
)
// check empty smtpPassword env var
if smtpPasswordLookup != "" {
flag.StringVar(&flagSmtpPassword, "smtp-password", smtpPasswordLookup, "SMTP Password")
} else {
flag.StringVar(&flagSmtpPassword, "smtp-password", util.LookupEnvOrFile("SMTP_PASSWORD_FILE", flagSmtpPassword), "SMTP Password File")
}
// check empty sendgridApiKey env var
if sendgridApiKeyLookup != "" {
flag.StringVar(&flagSendgridApiKey, "sendgrid-api-key", sendgridApiKeyLookup, "Your sendgrid api key.")
} else {
flag.StringVar(&flagSendgridApiKey, "sendgrid-api-key", util.LookupEnvOrFile("SENDGRID_API_KEY_FILE", flagSendgridApiKey), "File containing your sendgrid api key.")
}
// check empty sessionSecret env var
if sessionSecretLookup != "" {
flag.StringVar(&flagSessionSecret, "session-secret", sessionSecretLookup, "The key used to encrypt session cookies.")
} else {
@ -120,7 +116,7 @@ func init() {
flag.Parse()
// update runtime config
// Update runtime config in util package.
util.DisableLogin = flagDisableLogin
util.Proxy = flagProxy
util.BindAddress = flagBindAddress
@ -135,17 +131,21 @@ func init() {
util.SendgridApiKey = flagSendgridApiKey
util.EmailFrom = flagEmailFrom
util.EmailFromName = flagEmailFromName
// Use a stable session secret if provided; otherwise a new random value is generated each run.
util.SessionSecret = sha512.Sum512([]byte(flagSessionSecret))
util.SessionMaxDuration = int64(flagSessionMaxDuration) * 86_400 // Store in seconds
// DEBUG: Log the session secret hash for verification (remove in production)
log.Debugf("Using session secret (SHA512 hash): %x", util.SessionSecret)
util.SessionMaxDuration = int64(flagSessionMaxDuration) * 86_400 // store in seconds
util.WgConfTemplate = flagWgConfTemplate
util.BasePath = util.ParseBasePath(flagBasePath)
util.SubnetRanges = util.ParseSubnetRanges(flagSubnetRanges)
// Set log level.
lvl, _ := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO"))
log.SetLevel(lvl)
// print only if log level is INFO or lower
// Print app information if log level is INFO or lower.
if lvl <= log.INFO {
// print app information
fmt.Println("WireGuard Manager")
fmt.Println("App Version\t:", appVersion)
fmt.Println("Git Commit\t:", gitCommit)
@ -154,10 +154,8 @@ func init() {
fmt.Println("Git Repo\t:", "https://github.com/swissmakers/wireguard-manager")
fmt.Println("Authentication\t:", !util.DisableLogin)
fmt.Println("Bind address\t:", util.BindAddress)
//fmt.Println("Sendgrid key\t:", util.SendgridApiKey)
fmt.Println("Email from\t:", util.EmailFrom)
fmt.Println("Email from name\t:", util.EmailFromName)
//fmt.Println("Session secret\t:", util.SessionSecret)
fmt.Println("Custom wg.conf\t:", util.WgConfTemplate)
fmt.Println("Base path\t:", util.BasePath+"/")
fmt.Println("Subnet ranges\t:", util.GetSubnetRangesString())
@ -165,47 +163,67 @@ func init() {
}
func main() {
// Initialize the JSON DB store.
db, err := jsondb.New("./db")
if err != nil {
panic(err)
log.Fatalf("Error initializing database: %v", err)
}
if err := db.Init(); err != nil {
panic(err)
log.Fatalf("Error initializing database: %v", err)
}
// set app extra data
extraData := make(map[string]interface{})
extraData["appVersion"] = appVersion
extraData["gitCommit"] = gitCommit
extraData["basePath"] = util.BasePath
extraData["loginDisabled"] = flagDisableLogin
// strip the "templates/" prefix from the embedded directory so files can be read by their direct name (e.g.
// "base.html" instead of "templates/base.html")
tmplDir, _ := fs.Sub(fs.FS(embeddedTemplates), "templates")
// Extra app data for templates.
extraData := map[string]interface{}{
"appVersion": appVersion,
"gitCommit": gitCommit,
"basePath": util.BasePath,
"loginDisabled": flagDisableLogin,
}
// create the wireguard config on start, if it doesn't exist
// Strip the "templates/" prefix from the embedded templates directory.
tmplDir, err := fs.Sub(embeddedTemplates, "templates")
if err != nil {
log.Fatalf("Error processing templates: %v", err)
}
// Create the WireGuard server configuration if it doesn't exist.
initServerConfig(db, tmplDir)
// Check if subnet ranges are valid for the server configuration
// Remove any non-valid CIDRs
// Validate and fix subnet ranges.
if err := util.ValidateAndFixSubnetRanges(db); err != nil {
panic(err)
log.Fatalf("Invalid subnet ranges: %v", err)
}
// Print valid ranges
if lvl, _ := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO")); lvl <= log.INFO {
fmt.Println("Valid subnet ranges:", util.GetSubnetRangesString())
}
// register routes
// Initialize the Echo router using our optimized router.New.
app := router.New(tmplDir, extraData, util.SessionSecret)
// Additional middleware: Clear invalid session cookies from both response and request.
app.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if _, err := session.Get("session", c); err != nil {
log.Debugf("session.Get failed: %v", err)
// Clear invalid cookie in response.
cookie := &http.Cookie{
Name: "session_token",
Value: "",
Path: util.GetCookiePath(),
MaxAge: -1,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
c.SetCookie(cookie)
// Also remove the invalid cookie from the request header.
c.Request().Header.Del("Cookie")
}
return next(c)
}
})
// Register routes. (Note: The order of middleware matters.)
app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession, handler.RefreshSession)
// Important: Make sure that all non-GET routes check the request content type using handler.ContentTypeJson to
// mitigate CSRF attacks. This is effective, because browsers don't allow setting the Content-Type header on
// cross-origin requests.
if !util.DisableLogin {
app.GET(util.BasePath+"/login", handler.LoginPage())
app.POST(util.BasePath+"/login", handler.Login(db), handler.ContentTypeJson)
@ -219,61 +237,71 @@ func main() {
app.GET(util.BasePath+"/api/user/:username", handler.GetUser(db), handler.ValidSession)
}
// Initialize the email sender.
var sendmail emailer.Emailer
if util.SendgridApiKey != "" {
sendmail = emailer.NewSendgridApiMail(util.SendgridApiKey, util.EmailFromName, util.EmailFrom)
} else {
sendmail = emailer.NewSmtpMail(util.SmtpHostname, util.SmtpPort, util.SmtpUsername, util.SmtpPassword, util.SmtpHelo, util.SmtpNoTLSCheck, util.SmtpAuthType, util.EmailFromName, util.EmailFrom, util.SmtpEncryption)
sendmail = emailer.NewSmtpMail(util.SmtpHostname, util.SmtpPort, util.SmtpUsername, util.SmtpPassword,
util.SmtpHelo, util.SmtpNoTLSCheck, util.SmtpAuthType, util.EmailFromName, util.EmailFrom, util.SmtpEncryption)
}
// Additional API and page routes.
app.GET(util.BasePath+"/test-hash", handler.GetHashesChanges(db), handler.ValidSession)
app.GET(util.BasePath+"/_health", handler.Health())
app.GET(util.BasePath+"/favicon", handler.Favicon())
app.POST(util.BasePath+"/new-client", handler.NewClient(db), handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/update-client", handler.UpdateClient(db), handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/email-client", handler.EmailClient(db, sendmail, defaultEmailSubject, defaultEmailContent), handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/email-client", handler.EmailClient(db, sendmail, defaultEmailSubject, defaultEmailContent),
handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/client/set-status", handler.SetClientStatus(db), handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/remove-client", handler.RemoveClient(db), handler.ValidSession, handler.ContentTypeJson)
app.GET(util.BasePath+"/download", handler.DownloadClient(db), handler.ValidSession)
app.GET(util.BasePath+"/wg-server", handler.WireGuardServer(db), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/wg-server/interfaces", handler.WireGuardServerInterfaces(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.POST(util.BasePath+"/wg-server/keypair", handler.WireGuardServerKeyPair(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.GET(util.BasePath+"/global-settings", handler.GlobalSettings(db), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/global-settings", handler.GlobalSettingSubmit(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.POST(util.BasePath+"/wg-server/interfaces", handler.WireGuardServerInterfaces(db),
handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.POST(util.BasePath+"/wg-server/keypair", handler.WireGuardServerKeyPair(db),
handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.GET(util.BasePath+"/global-settings", handler.GlobalSettings(db),
handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/global-settings", handler.GlobalSettingSubmit(db),
handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.GET(util.BasePath+"/status", handler.Status(db), handler.ValidSession, handler.RefreshSession)
app.GET(util.BasePath+"/api/clients", handler.GetClients(db), handler.ValidSession)
app.GET(util.BasePath+"/api/client/:id", handler.GetClient(db), handler.ValidSession)
app.GET(util.BasePath+"/api/machine-ips", handler.MachineIPAddresses(), handler.ValidSession)
app.GET(util.BasePath+"/api/subnet-ranges", handler.GetOrderedSubnetRanges(), handler.ValidSession)
app.GET(util.BasePath+"/api/suggest-client-ips", handler.SuggestIPAllocation(db), handler.ValidSession)
app.POST(util.BasePath+"/api/apply-wg-config", handler.ApplyServerConfig(db, tmplDir), handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/api/apply-wg-config", handler.ApplyServerConfig(db, tmplDir),
handler.ValidSession, handler.ContentTypeJson)
// strip the "assets/" prefix from the embedded directory so files can be called directly without the "assets/"
// prefix
assetsDir, _ := fs.Sub(fs.FS(embeddedAssets), "assets")
// Serve static files from the embedded assets.
assetsDir, err := fs.Sub(embeddedAssets, "assets")
if err != nil {
log.Fatalf("Error processing assets: %v", err)
}
assetHandler := http.FileServer(http.FS(assetsDir))
// serves other static files
app.GET(util.BasePath+"/static/*", echo.WrapHandler(http.StripPrefix(util.BasePath+"/static/", assetHandler)))
// Listen on the appropriate socket.
if strings.HasPrefix(util.BindAddress, "unix://") {
// Listen on unix domain socket.
// https://github.com/labstack/echo/issues/830
err := syscall.Unlink(util.BindAddress[6:])
if err != nil {
app.Logger.Fatalf("Cannot unlink unix socket: Error: %v", err)
// For Unix domain sockets.
if err := syscall.Unlink(util.BindAddress[6:]); err != nil {
app.Logger.Fatalf("Cannot unlink unix socket: %v", err)
}
l, err := net.Listen("unix", util.BindAddress[6:])
if err != nil {
app.Logger.Fatalf("Cannot create unix socket. Error: %v", err)
app.Logger.Fatalf("Cannot create unix socket: %v", err)
}
app.Listener = l
app.Logger.Fatal(app.Start(""))
} else {
// Listen on TCP socket
// For TCP sockets.
app.Logger.Fatal(app.Start(util.BindAddress))
}
}
// initServerConfig creates the WireGuard config file if it doesn't exist.
func initServerConfig(db store.IStore, tmplDir fs.FS) {
settings, err := db.GetGlobalSettings()
if err != nil {
@ -281,7 +309,7 @@ func initServerConfig(db store.IStore, tmplDir fs.FS) {
}
if _, err := os.Stat(settings.ConfigFilePath); err == nil {
// file exists, don't overwrite it implicitly
// Config file exists; do not overwrite.
return
}
@ -300,9 +328,7 @@ func initServerConfig(db store.IStore, tmplDir fs.FS) {
log.Fatalf("Cannot get user config: %v", err)
}
// write config file
err = util.WriteWireGuardServerConfig(tmplDir, server, clients, users, settings)
if err != nil {
if err := util.WriteWireGuardServerConfig(tmplDir, server, clients, users, settings); err != nil {
log.Fatalf("Cannot create server config: %v", err)
}
}

View File

@ -4,34 +4,76 @@ import (
"time"
)
// Client model
// Client represents a WireGuard client configuration.
type Client struct {
ID string `json:"id"`
PrivateKey string `json:"private_key"`
PublicKey string `json:"public_key"`
PresharedKey string `json:"preshared_key"`
Name string `json:"name"`
Email string `json:"email"`
SubnetRanges []string `json:"subnet_ranges,omitempty"`
AllocatedIPs []string `json:"allocated_ips"`
AllowedIPs []string `json:"allowed_ips"`
ExtraAllowedIPs []string `json:"extra_allowed_ips"`
Endpoint string `json:"endpoint"`
AdditionalNotes string `json:"additional_notes"`
UseServerDNS bool `json:"use_server_dns"`
Enabled bool `json:"enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
// ID is a unique identifier for the client.
ID string `json:"id"`
// PrivateKey is the client's private key used for encryption.
// It may be empty if the public key was provided externally.
PrivateKey string `json:"private_key"`
// PublicKey is the client's public key.
PublicKey string `json:"public_key"`
// PresharedKey is an optional key used to enhance security.
PresharedKey string `json:"preshared_key"`
// Name is the friendly name assigned to the client.
Name string `json:"name"`
// Email is the email address associated with the client.
Email string `json:"email"`
// SubnetRanges holds the names of subnet ranges from which the clients IPs were allocated.
// This field is omitted from JSON output if empty.
SubnetRanges []string `json:"subnet_ranges,omitempty"`
// AllocatedIPs is the list of IP addresses allocated to the client.
AllocatedIPs []string `json:"allocated_ips"`
// AllowedIPs defines the CIDR ranges that are allowed to route traffic.
AllowedIPs []string `json:"allowed_ips"`
// ExtraAllowedIPs defines additional CIDR ranges allowed for routing.
ExtraAllowedIPs []string `json:"extra_allowed_ips"`
// Endpoint specifies the client's endpoint configuration.
Endpoint string `json:"endpoint"`
// AdditionalNotes are optional notes or comments about the client.
AdditionalNotes string `json:"additional_notes"`
// UseServerDNS indicates whether the client should use the server's DNS settings.
UseServerDNS bool `json:"use_server_dns"`
// Enabled indicates if the client is currently active.
Enabled bool `json:"enabled"`
// CreatedAt is the timestamp when the client was created.
CreatedAt time.Time `json:"created_at"`
// UpdatedAt is the timestamp of the client's last update.
UpdatedAt time.Time `json:"updated_at"`
}
// ClientData includes the Client and extra data
// ClientData wraps a Client with additional related data.
type ClientData struct {
// Client holds the client's configuration.
Client *Client
// QRCode is a base64-encoded representation of the client's configuration QR code.
QRCode string
}
// QRCodeSettings defines options for generating a QR code for a client.
type QRCodeSettings struct {
Enabled bool
// Enabled indicates whether QR code generation is enabled.
Enabled bool
// IncludeDNS specifies whether DNS settings should be included in the QR code.
IncludeDNS bool
// IncludeMTU specifies whether MTU settings should be included in the QR code.
IncludeMTU bool
}

View File

@ -1,9 +1,10 @@
package model
// ClientDefaults Defaults for creation of new clients used in the templates
// ClientDefaults holds the default settings for creating new clients.
// These defaults are used in the templates when rendering client creation forms.
type ClientDefaults struct {
AllowedIps []string
ExtraAllowedIps []string
UseServerDNS bool
EnableAfterCreation bool
AllowedIPs []string `json:"allowed_ips"` // Default allowed IP ranges.
ExtraAllowedIPs []string `json:"extra_allowed_ips"` // Additional allowed IP ranges.
UseServerDNS bool `json:"use_server_dns"` // Whether to use the server's DNS settings.
EnableAfterCreation bool `json:"enable_after_creation"` // Whether the client is enabled immediately after creation.
}

View File

@ -1,20 +1,22 @@
package model
// Interface model
// Interface represents a network interface with its name and IP address.
type Interface struct {
Name string `json:"name"`
IPAddress string `json:"ip_address"`
Name string `json:"name"` // Name of the interface (e.g., "eth0").
IPAddress string `json:"ip_address"` // IP address assigned to the interface.
}
// BaseData struct to pass value to the base template
// BaseData contains common data to be passed to templates.
// This includes the current active page, the current user's name, and whether they have admin privileges.
type BaseData struct {
Active string
CurrentUser string
Admin bool
Active string // The currently active page or section.
CurrentUser string // The username of the currently logged-in user.
Admin bool // Flag indicating if the current user has admin privileges.
}
// ClientServerHashes struct, to save hashes to detect changes
// ClientServerHashes holds hash values for client and server configurations.
// These hashes are used to detect changes in the configuration data.
type ClientServerHashes struct {
Client string `json:"client"`
Server string `json:"server"`
Client string `json:"client"` // Hash for the client configuration.
Server string `json:"server"` // Hash for the server configuration.
}

View File

@ -1,28 +1,27 @@
package model
import (
"time"
)
import "time"
// Server model
// Server represents the overall WireGuard server configuration,
// containing both the key pair and the network interface settings.
type Server struct {
KeyPair *ServerKeypair
Interface *ServerInterface
KeyPair *ServerKeypair `json:"keypair"` // The server's key pair used for encryption.
Interface *ServerInterface `json:"interface"` // The server's network interface configuration.
}
// ServerKeypair model
// ServerKeypair holds the cryptographic keys for the server.
type ServerKeypair struct {
PrivateKey string `json:"private_key"`
PublicKey string `json:"public_key"`
UpdatedAt time.Time `json:"updated_at"`
PrivateKey string `json:"private_key"` // The server's private key (should be kept secret).
PublicKey string `json:"public_key"` // The corresponding public key.
UpdatedAt time.Time `json:"updated_at"` // Timestamp of the last key update.
}
// ServerInterface model
// ServerInterface contains the network interface configuration for the server.
type ServerInterface struct {
Addresses []string `json:"addresses"`
ListenPort int `json:"listen_port,string"` // ,string to get listen_port string input as int
UpdatedAt time.Time `json:"updated_at"`
PostUp string `json:"post_up"`
PreDown string `json:"pre_down"`
PostDown string `json:"post_down"`
Addresses []string `json:"addresses"` // CIDR addresses assigned to the interface.
ListenPort int `json:"listen_port,string"` // Port on which the server listens (input as string in JSON, converted to int).
UpdatedAt time.Time `json:"updated_at"` // Timestamp of the last update to the interface configuration.
PostUp string `json:"post_up"` // Command to run after the interface is brought up.
PreDown string `json:"pre_down"` // Command to run before the interface is brought down.
PostDown string `json:"post_down"` // Command to run after the interface is brought down.
}

View File

@ -1,17 +1,16 @@
package model
import (
"time"
)
import "time"
// GlobalSetting model
// GlobalSetting represents the global configuration settings for the WireGuard server.
// Note: Some numeric values (e.g., MTU, PersistentKeepalive) are expected as strings in JSON.
type GlobalSetting struct {
EndpointAddress string `json:"endpoint_address"`
DNSServers []string `json:"dns_servers"`
MTU int `json:"mtu,string"`
PersistentKeepalive int `json:"persistent_keepalive,string"`
FirewallMark string `json:"firewall_mark"`
Table string `json:"table"`
ConfigFilePath string `json:"config_file_path"`
UpdatedAt time.Time `json:"updated_at"`
EndpointAddress string `json:"endpoint_address"` // The external endpoint address of the WireGuard server.
DNSServers []string `json:"dns_servers"` // List of DNS servers for client configuration.
MTU int `json:"mtu,string"` // Maximum Transmission Unit; JSON provides this value as a string.
PersistentKeepalive int `json:"persistent_keepalive,string"` // Keepalive interval (seconds); provided as a string in JSON.
FirewallMark string `json:"firewall_mark"` // Firewall mark used for routing.
Table string `json:"table"` // Routing table identifier.
ConfigFilePath string `json:"config_file_path"` // File path where the WireGuard config is generated.
UpdatedAt time.Time `json:"updated_at"` // Timestamp of the last update to the settings.
}

View File

@ -1,10 +1,29 @@
package model
// User model
import (
"encoding/json"
)
// User represents a user in the system.
// Note: The PasswordHash field takes precedence over Password.
type User struct {
Username string `json:"username"`
Password string `json:"password"`
// PasswordHash takes precedence over Password.
PasswordHash string `json:"password_hash"`
Username string `json:"username"`
Password string `json:"password"` // Used for binding input only.
PasswordHash string `json:"password_hash"` // Preferred field for authentication.
Admin bool `json:"admin"`
}
// MarshalJSON customizes the JSON encoding for User.
// It omits the plain-text Password field when marshalling, so that sensitive data is not leaked.
func (u User) MarshalJSON() ([]byte, error) {
// Define an alias to avoid infinite recursion.
type Alias User
return json.Marshal(&struct {
Password string `json:"password,omitempty"`
*Alias
}{
// Always output an empty string for Password.
Password: "",
Alias: (*Alias)(&u),
})
}

View File

@ -4,7 +4,7 @@ import (
"errors"
"io"
"io/fs"
"reflect"
"net/http"
"strings"
"text/template"
@ -16,30 +16,29 @@ import (
"github.com/swissmakers/wireguard-manager/util"
)
// TemplateRegistry is a custom html/template renderer for Echo framework
// TemplateRegistry is a custom html/template renderer for the Echo framework.
type TemplateRegistry struct {
templates map[string]*template.Template
extraData map[string]interface{}
}
// Render e.Renderer interface
// Render implements the e.Renderer interface.
// It injects extra data into the template if data is a map.
func (t *TemplateRegistry) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
tmpl, ok := t.templates[name]
if !ok {
err := errors.New("Template not found -> " + name)
return err
return errors.New("Template not found -> " + name)
}
// inject more app data information. E.g. appVersion
if reflect.TypeOf(data).Kind() == reflect.Map {
// Inject extra app data if data is a map.
if m, ok := data.(map[string]interface{}); ok {
for k, v := range t.extraData {
data.(map[string]interface{})[k] = v
m[k] = v
}
data.(map[string]interface{})["client_defaults"] = util.ClientDefaultsFromEnv()
m["client_defaults"] = util.ClientDefaultsFromEnv()
}
// login page does not need the base layout
// For the login page, no base layout is needed.
if name == "login.html" {
return tmpl.Execute(w, data)
}
@ -47,85 +46,115 @@ func (t *TemplateRegistry) Render(w io.Writer, name string, data interface{}, c
return tmpl.ExecuteTemplate(w, "base.html", data)
}
// New function
// New creates and configures an Echo router.
// It initializes the session store, loads templates from the provided fs.FS,
// sets up logging and validation, and returns the Echo instance.
func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte) *echo.Echo {
e := echo.New()
cookiePath := util.GetCookiePath()
//cookiePath := util.GetCookiePath()
//cookieStore := sessions.NewCookieStore(secret[:32], secret[32:])
//cookieStore.Options.Path = cookiePath
//cookieStore.Options.HttpOnly = true
//cookieStore.MaxAge(86400 * 7)
cookieStore := sessions.NewCookieStore(secret[:32], secret[32:])
cookieStore.Options.Path = cookiePath
cookieStore.Options.Path = util.GetCookiePath()
cookieStore.Options.HttpOnly = true
cookieStore.MaxAge(86400 * 7)
e.Use(session.Middleware(cookieStore))
// read html template file to string
// --- New middleware: Clear invalid session cookies ---
// If session.Get fails (e.g. due to securecookie errors),
// we clear the "session_token" cookie so that new sessions can be generated.
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if _, err := session.Get("session", c); err != nil {
log.Debugf("session.Get failed: %v", err)
// Clear the invalid session cookie.
cookie := &http.Cookie{
Name: "session_token",
Value: "",
Path: util.GetCookiePath(),
MaxAge: -1,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
c.SetCookie(cookie)
} else {
log.Debug("Session retrieved successfully")
}
return next(c)
}
})
// --- End new middleware ---
// Load HTML template files as strings.
tmplBaseString, err := util.StringFromEmbedFile(tmplDir, "base.html")
if err != nil {
log.Fatal(err)
}
tmplLoginString, err := util.StringFromEmbedFile(tmplDir, "login.html")
if err != nil {
log.Fatal(err)
}
tmplProfileString, err := util.StringFromEmbedFile(tmplDir, "profile.html")
if err != nil {
log.Fatal(err)
}
tmplClientsString, err := util.StringFromEmbedFile(tmplDir, "clients.html")
if err != nil {
log.Fatal(err)
}
tmplServerString, err := util.StringFromEmbedFile(tmplDir, "server.html")
if err != nil {
log.Fatal(err)
}
tmplGlobalSettingsString, err := util.StringFromEmbedFile(tmplDir, "global_settings.html")
if err != nil {
log.Fatal(err)
}
tmplUsersSettingsString, err := util.StringFromEmbedFile(tmplDir, "users_settings.html")
if err != nil {
log.Fatal(err)
}
tmplStatusString, err := util.StringFromEmbedFile(tmplDir, "status.html")
if err != nil {
log.Fatal(err)
}
// create template list
// Create a function map for templates.
funcs := template.FuncMap{
"StringsJoin": strings.Join,
}
templates := make(map[string]*template.Template)
templates["login.html"] = template.Must(template.New("login").Funcs(funcs).Parse(tmplLoginString))
templates["profile.html"] = template.Must(template.New("profile").Funcs(funcs).Parse(tmplBaseString + tmplProfileString))
templates["clients.html"] = template.Must(template.New("clients").Funcs(funcs).Parse(tmplBaseString + tmplClientsString))
templates["server.html"] = template.Must(template.New("server").Funcs(funcs).Parse(tmplBaseString + tmplServerString))
templates["global_settings.html"] = template.Must(template.New("global_settings").Funcs(funcs).Parse(tmplBaseString + tmplGlobalSettingsString))
templates["users_settings.html"] = template.Must(template.New("users_settings").Funcs(funcs).Parse(tmplBaseString + tmplUsersSettingsString))
templates["status.html"] = template.Must(template.New("status").Funcs(funcs).Parse(tmplBaseString + tmplStatusString))
// Build the map of templates.
templates := map[string]*template.Template{
"login.html": template.Must(template.New("login").Funcs(funcs).Parse(tmplLoginString)),
"profile.html": template.Must(template.New("profile").Funcs(funcs).Parse(tmplBaseString + tmplProfileString)),
"clients.html": template.Must(template.New("clients").Funcs(funcs).Parse(tmplBaseString + tmplClientsString)),
"server.html": template.Must(template.New("server").Funcs(funcs).Parse(tmplBaseString + tmplServerString)),
"global_settings.html": template.Must(template.New("global_settings").Funcs(funcs).Parse(tmplBaseString + tmplGlobalSettingsString)),
"users_settings.html": template.Must(template.New("users_settings").Funcs(funcs).Parse(tmplBaseString + tmplUsersSettingsString)),
"status.html": template.Must(template.New("status").Funcs(funcs).Parse(tmplBaseString + tmplStatusString)),
}
// Parse the log level from environment (default INFO).
lvl, err := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO"))
if err != nil {
log.Fatal(err)
}
// Configure the logger middleware.
logConfig := middleware.DefaultLoggerConfig
logConfig.Skipper = func(c echo.Context) bool {
resp := c.Response()
if resp.Status >= 500 && lvl > log.ERROR { // do not log if response is 5XX but log level is higher than ERROR
if resp.Status >= 500 && lvl > log.ERROR {
return true
} else if resp.Status >= 400 && lvl > log.WARN { // do not log if response is 4XX but log level is higher than WARN
} else if resp.Status >= 400 && lvl > log.WARN {
return true
} else if lvl > log.DEBUG { // do not log if log level is higher than DEBUG
} else if lvl > log.DEBUG {
return true
}
return false
@ -135,8 +164,8 @@ func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte) *echo
e.Pre(middleware.RemoveTrailingSlash())
e.Use(middleware.LoggerWithConfig(logConfig))
e.HideBanner = true
e.HidePort = lvl > log.INFO // hide the port output if the log level is higher than INFO
e.Validator = NewValidator()
e.HidePort = lvl > log.INFO
e.Validator = NewValidator() // Assume NewValidator is defined elsewhere.
e.Renderer = &TemplateRegistry{
templates: templates,
extraData: extraData,

View File

@ -1,20 +1,22 @@
package router
import "gopkg.in/go-playground/validator.v9"
import (
"gopkg.in/go-playground/validator.v9"
)
// NewValidator func
// Validator is a custom validator that wraps the go-playground validator.
type Validator struct {
validator *validator.Validate
}
// Validate validates the given struct and returns an error if any validation constraints fail.
func (v *Validator) Validate(i interface{}) error {
return v.validator.Struct(i)
}
// NewValidator creates and returns a new instance of Validator.
func NewValidator() *Validator {
return &Validator{
validator: validator.New(),
}
}
// Validator struct
type Validator struct {
validator *validator.Validate
}
// Validate func
func (v *Validator) Validate(i interface{}) error {
return v.validator.Struct(i)
}

View File

@ -4,22 +4,36 @@ import (
"github.com/swissmakers/wireguard-manager/model"
)
// IStore defines the interface for data storage used in the application.
// It abstracts the methods for user management, server configuration,
// client management, and hash tracking.
type IStore interface {
// Initialization
Init() error
// User Management
GetUsers() ([]model.User, error)
GetUserByName(username string) (model.User, error)
SaveUser(user model.User) error
DeleteUser(username string) error
// Global Settings and Server Configuration
GetGlobalSettings() (model.GlobalSetting, error)
GetServer() (model.Server, error)
SaveServerInterface(serverInterface model.ServerInterface) error
SaveServerKeyPair(serverKeyPair model.ServerKeypair) error
SaveGlobalSettings(globalSettings model.GlobalSetting) error
// Client Management
GetClients(hasQRCode bool) ([]model.ClientData, error)
GetClientByID(clientID string, qrCode model.QRCodeSettings) (model.ClientData, error)
SaveClient(client model.Client) error
DeleteClient(clientID string) error
SaveServerInterface(serverInterface model.ServerInterface) error
SaveServerKeyPair(serverKeyPair model.ServerKeypair) error
SaveGlobalSettings(globalSettings model.GlobalSetting) error
// File Storage Path
GetPath() string
// Hash Management for Config Change Detection
SaveHashes(hashes model.ClientServerHashes) error
GetHashes() (model.ClientServerHashes, error)
}

View File

@ -27,7 +27,7 @@
<!-- START: On page css -->
{{template "top_css" .}}
<!-- END: On page css -->
<style>
<style>
/* Base Dark Mode Styles */
body, .content-wrapper {
background-color: #121212;
@ -78,9 +78,9 @@
}
/* Modify inputs and form elements */
input, select, textarea, .form-control, .form-control:disabled, div.tagsinput {
background-color: #333333;
background-color: #333333 !important;
color: #e0e0e0 !important;
border: 1px solid #555;
border: 1px solid #555 !important;
}
input::placeholder, select::placeholder, textarea::placeholder {
color: #b0b0b0;
@ -123,7 +123,7 @@
color: #e0e0e0;
border: 1px solid #444;
}
</style>
</style>
</head>
<body class="hold-transition sidebar-mini">
@ -165,11 +165,9 @@
<!-- Right navbar links -->
<div class="navbar-nav ml-auto">
<button style="margin-left: 0.5em;" type="button" class="btn btn-outline-primary btn-sm" data-toggle="modal"
data-target="#modal_new_client"><i class="nav-icon fas fa-plus"></i> New
Client</button>
data-target="#modal_new_client"><i class="nav-icon fas fa-plus"></i> New Client</button>
<button id="apply-config-button" style="margin-left: 0.5em; display: none;" type="button" class="btn btn-outline-danger btn-sm" data-toggle="modal"
data-target="#modal_apply_config"><i class="nav-icon fas fa-check"></i> Apply
Config</button>
data-target="#modal_apply_config"><i class="nav-icon fas fa-check"></i> Apply Config</button>
{{if .baseData.CurrentUser}}
<button onclick="location.href='{{.basePath}}/logout';" style="margin-left: 0.5em;" type="button"
class="btn btn-outline-danger btn-sm"><i class="nav-icon fas fa-sign-out-alt"></i> Logout</button>
@ -194,13 +192,7 @@
</div>
<div class="info">
{{if .baseData.CurrentUser}}
{{if .baseData.Admin}}
<a href="{{.basePath}}/profile" class="d-block">My Account: {{.baseData.CurrentUser}}</a>
{{else}}
<a href="{{.basePath}}/profile" class="d-block">My Account: {{.baseData.CurrentUser}}</a>
{{end}}
{{else}}
<a href="#" class="d-block">My Account</a>
{{end}}
@ -214,17 +206,13 @@
<li class="nav-item">
<a href="{{.basePath}}/" class="nav-link {{if eq .baseData.Active ""}}active{{end}}">
<i class="nav-icon fas fa-user-secret"></i>
<p>
VPN Clients
</p>
<p>VPN Clients</p>
</a>
</li>
<li class="nav-item">
<a href="{{.basePath}}/status" class="nav-link {{if eq .baseData.Active "status" }}active{{end}}">
<i class="nav-icon fas fa-signal"></i>
<p>
VPN Connected
</p>
<p>VPN Connected</p>
</a>
</li>
{{if .baseData.Admin}}
@ -232,26 +220,20 @@
<li class="nav-item">
<a href="{{.basePath}}/wg-server" class="nav-link {{if eq .baseData.Active "wg-server" }}active{{end}}">
<i class="nav-icon fas fa-server"></i>
<p>
WireGuard Server
</p>
<p>WireGuard Server</p>
</a>
</li>
<li class="nav-item">
<a href="{{.basePath}}/global-settings" class="nav-link {{if eq .baseData.Active "global-settings" }}active{{end}}">
<i class="nav-icon fas fa-cog"></i>
<p>
Client Config Settings
</p>
<p>Client Config Settings</p>
</a>
</li>
{{if not .loginDisabled}}
<li class="nav-item">
<a href="{{.basePath}}/users-settings" class="nav-link {{if eq .baseData.Active "users-settings" }}active{{end}}">
<i class="nav-icon fas fa-cog"></i>
<p>
WGM User Accounts
</p>
<i class="nav-icon fas fa-cog"></i>
<p>WGM User Accounts</p>
</a>
</li>
{{end}}
@ -274,6 +256,7 @@
</div>
<form name="frm_new_client" id="frm_new_client">
<div class="modal-body">
<!-- Form fields for new client go here -->
<div class="form-group">
<label for="client_name" class="control-label">Name</label>
<input type="text" class="form-control" id="client_name" name="client_name">
@ -284,8 +267,7 @@
</div>
<div class="form-group">
<label for="subnet_ranges" class="control-label">Subnet range</label>
<select id="subnet_ranges" class="select2"
data-placeholder="Select a subnet range" style="width: 100%;">
<select id="subnet_ranges" class="select2" data-placeholder="Select a subnet range" style="width: 100%;">
</select>
</div>
<div class="form-group">
@ -294,22 +276,17 @@
</div>
<div class="form-group">
<label for="client_allowed_ips" class="control-label">Allowed IPs
<i class="fas fa-info-circle" data-toggle="tooltip"
data-original-title="Specify a list of addresses that will get routed to the
server. These addresses will be included in 'AllowedIPs' of client config">
<i class="fas fa-info-circle" data-toggle="tooltip" data-original-title="Specify a list of addresses that will get routed to the server. These addresses will be included in 'AllowedIPs' of client config">
</i>
</label>
<input type="text" data-role="tagsinput" class="form-control" id="client_allowed_ips"
value="{{ StringsJoin .client_defaults.AllowedIps "," }}">
<input type="text" data-role="tagsinput" class="form-control" id="client_allowed_ips" value="{{ StringsJoin .client_defaults.AllowedIPs "," }}">
</div>
<div class="form-group">
<label for="client_extra_allowed_ips" class="control-label">Extra Allowed IPs
<i class="fas fa-info-circle" data-toggle="tooltip"
data-original-title="Specify a list of addresses that will get routed to the
client. These addresses will be included in 'AllowedIPs' of WG server config">
<i class="fas fa-info-circle" data-toggle="tooltip" data-original-title="Specify a list of addresses that will get routed to the client. These addresses will be included in 'AllowedIPs' of WG server config">
</i>
</label>
<input type="text" data-role="tagsinput" class="form-control" id="client_extra_allowed_ips" value="{{ StringsJoin .client_defaults.ExtraAllowedIps "," }}">
<input type="text" data-role="tagsinput" class="form-control" id="client_extra_allowed_ips" value="{{ StringsJoin .client_defaults.ExtraAllowedIPs "," }}">
</div>
<div class="form-group">
<label for="client_endpoint" class="control-label">Endpoint</label>
@ -318,43 +295,32 @@
<div class="form-group">
<div class="icheck-primary d-inline">
<input type="checkbox" id="use_server_dns" {{ if .client_defaults.UseServerDNS }}checked{{ end }}>
<label for="use_server_dns">
Use server DNS
</label>
<label for="use_server_dns">Use server DNS</label>
</div>
</div>
<div class="form-group">
<div class="icheck-primary d-inline">
<input type="checkbox" id="enabled" {{ if .client_defaults.EnableAfterCreation }}checked{{ end }}>
<label for="enabled">
Enable after creation
</label>
<label for="enabled">Enable after creation</label>
</div>
</div>
<details>
<summary><strong>Public and Preshared Keys</strong>
<i class="fas fa-info-circle" data-toggle="tooltip"
data-original-title="If you don't want to let the server generate and store the
client's private key, you can manually specify its public and preshared key here
. Note: QR code will not be generated">
<summary>
<strong>Public and Preshared Keys</strong>
<i class="fas fa-info-circle" data-toggle="tooltip" data-original-title="If you don't want the server to generate and store the client's private key, you can manually specify its public and preshared key here. Note: QR code will not be generated">
</i>
</summary>
<div class="form-group" style="margin-top: 1rem">
<label for="client_public_key" class="control-label">
Public Key
</label>
<label for="client_public_key" class="control-label">Public Key</label>
<input type="text" class="form-control" id="client_public_key" name="client_public_key" placeholder="Autogenerated" aria-invalid="false">
</div>
<div class="form-group">
<label for="client_preshared_key" class="control-label">
Preshared Key
</label>
<label for="client_preshared_key" class="control-label">Preshared Key</label>
<input type="text" class="form-control" id="client_preshared_key" name="client_preshared_key" placeholder="Autogenerated - enter &quot;-&quot; to skip generation">
</div>
</details>
<details style="margin-top: 0.5rem;">
<summary><strong>Additional configuration</strong>
</summary>
<summary><strong>Additional configuration</strong></summary>
<div class="form-group">
<label for="additional_notes" class="control-label">Notes</label>
<textarea class="form-control" style="min-height: 6rem;" id="additional_notes" name="additional_notes" placeholder="Additional notes about this client"></textarea>
@ -419,8 +385,7 @@
<div class="float-right d-none d-sm-block">
<b>Version</b> {{ .appVersion }}
</div>
<strong>Copyright &copy; <script>document.write(new Date().getFullYear())</script> <a href="https://github.com/swissmakers/wireguard-manager">WireGuard Manager</a>.</strong> All rights
reserved.
<strong>Copyright &copy; <script>document.write(new Date().getFullYear())</script> <a href="https://github.com/swissmakers/wireguard-manager">WireGuard Manager</a>.</strong> All rights reserved.
</footer>
<!-- Control Sidebar -->
@ -463,50 +428,54 @@
`, 'toastrToastStyleFix')
toastr.options.closeDuration = 100;
// toastr.options.timeOut = 10000;
toastr.options.positionClass = 'toast-top-right-fix';
updateApplyConfigVisibility()
// Initial call, and then poll every 5 seconds for config changes
updateApplyConfigVisibility();
setInterval(updateApplyConfigVisibility, 5000);
});
function addGlobalStyle(css, id) {
if (!document.querySelector('#' + id)) {
let head = document.head
if (!head) { return }
let style = document.createElement('style')
style.type = 'text/css'
style.id = id
style.innerHTML = css
head.appendChild(style)
let head = document.head;
if (!head) { return; }
let style = document.createElement('style');
style.type = 'text/css';
style.id = id;
style.innerHTML = css;
head.appendChild(style);
}
}
function updateApplyConfigVisibility() {
$.ajax({
cache: false,
method: 'GET',
url: '{{.basePath}}/test-hash',
dataType: 'json',
contentType: "application/json",
success: function(data) {
if (data.status) {
$("#apply-config-button").show()
}
else
{
$("#apply-config-button").hide()
}
},
error: function(jqXHR, exception) {
const responseJson = jQuery.parseJSON(jqXHR.responseText);
toastr.error(responseJson['message']);
$.ajax({
cache: false,
method: 'GET',
url: '{{.basePath}}/test-hash',
dataType: 'json',
contentType: "application/json",
success: function(data) {
console.log("Config check response:", data);
// Check the 'success' property returned by the endpoint.
if (data.success) {
$("#apply-config-button").show();
} else {
$("#apply-config-button").hide();
}
});
},
error: function(jqXHR, exception) {
try {
const responseJson = JSON.parse(jqXHR.responseText);
toastr.error(responseJson['message']);
} catch (e) {
toastr.error("Error checking config changes.");
}
}
});
}
// populateClient function for render new client info
// on the client page.
// populateClient function for render new client info on the client page.
function populateClient(client_id) {
$.ajax({
cache: false,
@ -518,44 +487,42 @@
renderClientList([resp]);
},
error: function (jqXHR, exception) {
const responseJson = jQuery.parseJSON(jqXHR.responseText);
toastr.error(responseJson['message']);
try {
const responseJson = JSON.parse(jqXHR.responseText);
toastr.error(responseJson['message']);
} catch (e) {
toastr.error("Error loading client data.");
}
}
});
}
// submitNewClient function for new client form submission
// submitNewClient function for new client form submission.
function submitNewClient() {
const name = $("#client_name").val();
const email = $("#client_email").val();
const allocated_ips = $("#client_allocated_ips").val().split(",");
const allowed_ips = $("#client_allowed_ips").val().split(",");
const endpoint = $("#client_endpoint").val();
let use_server_dns = false;
let extra_allowed_ips = [];
if ($("#client_extra_allowed_ips").val() !== "") {
extra_allowed_ips = $("#client_extra_allowed_ips").val().split(",");
}
if ($("#use_server_dns").is(':checked')){
use_server_dns = true;
}
let enabled = false;
if ($("#enabled").is(':checked')){
enabled = true;
}
const use_server_dns = $("#use_server_dns").is(':checked');
const enabled = $("#enabled").is(':checked');
const public_key = $("#client_public_key").val();
const preshared_key = $("#client_preshared_key").val();
const additional_notes = $("#additional_notes").val();
const data = {"name": name, "email": email, "allocated_ips": allocated_ips, "allowed_ips": allowed_ips,
"extra_allowed_ips": extra_allowed_ips, "endpoint": endpoint, "use_server_dns": use_server_dns, "enabled": enabled,
"public_key": public_key, "preshared_key": preshared_key, "additional_notes": additional_notes};
const data = {
"name": name,
"email": email,
"allocated_ips": allocated_ips,
"allowed_ips": allowed_ips,
"extra_allowed_ips": $("#client_extra_allowed_ips").val().split(","),
"endpoint": endpoint,
"use_server_dns": use_server_dns,
"enabled": enabled,
"public_key": public_key,
"preshared_key": preshared_key,
"additional_notes": additional_notes
};
$.ajax({
cache: false,
@ -567,26 +534,30 @@
success: function(resp) {
$("#modal_new_client").modal('hide');
toastr.success('Created new client successfully');
// Update the home page (clients page) after adding successfully
// Update the home page (clients page) after adding successfully.
if (window.location.pathname === "{{.basePath}}/") {
populateClient(resp.id);
}
updateApplyConfigVisibility()
updateApplyConfigVisibility();
},
error: function(jqXHR, exception) {
const responseJson = jQuery.parseJSON(jqXHR.responseText);
toastr.error(responseJson['message']);
try {
const responseJson = JSON.parse(jqXHR.responseText);
toastr.error(responseJson['message']);
} catch (e) {
toastr.error("Error creating client.");
}
}
});
}
// updateIPAllocationSuggestion function for automatically fill
// the IP Allocation input with suggested ip addresses
// updateIPAllocationSuggestion function for automatically filling
// the IP Allocation input with suggested IP addresses.
function updateIPAllocationSuggestion(forceDefault = false) {
let subnetRange = $("#subnet_ranges").select2('val');
if (forceDefault || !subnetRange || subnetRange.length === 0) {
subnetRange = '__default_any__'
subnetRange = '__default_any__';
}
$.ajax({
cache: false,
@ -596,29 +567,33 @@
contentType: "application/json",
success: function(data) {
const allocated_ips = $("#client_allocated_ips").val().split(",");
allocated_ips.forEach(function (item, index) {
$('#client_allocated_ips').removeTag(escape(item));
})
data.forEach(function (item, index) {
allocated_ips.forEach(function (item) {
$('#client_allocated_ips').removeTag(item);
});
data.forEach(function (item) {
$('#client_allocated_ips').addTag(item);
})
});
},
error: function(jqXHR, exception) {
const allocated_ips = $("#client_allocated_ips").val().split(",");
allocated_ips.forEach(function (item, index) {
$('#client_allocated_ips').removeTag(escape(item));
})
const responseJson = jQuery.parseJSON(jqXHR.responseText);
toastr.error(responseJson['message']);
allocated_ips.forEach(function (item) {
$('#client_allocated_ips').removeTag(item);
});
try {
const responseJson = JSON.parse(jqXHR.responseText);
toastr.error(responseJson['message']);
} catch (e) {
toastr.error("Error suggesting IP allocation.");
}
}
});
}
</script>
<script>
//Initialize Select2 Elements
$(".select2").select2()
// Initialize Select2 Elements.
$(".select2").select2();
// IP Allocation tag input
// IP Allocation tag input.
$("#client_allocated_ips").tagsInput({
'width': '100%',
'height': '75%',
@ -630,7 +605,7 @@
'placeholderColor': '#666666'
});
// AllowedIPs tag input
// AllowedIPs tag input.
$("#client_allowed_ips").tagsInput({
'width': '100%',
'height': '75%',
@ -653,7 +628,7 @@
'placeholderColor': '#666666'
});
// New client form validation
// New client form validation.
$(document).ready(function () {
$.validator.setDefaults({
submitHandler: function () {
@ -676,18 +651,18 @@
error.addClass('invalid-feedback');
element.closest('.form-group').append(error);
},
highlight: function (element, errorClass, validClass) {
highlight: function (element) {
$(element).addClass('is-invalid');
},
unhighlight: function (element, errorClass, validClass) {
unhighlight: function (element) {
$(element).removeClass('is-invalid');
}
});
});
// New Client modal event
// New Client modal event.
$(document).ready(function () {
$("#modal_new_client").on('shown.bs.modal', function (e) {
$("#modal_new_client").on('shown.bs.modal', function () {
$("#client_name").val("");
$("#client_email").val("");
$("#client_public_key").val("");
@ -701,13 +676,12 @@
});
});
// handle subnet range select
$('#subnet_ranges').on('select2:select', function (e) {
// console.log('Selected Option: ', $("#subnet_ranges").select2('val'));
// Handle subnet range select.
$('#subnet_ranges').on('select2:select', function () {
updateIPAllocationSuggestion();
});
// apply_config_confirm button event
// apply_config_confirm button event.
$(document).ready(function () {
$("#apply_config_confirm").click(function () {
$.ajax({
@ -717,19 +691,22 @@
dataType: 'json',
contentType: "application/json",
success: function(data) {
updateApplyConfigVisibility()
updateApplyConfigVisibility();
$("#modal_apply_config").modal('hide');
toastr.success('Applied config successfully');
},
error: function(jqXHR, exception) {
const responseJson = jQuery.parseJSON(jqXHR.responseText);
toastr.error(responseJson['message']);
error: function(jqXHR) {
try {
const responseJson = JSON.parse(jqXHR.responseText);
toastr.error(responseJson['message']);
} catch (e) {
toastr.error("Error applying config.");
}
}
});
});
});
</script>
<!-- START: On page script -->
{{template "bottom_js" .}}
<!-- END: On page script -->

View File

@ -1,189 +1,146 @@
<!DOCTYPE html>
<html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<title>Login</title>
<!-- Tell the browser to be responsive to screen width -->
<meta name="viewport" content="width=device-width, initial-scale=1">
<!-- Favicon -->
<link rel="icon" href="{{.basePath}}/favicon">
<!-- Font Awesome -->
<link rel="stylesheet" href="{{.basePath}}/static/plugins/fontawesome-free/css/all.min.css">
<!-- icheck bootstrap -->
<link rel="stylesheet" href="{{.basePath}}/static/plugins/icheck-bootstrap/icheck-bootstrap.min.css">
<!-- Theme style -->
<link rel="stylesheet" href="{{.basePath}}/static/dist/css/adminlte.min.css">
<style>
/* Base Dark Mode Styles */
body, .content-wrapper, .login-page {
background-color: #121212;
color: #e0e0e0;
}
.main-footer {
background-color: #1c1c1c;
color: #e0e0e0;
}
.card {
background-color: #2a2a2a;
color: #e0e0e0;
}
/* Dark mode for buttons */
.btn-outline-primary {
border-color: #4e73df;
color: #4e73df;
}
.btn-outline-primary:hover {
background-color: #4e73df;
color: #ffffff;
}
.btn-outline-danger {
border-color: #e74a3b;
color: #e74a3b;
}
.btn-outline-danger:hover {
background-color: #e74a3b;
color: #ffffff;
}
/* Modify inputs and form elements */
input, select, textarea, .form-control, .form-control:disabled, div.tagsinput {
background-color: #333333;
color: #e0e0e0;
border: 1px solid #555;
}
input::placeholder, select::placeholder, textarea::placeholder {
color: #b0b0b0;
}
input[type="checkbox"], input[type="radio"] {
background-color: #444;
}
/* Modal dark mode */
.modal-content {
background-color: #2a2a2a;
color: #e0e0e0;
}
.modal-header {
border-bottom: 1px solid #555;
}
.modal-footer {
border-top: 1px solid #555;
}
/* Dark mode for the sidebar active state */
.nav-sidebar .nav-link.active {
background-color: #444;
}
/* Table dark mode */
table {
background-color: #2a2a2a;
}
table th, table td {
color: #e0e0e0;
border: 1px solid #444;
}
</style>
</head>
<body class="hold-transition login-page">
<div class="login-box">
<div class="card">
<div class="card-body login-card-body">
<p class="login-box-msg">Sign in to start your session</p>
<form action="" method="post">
<div class="input-group mb-3">
<input id="username" type="text" class="form-control" placeholder="Username">
<div class="input-group-append">
<div class="input-group-text">
<span class="fas fa-envelope"></span>
</div>
</div>
</div>
<div class="input-group mb-3">
<input id="password" type="password" class="form-control" placeholder="Password">
<div class="input-group-append">
<div class="input-group-text">
<span class="fas fa-lock"></span>
</div>
</div>
</div>
<div class="row">
<div class="col-8">
<div class="icheck-primary">
<input type="checkbox" id="remember">
<label for="remember">
Remember Me
</label>
</div>
</div>
<!-- /.col -->
<div class="col-4">
<button id="btn_login" type="submit" class="btn btn-primary btn-block">Sign In</button>
</div>
<!-- /.col -->
</div>
</form>
<div class="text-center mb-3">
<p id="message"></p>
</div>
</div>
<!-- /.login-card-body -->
</div>
</div>
<!-- /.login-box -->
<!-- jQuery -->
<script src="{{.basePath}}/static/plugins/jquery/jquery.min.js"></script>
<!-- Bootstrap 4 -->
<script src="{{.basePath}}/static/plugins/bootstrap/js/bootstrap.bundle.min.js"></script>
<!-- AdminLTE App -->
<script src="{{.basePath}}/static/dist/js/adminlte.min.js"></script>
</body>
<script>
function redirectNext() {
const urlParams = new URLSearchParams(window.location.search);
const nextURL = urlParams.get('next');
if (nextURL && /(?:^\/[a-zA-Z_])|(?:^\/$)/.test(nextURL.trim())) {
window.location.href = nextURL;
} else {
window.location.href = '/{{.basePath}}';
}
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<title>Login</title>
<!-- Responsive -->
<meta name="viewport" content="width=device-width, initial-scale=1">
<!-- Favicon -->
<link rel="icon" href="{{.basePath}}/favicon">
<!-- Font Awesome -->
<link rel="stylesheet" href="{{.basePath}}/static/plugins/fontawesome-free/css/all.min.css">
<!-- icheck bootstrap -->
<link rel="stylesheet" href="{{.basePath}}/static/plugins/icheck-bootstrap/icheck-bootstrap.min.css">
<!-- Theme style -->
<link rel="stylesheet" href="{{.basePath}}/static/dist/css/adminlte.min.css">
<style>
/* Dark Mode Styles */
.login-card-body, .register-card-body {
background-color: #2b2b2b;
}
</script>
<script>
.login-card-body, .register-card-body {
color: #cfcfcf;
}
body, .content-wrapper, .login-page {
background-color: #121212;
color: #e0e0e0;
}
/* Buttons */
.btn-outline-primary {
border-color: #4e73df;
color: #4e73df;
}
.btn-outline-primary:hover {
background-color: #4e73df;
color: #ffffff;
}
/* Form elements */
input, select, textarea, .form-control, .form-control:disabled, div.tagsinput {
background-color: #333333 !important;
color: #e0e0e0 !important;
}
input::placeholder, select::placeholder, textarea::placeholder {
color: #b0b0b0;
}
input[type="checkbox"], input[type="radio"] {
background-color: #444;
}
</style>
</head>
<body class="hold-transition login-page">
<div class="login-box">
<div class="card">
<div class="card-body login-card-body">
<p class="login-box-msg">Sign in to start your session</p>
<form id="loginForm" method="post" novalidate>
<div class="input-group mb-3">
<input id="username" type="text" class="form-control" placeholder="Username" required>
<div class="input-group-append">
<div class="input-group-text">
<span class="fas fa-envelope"></span>
</div>
</div>
</div>
<div class="input-group mb-3">
<input id="password" type="password" class="form-control" placeholder="Password" required>
<div class="input-group-append">
<div class="input-group-text">
<span class="fas fa-lock"></span>
</div>
</div>
</div>
<div class="row">
<div class="col-8">
<div class="icheck-primary">
<input type="checkbox" id="remember">
<label for="remember">Remember Me</label>
</div>
</div>
<div class="col-4">
<button id="btn_login" type="submit" class="btn btn-primary btn-block">Sign In</button>
</div>
</div>
</form>
<div class="text-center mb-3">
<p id="message"></p>
</div>
</div>
</div>
</div>
<!-- jQuery -->
<script src="{{.basePath}}/static/plugins/jquery/jquery.min.js"></script>
<!-- Bootstrap 4 -->
<script src="{{.basePath}}/static/plugins/bootstrap/js/bootstrap.bundle.min.js"></script>
<!-- AdminLTE App -->
<script src="{{.basePath}}/static/dist/js/adminlte.min.js"></script>
<script>
// Redirect based on 'next' URL parameter; default to basePath.
function redirectNext() {
const urlParams = new URLSearchParams(window.location.search);
const nextURL = urlParams.get('next');
if (nextURL && nextURL.trim().startsWith("/")) {
window.location.href = nextURL.trim();
} else {
window.location.href = '{{.basePath}}';
}
}
</script>
<script>
$(document).ready(function () {
$('form').on('submit', function(e) {
e.preventDefault();
$("#btn_login").trigger('click');
});
$("#btn_login").click(function () {
const username = $("#username").val();
const password = $("#password").val();
let rememberMe = false;
if ($("#remember").is(':checked')){
rememberMe = true;
// Override default form submission.
$("#loginForm").on('submit', function(e) {
e.preventDefault();
$("#btn_login").trigger('click');
});
$("#btn_login").click(function () {
const username = $("#username").val().trim();
const password = $("#password").val();
const rememberMe = $("#remember").is(':checked');
const data = { "username": username, "password": password, "rememberMe": rememberMe };
$.ajax({
cache: false,
method: 'POST',
url: '{{.basePath}}/login',
dataType: 'json',
contentType: "application/json",
data: JSON.stringify(data),
success: function(response) {
$("#message").html(`<p style="color:green">${response.message}</p>`);
redirectNext();
},
error: function(jqXHR) {
let response;
try {
response = JSON.parse(jqXHR.responseText);
} catch (error) {
response = { message: "An unexpected error occurred." };
}
const data = {"username": username, "password": password, "rememberMe": rememberMe}
$.ajax({
cache: false,
method: 'POST',
url: '{{.basePath}}/login',
dataType: 'json',
contentType: "application/json",
data: JSON.stringify(data),
success: function(data) {
document.getElementById("message").innerHTML = `<p style="color:green">${data['message']}</p>`;
// redirect after logging in successfully
redirectNext();
},
error: function(jqXHR, exception) {
const responseJson = jQuery.parseJSON(jqXHR.responseText);
document.getElementById("message").innerHTML = `<p style="color:#ff0000">${responseJson['message']}</p>`;
}
});
$("#message").html(`<p style="color:#ff0000">${response.message}</p>`);
}
});
});
});
</script>
</script>
</body>
</html>

View File

@ -1,4 +1,19 @@
package util
var IPToSubnetRange = map[string]uint16{}
var DBUsersToCRC32 = map[string]uint32{}
import "sync"
// IPToSubnetRange caches a mapping from an IP address (as a string) to a subnet range index.
// Note: This global map is not thread-safe by default. Use ipToSubnetRangeMutex for concurrent access.
var IPToSubnetRange = make(map[string]uint16)
// DBUsersToCRC32 caches a mapping from a username to its corresponding CRC32 hash value.
// Note: This global map is not thread-safe by default. Use dbUsersToCRC32Mutex for concurrent access.
var DBUsersToCRC32 = make(map[string]uint32)
// Mutexes to protect concurrent access to the caches.
// Use ipToSubnetRangeMutex when reading from or writing to IPToSubnetRange,
// and use dbUsersToCRC32Mutex for DBUsersToCRC32.
var (
ipToSubnetRangeMutex sync.RWMutex
//dbUsersToCRC32Mutex sync.RWMutex
)

View File

@ -7,7 +7,7 @@ import (
"github.com/labstack/gommon/log"
)
// Runtime config
// Global runtime configuration variables.
var (
DisableLogin bool
Proxy bool
@ -27,10 +27,11 @@ var (
SessionMaxDuration int64
WgConfTemplate string
BasePath string
SubnetRanges map[string]([]*net.IPNet)
SubnetRangesOrder []string
SubnetRanges map[string][]*net.IPNet // Mapping of range name to slice of *net.IPNet
SubnetRangesOrder []string // Order of subnet range names
)
// Default values and environment variable names.
const (
DefaultUsername = "admin"
DefaultPassword = "admin"
@ -40,7 +41,7 @@ const (
DefaultDNS = "1.1.1.1"
DefaultMTU = 1450
DefaultPersistentKeepalive = 15
DefaultFirewallMark = "0xca6c" // i.e. 51820
DefaultFirewallMark = "0xca6c" // e.g. 51820
DefaultTable = "auto"
DefaultConfigFilePath = "/etc/wireguard/wg0.conf"
UsernameEnvVar = "WGM_USERNAME"
@ -67,51 +68,69 @@ const (
DefaultClientEnableAfterCreationEnvVar = "WGM_DEFAULT_CLIENT_ENABLE_AFTER_CREATION"
)
// ParseBasePath ensures that the base path starts with a slash and does not end with one.
func ParseBasePath(basePath string) string {
if !strings.HasPrefix(basePath, "/") {
basePath = "/" + basePath
}
basePath = strings.TrimSuffix(basePath, "/")
return basePath
return strings.TrimSuffix(basePath, "/")
}
func ParseSubnetRanges(subnetRangesStr string) map[string]([]*net.IPNet) {
subnetRanges := map[string]([]*net.IPNet){}
// ParseSubnetRanges parses a string containing subnet ranges into a map of subnet ranges.
// The expected format is:
//
// rangeName:CIDR1,CIDR2;rangeName2:CIDR3,CIDR4
//
// It returns a map from the range name to a slice of *net.IPNet and populates SubnetRangesOrder.
func ParseSubnetRanges(subnetRangesStr string) map[string][]*net.IPNet {
subnetRanges := make(map[string][]*net.IPNet)
// Reset the global order.
SubnetRangesOrder = []string{}
if subnetRangesStr == "" {
return subnetRanges
}
cidrSet := map[string]bool{}
// Clean the input string.
subnetRangesStr = strings.TrimSpace(subnetRangesStr)
subnetRangesStr = strings.Trim(subnetRangesStr, ";:,")
ranges := strings.Split(subnetRangesStr, ";")
// Use a set to track duplicate CIDRs.
cidrSet := make(map[string]bool)
for _, rng := range ranges {
rng = strings.TrimSpace(rng)
rngSpl := strings.Split(rng, ":")
if len(rngSpl) != 2 {
parts := strings.Split(rng, ":")
if len(parts) != 2 {
log.Warnf("Unable to parse subnet range: %v. Skipped.", rng)
continue
}
rngName := strings.TrimSpace(rngSpl[0])
subnetRanges[rngName] = make([]*net.IPNet, 0)
cidrs := strings.Split(rngSpl[1], ",")
rangeName := strings.TrimSpace(parts[0])
subnetRanges[rangeName] = []*net.IPNet{}
// Split the CIDRs by comma.
cidrs := strings.Split(parts[1], ",")
for _, cidr := range cidrs {
cidr = strings.TrimSpace(cidr)
_, net, err := net.ParseCIDR(cidr)
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
log.Warnf("[%v] Unable to parse CIDR: %v. Skipped.", rngName, cidr)
log.Warnf("[%v] Unable to parse CIDR: %v. Skipped.", rangeName, cidr)
continue
}
if cidrSet[net.String()] {
log.Warnf("[%v] CIDR already exists: %v. Skipped.", rngName, net.String())
if cidrSet[ipnet.String()] {
log.Warnf("[%v] CIDR already exists: %v. Skipped.", rangeName, ipnet.String())
continue
}
cidrSet[net.String()] = true
subnetRanges[rngName] = append(subnetRanges[rngName], net)
cidrSet[ipnet.String()] = true
subnetRanges[rangeName] = append(subnetRanges[rangeName], ipnet)
}
if len(subnetRanges[rngName]) == 0 {
delete(subnetRanges, rngName)
// Remove the range if no valid CIDRs were found.
if len(subnetRanges[rangeName]) == 0 {
delete(subnetRanges, rangeName)
} else {
SubnetRangesOrder = append(SubnetRangesOrder, rngName)
SubnetRangesOrder = append(SubnetRangesOrder, rangeName)
}
}
return subnetRanges

View File

@ -2,27 +2,37 @@ package util
import (
"encoding/base64"
"errors"
"fmt"
"golang.org/x/crypto/bcrypt"
)
const BcryptCost = 14 // Bcrypt cost factor (adjust as needed)
// HashPassword hashes the provided plaintext password using bcrypt and returns
// a base64-encoded hash. Returns an error if hashing fails or if the password is empty.
func HashPassword(plaintext string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), 14)
if plaintext == "" {
return "", fmt.Errorf("password cannot be empty")
}
hashed, err := bcrypt.GenerateFromPassword([]byte(plaintext), BcryptCost)
if err != nil {
return "", fmt.Errorf("cannot hash password: %w", err)
}
return base64.StdEncoding.EncodeToString(bytes), nil
return base64.StdEncoding.EncodeToString(hashed), nil
}
func VerifyHash(base64Hash string, plaintext string) (bool, error) {
// VerifyHash compares a plaintext password with a base64-encoded bcrypt hash.
// It returns true if the password matches the hash. If the password does not match,
// it returns false with no error.
func VerifyHash(base64Hash, plaintext string) (bool, error) {
hash, err := base64.StdEncoding.DecodeString(base64Hash)
if err != nil {
return false, fmt.Errorf("cannot decode base64 hash: %w", err)
}
err = bcrypt.CompareHashAndPassword(hash, []byte(plaintext))
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
if err == bcrypt.ErrMismatchedHashAndPassword {
return false, nil
}
if err != nil {

View File

@ -21,18 +21,22 @@ import (
"time"
"github.com/chmike/domain"
"github.com/swissmakers/wireguard-manager/model"
"github.com/swissmakers/wireguard-manager/store"
"golang.org/x/mod/sumdb/dirhash"
externalip "github.com/glendc/go-external-ip"
"github.com/labstack/gommon/log"
"github.com/sdomino/scribble"
"github.com/swissmakers/wireguard-manager/model"
)
// BuildClientConfig to create wireguard client config string
//
// Client Configuration Building
//
// BuildClientConfig creates the WireGuard client configuration as a string.
func BuildClientConfig(client model.Client, server model.Server, setting model.GlobalSetting) string {
// Interface section
// [Interface] section
clientAddress := fmt.Sprintf("Address = %s\n", strings.Join(client.AllocatedIPs, ","))
clientPrivateKey := fmt.Sprintf("PrivateKey = %s\n", client.PrivateKey)
clientDNS := ""
@ -44,13 +48,12 @@ func BuildClientConfig(client model.Client, server model.Server, setting model.G
clientMTU = fmt.Sprintf("MTU = %d\n", setting.MTU)
}
// Peer section
// [Peer] section
peerPublicKey := fmt.Sprintf("PublicKey = %s\n", server.KeyPair.PublicKey)
peerPresharedKey := ""
if client.PresharedKey != "" {
peerPresharedKey = fmt.Sprintf("PresharedKey = %s\n", client.PresharedKey)
}
peerAllowedIPs := fmt.Sprintf("AllowedIPs = %s\n", strings.Join(client.AllowedIPs, ","))
desiredHost := setting.EndpointAddress
@ -65,13 +68,11 @@ func BuildClientConfig(client model.Client, server model.Server, setting model.G
}
}
peerEndpoint := fmt.Sprintf("Endpoint = %s:%d\n", desiredHost, desiredPort)
peerPersistentKeepalive := ""
if setting.PersistentKeepalive > 0 {
peerPersistentKeepalive = fmt.Sprintf("PersistentKeepalive = %d\n", setting.PersistentKeepalive)
}
// build the config as string
strConfig := "[Interface]\n" +
clientAddress +
clientPrivateKey +
@ -87,88 +88,82 @@ func BuildClientConfig(client model.Client, server model.Server, setting model.G
return strConfig
}
// ClientDefaultsFromEnv to read the default values for creating a new client from the environment or use sane defaults
// ClientDefaultsFromEnv returns default client creation values from environment variables or sane defaults.
func ClientDefaultsFromEnv() model.ClientDefaults {
clientDefaults := model.ClientDefaults{}
clientDefaults.AllowedIps = LookupEnvOrStrings(DefaultClientAllowedIpsEnvVar, []string{"0.0.0.0/0"})
clientDefaults.ExtraAllowedIps = LookupEnvOrStrings(DefaultClientExtraAllowedIpsEnvVar, []string{})
clientDefaults.UseServerDNS = LookupEnvOrBool(DefaultClientUseServerDNSEnvVar, true)
clientDefaults.EnableAfterCreation = LookupEnvOrBool(DefaultClientEnableAfterCreationEnvVar, true)
return clientDefaults
return model.ClientDefaults{
AllowedIPs: LookupEnvOrStrings(DefaultClientAllowedIpsEnvVar, []string{"0.0.0.0/0"}),
ExtraAllowedIPs: LookupEnvOrStrings(DefaultClientExtraAllowedIpsEnvVar, []string{}),
UseServerDNS: LookupEnvOrBool(DefaultClientUseServerDNSEnvVar, true),
EnableAfterCreation: LookupEnvOrBool(DefaultClientEnableAfterCreationEnvVar, true),
}
}
// ContainsCIDR to check if ipnet1 contains ipnet2
// https://stackoverflow.com/a/40406619/6111641
// https://go.dev/play/p/Q4J-JEN3sF
//
// CIDR and IP Validation
//
// ContainsCIDR returns true if ipnet1 completely contains ipnet2.
func ContainsCIDR(ipnet1, ipnet2 *net.IPNet) bool {
ones1, _ := ipnet1.Mask.Size()
ones2, _ := ipnet2.Mask.Size()
return ones1 <= ones2 && ipnet1.Contains(ipnet2.IP)
}
// ValidateCIDR to validate a network CIDR
// ValidateCIDR returns true if the given CIDR is valid.
func ValidateCIDR(cidr string) bool {
_, _, err := net.ParseCIDR(cidr)
return err == nil
}
// ValidateCIDRList to validate a list of network CIDR
// ValidateCIDRList validates a slice of CIDRs.
// If allowEmpty is true, empty strings are allowed.
func ValidateCIDRList(cidrs []string, allowEmpty bool) bool {
for _, cidr := range cidrs {
if allowEmpty {
if len(cidr) > 0 {
if !ValidateCIDR(cidr) {
return false
}
}
} else {
if !ValidateCIDR(cidr) {
return false
}
if allowEmpty && len(cidr) == 0 {
continue
}
if !ValidateCIDR(cidr) {
return false
}
}
return true
}
// ValidateAllowedIPs to validate allowed ip addresses in CIDR format
// ValidateAllowedIPs validates a list of allowed IP addresses in CIDR format.
func ValidateAllowedIPs(cidrs []string) bool {
return ValidateCIDRList(cidrs, false)
}
// ValidateExtraAllowedIPs to validate extra Allowed ip addresses, allowing empty strings
// ValidateExtraAllowedIPs validates extra allowed IPs, allowing empty strings.
func ValidateExtraAllowedIPs(cidrs []string) bool {
return ValidateCIDRList(cidrs, true)
}
// ValidateServerAddresses to validate allowed ip addresses in CIDR format
// ValidateServerAddresses validates server interface addresses in CIDR format.
func ValidateServerAddresses(cidrs []string) bool {
return ValidateCIDRList(cidrs, false)
}
// ValidateIPAddress to validate the IPv4 and IPv6 address
// ValidateIPAddress checks whether a given string is a valid IPv4 or IPv6 address.
func ValidateIPAddress(ip string) bool {
return net.ParseIP(ip) != nil
}
// ValidateDomainName to validate domain name
// ValidateDomainName checks whether a domain name is valid.
func ValidateDomainName(name string) bool {
return domain.Check(name) == nil
}
// ValidateIPAndSearchDomainAddressList to validate a list of IPv4 and IPv6 addresses plus added search domains
// ValidateIPAndSearchDomainAddressList validates a list of IP addresses followed by search domains.
func ValidateIPAndSearchDomainAddressList(entries []string) bool {
ip := false
domain := false
var ipFound, domainFound bool
for _, entry := range entries {
// ip but not after domain
if ValidateIPAddress(entry) && !domain {
ip = true
if ValidateIPAddress(entry) && !domainFound {
ipFound = true
continue
}
// domain and after ip
if ValidateDomainName(entry) && ip {
domain = true
if ValidateDomainName(entry) && ipFound {
domainFound = true
continue
}
return false
@ -176,19 +171,20 @@ func ValidateIPAndSearchDomainAddressList(entries []string) bool {
return true
}
// GetInterfaceIPs to get local machine's interface ip addresses
//
// Local and Public IP Retrieval
//
// GetInterfaceIPs returns the list of local interface IP addresses (IPv4 only).
func GetInterfaceIPs() ([]model.Interface, error) {
// get machine's interfaces
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var interfaceList []model.Interface
// get interface's ip addresses
for _, i := range ifaces {
addrs, err := i.Addrs()
for _, iface := range ifaces {
addrs, err := iface.Addrs()
if err != nil {
return nil, err
}
@ -207,43 +203,38 @@ func GetInterfaceIPs() ([]model.Interface, error) {
if ip == nil {
continue
}
iface := model.Interface{}
iface.Name = i.Name
iface.IPAddress = ip.String()
interfaceList = append(interfaceList, iface)
interfaceList = append(interfaceList, model.Interface{
Name: iface.Name,
IPAddress: ip.String(),
})
}
}
return interfaceList, err
return interfaceList, nil
}
// GetPublicIP to get machine's public ip address
// GetPublicIP returns the public IP address of the machine using an external consensus.
func GetPublicIP() (model.Interface, error) {
// set time out to 5 seconds
cfg := externalip.ConsensusConfig{}
cfg.Timeout = time.Second * 5
cfg := externalip.ConsensusConfig{Timeout: 5 * time.Second}
consensus := externalip.NewConsensus(&cfg, nil)
// add trusted voters
consensus.AddVoter(externalip.NewHTTPSource("https://checkip.amazonaws.com/"), 1)
consensus.AddVoter(externalip.NewHTTPSource("http://whatismyip.akamai.com"), 1)
consensus.AddVoter(externalip.NewHTTPSource("https://ifconfig.top"), 1)
publicInterface := model.Interface{}
publicInterface.Name = "Public Address"
publicInterface := model.Interface{Name: "Public Address"}
ip, err := consensus.ExternalIP()
if err != nil {
publicInterface.IPAddress = "N/A"
} else {
publicInterface.IPAddress = ip.String()
}
// error handling happened above, no need to pass it through
return publicInterface, nil
}
// GetIPFromCIDR get ip from CIDR
//
// IP Extraction and Allocation
//
// GetIPFromCIDR extracts the IP portion from a CIDR notation.
func GetIPFromCIDR(cidr string) (string, error) {
ip, _, err := net.ParseCIDR(cidr)
if err != nil {
@ -252,24 +243,22 @@ func GetIPFromCIDR(cidr string) (string, error) {
return ip.String(), nil
}
// GetAllocatedIPs to get all ip addresses allocated to clients and server
// GetAllocatedIPs returns all IP addresses allocated to clients and the server.
// The ignoreClientID parameter can be used to exclude a specific client.
func GetAllocatedIPs(ignoreClientID string) ([]string, error) {
allocatedIPs := make([]string, 0)
var allocatedIPs []string
// initialize database directory
dir := "./db"
db, err := scribble.New(dir, nil)
// Initialize the scribble DB.
db, err := scribble.New("./db", nil)
if err != nil {
return nil, err
}
// read server information
serverInterface := model.ServerInterface{}
// Read server interface addresses.
var serverInterface model.ServerInterface
if err := db.Read("server", "interfaces", &serverInterface); err != nil {
return nil, err
}
// append server's addresses to the result
for _, cidr := range serverInterface.Addresses {
ip, err := GetIPFromCIDR(cidr)
if err != nil {
@ -278,19 +267,16 @@ func GetAllocatedIPs(ignoreClientID string) ([]string, error) {
allocatedIPs = append(allocatedIPs, ip)
}
// read client information
// Read clients.
records, err := db.ReadAll("clients")
if err != nil {
return nil, err
}
// append client's addresses to the result
for _, f := range records {
client := model.Client{}
if err := json.Unmarshal(f, &client); err != nil {
for _, record := range records {
var client model.Client
if err := json.Unmarshal(record, &client); err != nil {
return nil, err
}
if client.ID != ignoreClientID {
for _, cidr := range client.AllocatedIPs {
ip, err := GetIPFromCIDR(cidr)
@ -305,7 +291,7 @@ func GetAllocatedIPs(ignoreClientID string) ([]string, error) {
return allocatedIPs, nil
}
// inc from https://play.golang.org/p/m8TNTtygK0
// inc increments an IP address by one.
func inc(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
@ -315,21 +301,17 @@ func inc(ip net.IP) {
}
}
// GetBroadcastIP func to get the broadcast ip address of a network
// GetBroadcastIP computes the broadcast address of a given network.
func GetBroadcastIP(n *net.IPNet) net.IP {
var broadcast net.IP
if len(n.IP) == 4 {
broadcast = net.ParseIP("0.0.0.0").To4()
} else {
broadcast = net.ParseIP("::")
}
for i := 0; i < len(n.IP); i++ {
broadcast := make(net.IP, len(n.IP))
for i := range n.IP {
broadcast[i] = n.IP[i] | ^n.Mask[i]
}
return broadcast
}
// GetBroadcastAndNetworkAddrsLookup get the ip address that can't be used with current server interfaces
// GetBroadcastAndNetworkAddrsLookup returns a map of addresses (broadcast and network addresses)
// for the given interface addresses (CIDRs).
func GetBroadcastAndNetworkAddrsLookup(interfaceAddresses []string) map[string]bool {
list := make(map[string]bool)
for _, ifa := range interfaceAddresses {
@ -337,30 +319,24 @@ func GetBroadcastAndNetworkAddrsLookup(interfaceAddresses []string) map[string]b
if err != nil {
continue
}
broadcastAddr := GetBroadcastIP(netAddr).String()
networkAddr := netAddr.IP.String()
list[broadcastAddr] = true
list[networkAddr] = true
list[GetBroadcastIP(netAddr).String()] = true
list[netAddr.IP.String()] = true
}
return list
}
// GetAvailableIP get the ip address that can be allocated from an CIDR
// We need interfaceAddresses to find real broadcast and network addresses
// GetAvailableIP returns an available IP from the given CIDR that is not allocated and is not a network/broadcast address.
func GetAvailableIP(cidr string, allocatedList, interfaceAddresses []string) (string, error) {
ip, netAddr, err := net.ParseCIDR(cidr)
if err != nil {
return "", err
}
unavailableIPs := GetBroadcastAndNetworkAddrsLookup(interfaceAddresses)
for ip := ip.Mask(netAddr.Mask); netAddr.Contains(ip); inc(ip) {
available := true
suggestedAddr := ip.String()
for _, allocatedAddr := range allocatedList {
if suggestedAddr == allocatedAddr {
available := true
for _, allocated := range allocatedList {
if suggestedAddr == allocated {
available = false
break
}
@ -369,31 +345,22 @@ func GetAvailableIP(cidr string, allocatedList, interfaceAddresses []string) (st
return suggestedAddr, nil
}
}
return "", errors.New("no more available ip address")
}
// ValidateIPAllocation to validate the list of client's ip allocation
// They must have a correct format and available in serverAddresses space
// ValidateIPAllocation validates the client's requested IP allocation.
func ValidateIPAllocation(serverAddresses []string, ipAllocatedList []string, ipAllocationList []string) (bool, error) {
for _, clientCIDR := range ipAllocationList {
ip, _, _ := net.ParseCIDR(clientCIDR)
// clientCIDR must be in CIDR format
if ip == nil {
return false, fmt.Errorf("invalid ip allocation input %s. Must be in CIDR format", clientCIDR)
}
// return false immediately if the ip is already in use (in ipAllocatedList)
for _, item := range ipAllocatedList {
if item == ip.String() {
for _, allocated := range ipAllocatedList {
if allocated == ip.String() {
return false, fmt.Errorf("IP %s already allocated", ip)
}
}
// even if it is not in use, we still need to check if it
// belongs to a network of the server.
var isValid = false
var isValid bool
for _, serverCIDR := range serverAddresses {
_, serverNet, _ := net.ParseCIDR(serverCIDR)
if serverNet.Contains(ip) {
@ -401,49 +368,62 @@ func ValidateIPAllocation(serverAddresses []string, ipAllocatedList []string, ip
break
}
}
// current ip allocation is valid, check the next one
if isValid {
continue
} else {
if !isValid {
return false, fmt.Errorf("IP %s does not belong to any network addresses of WireGuard server", ip)
}
}
return true, nil
}
// findSubnetRangeForIP to find first SR for IP, and cache the match
//
// Subnet Ranges and Client Data Helpers
//
// findSubnetRangeForIP finds the subnet range for a given CIDR.
// It uses a cache (IPToSubnetRange) and the global SubnetRanges and SubnetRangesOrder.
func findSubnetRangeForIP(cidr string) (uint16, error) {
// Parse the provided CIDR.
ip, _, err := net.ParseCIDR(cidr)
if err != nil {
return 0, err
}
ipStr := ip.String()
if srName, ok := IPToSubnetRange[ip.String()]; ok {
return srName, nil
// Check the cache first using a read lock.
ipToSubnetRangeMutex.RLock()
if sr, ok := IPToSubnetRange[ipStr]; ok {
ipToSubnetRangeMutex.RUnlock()
return sr, nil
}
ipToSubnetRangeMutex.RUnlock()
for srIndex, sr := range SubnetRangesOrder {
for _, srCIDR := range SubnetRanges[sr] {
if srCIDR.Contains(ip) {
IPToSubnetRange[ip.String()] = uint16(srIndex)
return uint16(srIndex), nil
// Iterate over the global SubnetRangesOrder to compute the subnet range index.
for index, srName := range SubnetRangesOrder {
cidrList, ok := SubnetRanges[srName]
if !ok {
continue
}
// For each CIDR in the current subnet range, check if it contains the IP.
for _, ipnet := range cidrList {
if ipnet.Contains(ip) {
// Lock for writing and store the computed index in the cache.
ipToSubnetRangeMutex.Lock()
IPToSubnetRange[ipStr] = uint16(index)
ipToSubnetRangeMutex.Unlock()
return uint16(index), nil
}
}
}
return 0, fmt.Errorf("subnet range not found for this IP")
return 0, fmt.Errorf("subnet range not found for IP %s", ipStr)
}
// FillClientSubnetRange to fill subnet ranges client belongs to, does nothing if SRs are not found
// FillClientSubnetRange appends the subnet range names to the client data.
func FillClientSubnetRange(client model.ClientData) model.ClientData {
cl := *client.Client
for _, ip := range cl.AllocatedIPs {
sr, err := findSubnetRangeForIP(ip)
if err != nil {
continue
if sr, err := findSubnetRangeForIP(ip); err == nil {
cl.SubnetRanges = append(cl.SubnetRanges, SubnetRangesOrder[sr])
}
cl.SubnetRanges = append(cl.SubnetRanges, SubnetRangesOrder[sr])
}
return model.ClientData{
Client: &cl,
@ -451,13 +431,11 @@ func FillClientSubnetRange(client model.ClientData) model.ClientData {
}
}
// ValidateAndFixSubnetRanges to check if subnet ranges are valid for the server configuration
// Removes all non-valid CIDRs
// ValidateAndFixSubnetRanges checks and removes non-valid CIDRs from the global SubnetRanges.
func ValidateAndFixSubnetRanges(db store.IStore) error {
if len(SubnetRangesOrder) == 0 {
return nil
}
server, err := db.GetServer()
if err != nil {
return err
@ -471,28 +449,24 @@ func ValidateAndFixSubnetRanges(db store.IStore) error {
}
serverSubnets = append(serverSubnets, netAddr)
}
for _, rng := range SubnetRangesOrder {
cidrs := SubnetRanges[rng]
if len(cidrs) > 0 {
newCIDRs := make([]*net.IPNet, 0)
for _, cidr := range cidrs {
valid := false
for _, serverSubnet := range serverSubnets {
if ContainsCIDR(serverSubnet, cidr) {
valid = true
break
}
}
if valid {
newCIDRs = append(newCIDRs, cidr)
} else {
log.Warnf("[%v] CIDR is outside of all server subnets: %v. Removed.", rng, cidr)
}
}
if len(newCIDRs) > 0 {
SubnetRanges[rng] = newCIDRs
} else {
@ -501,95 +475,83 @@ func ValidateAndFixSubnetRanges(db store.IStore) error {
}
}
}
return nil
}
// GetSubnetRangesString to get a formatted string, representing active subnet ranges
// GetSubnetRangesString returns a formatted string representing active subnet ranges.
func GetSubnetRangesString() string {
if len(SubnetRangesOrder) == 0 {
return ""
}
strB := strings.Builder{}
var sb strings.Builder
for _, rng := range SubnetRangesOrder {
cidrs := SubnetRanges[rng]
if len(cidrs) > 0 {
strB.WriteString(rng)
strB.WriteString(":[")
first := true
for _, cidr := range cidrs {
if !first {
strB.WriteString(", ")
sb.WriteString(rng)
sb.WriteString(":[")
for i, cidr := range cidrs {
if i > 0 {
sb.WriteString(", ")
}
strB.WriteString(cidr.String())
first = false
sb.WriteString(cidr.String())
}
strB.WriteString("] ")
sb.WriteString("] ")
}
}
return strings.TrimSpace(strB.String())
return strings.TrimSpace(sb.String())
}
// WriteWireGuardServerConfig to write WireGuard server config. e.g. wg0.conf
//
// WireGuard Server Configuration File
//
// WriteWireGuardServerConfig writes the WireGuard server configuration (wg.conf) using a template.
// If WgConfTemplate is set, it is used; otherwise, a default embedded template is read.
func WriteWireGuardServerConfig(tmplDir fs.FS, serverConfig model.Server, clientDataList []model.ClientData, usersList []model.User, globalSettings model.GlobalSetting) error {
var tmplWireGuardConf string
// if set, read wg.conf template from WgConfTemplate
if len(WgConfTemplate) > 0 {
fileContentBytes, err := os.ReadFile(WgConfTemplate)
data, err := os.ReadFile(WgConfTemplate)
if err != nil {
return err
}
tmplWireGuardConf = string(fileContentBytes)
tmplWireGuardConf = string(data)
} else {
// read default wg.conf template file to string
fileContent, err := StringFromEmbedFile(tmplDir, "wg.conf")
if err != nil {
return err
}
tmplWireGuardConf = fileContent
}
// escape multiline notes
escapedClientDataList := []model.ClientData{}
// Escape multiline notes.
var escapedClientDataList []model.ClientData
for _, cd := range clientDataList {
if cd.Client.AdditionalNotes != "" {
cd.Client.AdditionalNotes = strings.ReplaceAll(cd.Client.AdditionalNotes, "\n", "\n# ")
}
escapedClientDataList = append(escapedClientDataList, cd)
}
// parse the template
t, err := template.New("wg_config").Parse(tmplWireGuardConf)
tmplParsed, err := template.New("wg_config").Parse(tmplWireGuardConf)
if err != nil {
return err
}
// write config file to disk
f, err := os.Create(globalSettings.ConfigFilePath)
if err != nil {
return err
}
defer f.Close()
config := map[string]interface{}{
"serverConfig": serverConfig,
"clientDataList": escapedClientDataList,
"globalSettings": globalSettings,
"usersList": usersList,
}
err = t.Execute(f, config)
if err != nil {
return err
}
f.Close()
return nil
return tmplParsed.Execute(f, config)
}
//
// Environment Variable Helpers
//
func LookupEnvOrString(key string, defaultVal string) string {
if val, ok := os.LookupEnv(key); ok {
return val
@ -599,22 +561,22 @@ func LookupEnvOrString(key string, defaultVal string) string {
func LookupEnvOrBool(key string, defaultVal bool) bool {
if val, ok := os.LookupEnv(key); ok {
v, err := strconv.ParseBool(val)
if err != nil {
if parsed, err := strconv.ParseBool(val); err == nil {
return parsed
} else {
fmt.Fprintf(os.Stderr, "LookupEnvOrBool[%s]: %v\n", key, err)
}
return v
}
return defaultVal
}
func LookupEnvOrInt(key string, defaultVal int) int {
if val, ok := os.LookupEnv(key); ok {
v, err := strconv.Atoi(val)
if err != nil {
if parsed, err := strconv.Atoi(val); err == nil {
return parsed
} else {
fmt.Fprintf(os.Stderr, "LookupEnvOrInt[%s]: %v\n", key, err)
}
return v
}
return defaultVal
}
@ -626,32 +588,39 @@ func LookupEnvOrStrings(key string, defaultVal []string) []string {
return defaultVal
}
// LookupEnvOrFile reads the content of a file whose path is stored in the environment variable.
func LookupEnvOrFile(key string, defaultVal string) string {
if val, ok := os.LookupEnv(key); ok {
if file, err := os.Open(val); err == nil {
var content string
scanner := bufio.NewScanner(file)
for scanner.Scan() {
content += scanner.Text()
}
return content
f, err := os.Open(val)
if err != nil {
return defaultVal
}
defer f.Close()
var content strings.Builder
scanner := bufio.NewScanner(f)
for scanner.Scan() {
content.WriteString(scanner.Text())
}
return content.String()
}
return defaultVal
}
func StringFromEmbedFile(embed fs.FS, filename string) (string, error) {
file, err := embed.Open(filename)
// StringFromEmbedFile reads a file from an embedded filesystem and returns its content as a string.
func StringFromEmbedFile(efs fs.FS, filename string) (string, error) {
file, err := efs.Open(filename)
if err != nil {
return "", err
}
content, err := io.ReadAll(file)
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
return "", err
}
return string(content), nil
return string(data), nil
}
// ParseLogLevel converts a log level string to log.Lvl.
func ParseLogLevel(lvl string) (log.Lvl, error) {
switch strings.ToLower(lvl) {
case "debug":
@ -669,44 +638,42 @@ func ParseLogLevel(lvl string) (log.Lvl, error) {
}
}
// GetCurrentHash returns current hashes
//
// Hashing and Database Helpers
//
// GetCurrentHash returns current hashes for clients and server configuration.
func GetCurrentHash(db store.IStore) (string, string) {
hashClients, _ := dirhash.HashDir(path.Join(db.GetPath(), "clients"), "prefix", dirhash.Hash1)
files := append([]string(nil), "prefix/global_settings.json", "prefix/interfaces.json", "prefix/keypair.json")
osOpen := func(name string) (io.ReadCloser, error) {
return os.Open(filepath.Join(path.Join(db.GetPath(), "server"), strings.TrimPrefix(name, "prefix")))
}
hashServer, _ := dirhash.Hash1(files, osOpen)
return hashClients, hashServer
}
// HashesChanged returns true if the current hashes differ from those stored in the database.
func HashesChanged(db store.IStore) bool {
old, _ := db.GetHashes()
oldClient := old.Client
oldServer := old.Server
newClient, newServer := GetCurrentHash(db)
if oldClient != newClient {
//fmt.Println("Hash for client differs")
return true
}
if oldServer != newServer {
//fmt.Println("Hash for server differs")
return true
}
return false
return old.Client != newClient || old.Server != newServer
}
// UpdateHashes updates the stored hashes in the database.
func UpdateHashes(db store.IStore) error {
var clientServerHashes model.ClientServerHashes
clientServerHashes.Client, clientServerHashes.Server = GetCurrentHash(db)
return db.SaveHashes(clientServerHashes)
}
//
// Miscellaneous Helpers
//
// RandomString returns a random string of the given length.
func RandomString(length int) string {
var seededRand = rand.New(rand.NewSource(time.Now().UnixNano()))
seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
@ -715,11 +682,13 @@ func RandomString(length int) string {
return string(b)
}
// ManagePerms sets file permissions to 0600.
func ManagePerms(path string) error {
err := os.Chmod(path, 0600)
return err
return os.Chmod(path, 0600)
}
// GetDBUserCRC32 returns a CRC32 checksum of the given user.
// This is used for session verification.
func GetDBUserCRC32(dbuser model.User) uint32 {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
@ -729,28 +698,24 @@ func GetDBUserCRC32(dbuser model.User) uint32 {
return crc32.ChecksumIEEE(buf.Bytes())
}
// ConcatMultipleSlices concatenates multiple byte slices.
func ConcatMultipleSlices(slices ...[]byte) []byte {
var totalLen int
totalLen := 0
for _, s := range slices {
totalLen += len(s)
}
result := make([]byte, totalLen)
var i int
for _, s := range slices {
i += copy(result[i:], s)
}
return result
}
// GetCookiePath returns the cookie path based on BasePath.
func GetCookiePath() string {
cookiePath := BasePath
if cookiePath == "" {
cookiePath = "/"
if BasePath == "" {
return "/"
}
return cookiePath
return BasePath
}

1143
yarn.lock

File diff suppressed because it is too large Load Diff