feat(auth): add OAuth2 login (#276)

* feat(auth): add OAuth2 login with GitHub and Google

* chore(translations): add files for Japanese

* fix(auth): fix link function for GitHub

* feat(oauth): basic oidc implementation

* feat(oauth): oauth guard

* fix: disable image optimizations for logo to prevent caching issues with custom logos

* fix: memory leak while downloading large files

* chore(translations): update translations via Crowdin (#278)

* New translations en-us.ts (Japanese)

* New translations en-us.ts (Japanese)

* New translations en-us.ts (Japanese)

* release: 0.18.2

* doc(translations): Add Japanese README (#279)

* Added Japanese README.

* Added JAPANESE README link to README.md.

* Updated Japanese README.

* Updated Environment Variable Table.

* updated zh-cn README.

* feat(oauth): unlink account

* refactor(oauth): make providers extensible

* fix(oauth): fix discoveryUri error when toggle google-enabled

* feat(oauth): add microsoft and discord as oauth provider

* docs(oauth): update README.md

* docs(oauth): update oauth2-guide.md

* set password to null for new oauth users

* New translations en-us.ts (Japanese) (#281)

* chore(translations): add Polish files

* fix(oauth): fix random username and password

* feat(oauth): add totp

* fix(oauth): fix totp throttle

* fix(oauth): fix qrcode and remove comment

* feat(oauth): add error page

* fix(oauth): i18n of error page

* feat(auth): add OAuth2 login

* fix(auth): fix link function for GitHub

* feat(oauth): basic oidc implementation

* feat(oauth): oauth guard

* feat(oauth): unlink account

* refactor(oauth): make providers extensible

* fix(oauth): fix discoveryUri error when toggle google-enabled

* feat(oauth): add microsoft and discord as oauth provider

* docs(oauth): update README.md

* docs(oauth): update oauth2-guide.md

* set password to null for new oauth users

* fix(oauth): fix random username and password

* feat(oauth): add totp

* fix(oauth): fix totp throttle

* fix(oauth): fix qrcode and remove comment

* feat(oauth): add error page

* fix(oauth): i18n of error page

* refactor: return null instead of `false` in `getIdOfCurrentUser` functiom

* feat: show original oauth error if available

* refactor: run formatter

* refactor(oauth): error message i18n

* refactor(oauth): make OAuth token available
someone may use it (to revoke token or get other info etc.)
also improved the i18n message

* chore(oauth): remove unused import

* chore: add database migration

* fix: missing python installation for nanoid

---------

Co-authored-by: Elias Schneider <login@eliasschneider.com>
Co-authored-by: ふうせん <10260662+fusengum@users.noreply.github.com>
This commit is contained in:
Qing Fu
2023-10-22 22:09:53 +08:00
committed by GitHub
parent d327bc355c
commit 02cd98fa9c
52 changed files with 1983 additions and 161 deletions

View File

@@ -15,6 +15,8 @@ import { UserModule } from "./user/user.module";
import { ClamScanModule } from "./clamscan/clamscan.module";
import { ReverseShareModule } from "./reverseShare/reverseShare.module";
import { AppController } from "./app.controller";
import { OAuthModule } from "./oauth/oauth.module";
import { CacheModule } from "@nestjs/cache-manager";
@Module({
imports: [
@@ -33,10 +35,12 @@ import { AppController } from "./app.controller";
ScheduleModule.forRoot(),
ClamScanModule,
ReverseShareModule,
OAuthModule,
CacheModule.register({
isGlobal: true,
}),
],
controllers:[
AppController,
],
controllers: [AppController],
providers: [
{
provide: APP_GUARD,

View File

@@ -47,7 +47,7 @@ export class AuthController {
const result = await this.authService.signUp(dto);
response = this.addTokensToResponse(
this.authService.addTokensToResponse(
response,
result.refreshToken,
result.accessToken,
@@ -66,7 +66,7 @@ export class AuthController {
const result = await this.authService.signIn(dto);
if (result.accessToken && result.refreshToken) {
response = this.addTokensToResponse(
this.authService.addTokensToResponse(
response,
result.refreshToken,
result.accessToken,
@@ -85,7 +85,7 @@ export class AuthController {
) {
const result = await this.authTotpService.signInTotp(dto);
response = this.addTokensToResponse(
this.authService.addTokensToResponse(
response,
result.refreshToken,
result.accessToken,
@@ -117,11 +117,11 @@ export class AuthController {
) {
const result = await this.authService.updatePassword(
user,
dto.oldPassword,
dto.password,
dto.oldPassword,
);
response = this.addTokensToResponse(response, result.refreshToken);
this.authService.addTokensToResponse(response, result.refreshToken);
return new TokenDTO().from(result);
}
@@ -136,7 +136,7 @@ export class AuthController {
const accessToken = await this.authService.refreshAccessToken(
request.cookies.refresh_token,
);
response = this.addTokensToResponse(response, undefined, accessToken);
this.authService.addTokensToResponse(response, undefined, accessToken);
return new TokenDTO().from({ accessToken });
}
@@ -172,22 +172,4 @@ export class AuthController {
// Note: We use VerifyTotpDTO here because it has both fields we need: password and totp code
return this.authTotpService.disableTotp(user, body.password, body.code);
}
private addTokensToResponse(
response: Response,
refreshToken?: string,
accessToken?: string,
) {
if (accessToken)
response.cookie("access_token", accessToken, { sameSite: "lax" });
if (refreshToken)
response.cookie("refresh_token", refreshToken, {
path: "/api/auth/token",
httpOnly: true,
sameSite: "strict",
maxAge: 1000 * 60 * 60 * 24 * 30 * 3,
});
return response;
}
}

View File

@@ -7,7 +7,12 @@ import { AuthTotpService } from "./authTotp.service";
import { JwtStrategy } from "./strategy/jwt.strategy";
@Module({
imports: [JwtModule.register({}), EmailModule],
imports: [
JwtModule.register({
global: true,
}),
EmailModule,
],
controllers: [AuthController],
providers: [AuthService, AuthTotpService, JwtStrategy],
exports: [AuthService],

View File

@@ -8,6 +8,7 @@ import { JwtService } from "@nestjs/jwt";
import { User } from "@prisma/client";
import { PrismaClientKnownRequestError } from "@prisma/client/runtime/library";
import * as argon from "argon2";
import { Request, Response } from "express";
import * as moment from "moment";
import { ConfigService } from "src/config/config.service";
import { EmailService } from "src/email/email.service";
@@ -27,7 +28,7 @@ export class AuthService {
async signUp(dto: AuthRegisterDTO) {
const isFirstUser = (await this.prisma.user.count()) == 0;
const hash = await argon.hash(dto.password);
const hash = dto.password ? await argon.hash(dto.password) : null;
try {
const user = await this.prisma.user.create({
data: {
@@ -43,7 +44,7 @@ export class AuthService {
);
const accessToken = await this.createAccessToken(user, refreshTokenId);
return { accessToken, refreshToken };
return { accessToken, refreshToken, user };
} catch (e) {
if (e instanceof PrismaClientKnownRequestError) {
if (e.code == "P2002") {
@@ -69,9 +70,16 @@ export class AuthService {
if (!user || !(await argon.verify(user.password, dto.password)))
throw new UnauthorizedException("Wrong email or password");
return this.generateToken(user);
}
async generateToken(user: User, isOAuth = false) {
// TODO: Make all old loginTokens invalid when a new one is created
// Check if the user has TOTP enabled
if (user.totpVerified) {
if (
user.totpVerified &&
!(isOAuth && this.config.get("oauth.ignoreTotp"))
) {
const loginToken = await this.createLoginToken(user.id);
return { loginToken };
@@ -129,9 +137,11 @@ export class AuthService {
});
}
async updatePassword(user: User, oldPassword: string, newPassword: string) {
if (!(await argon.verify(user.password, oldPassword)))
throw new ForbiddenException("Invalid password");
async updatePassword(user: User, newPassword: string, oldPassword?: string) {
const isPasswordValid =
!user.password || !(await argon.verify(user.password, oldPassword));
if (!isPasswordValid) throw new ForbiddenException("Invalid password");
const hash = await argon.hash(newPassword);
@@ -210,4 +220,38 @@ export class AuthService {
return loginToken;
}
addTokensToResponse(
response: Response,
refreshToken?: string,
accessToken?: string,
) {
if (accessToken)
response.cookie("access_token", accessToken, { sameSite: "lax" });
if (refreshToken)
response.cookie("refresh_token", refreshToken, {
path: "/api/auth/token",
httpOnly: true,
sameSite: "strict",
maxAge: 1000 * 60 * 60 * 24 * 30 * 3,
});
}
/**
* Returns the user id if the user is logged in, null otherwise
*/
async getIdOfCurrentUser(request: Request): Promise<string | null> {
if (!request.cookies.access_token) return null;
try {
const payload = await this.jwtService.verifyAsync(
request.cookies.access_token,
{
secret: this.config.get("internal.jwtSecret"),
},
);
return payload.sub;
} catch {
return null;
}
}
}

View File

@@ -22,43 +22,29 @@ export class AuthTotpService {
) {}
async signInTotp(dto: AuthSignInTotpDTO) {
if (!dto.email && !dto.username)
throw new BadRequestException("Email or username is required");
const user = await this.prisma.user.findFirst({
where: {
OR: [{ email: dto.email }, { username: dto.username }],
},
});
if (!user || !(await argon.verify(user.password, dto.password)))
throw new UnauthorizedException("Wrong email or password");
const token = await this.prisma.loginToken.findFirst({
where: {
token: dto.loginToken,
},
include: {
user: true,
},
});
if (!token || token.userId != user.id || token.used)
if (!token || token.used)
throw new UnauthorizedException("Invalid login token");
if (token.expiresAt < new Date())
throw new UnauthorizedException("Login token expired", "token_expired");
// Check the TOTP code
const { totpSecret } = await this.prisma.user.findUnique({
where: { id: user.id },
select: { totpSecret: true },
});
const { totpSecret } = token.user;
if (!totpSecret) {
throw new BadRequestException("TOTP is not enabled");
}
const expected = authenticator.generate(totpSecret);
if (dto.totp !== expected) {
if (!authenticator.check(dto.totp, totpSecret)) {
throw new BadRequestException("Invalid code");
}
@@ -69,9 +55,9 @@ export class AuthTotpService {
});
const { refreshToken, refreshTokenId } =
await this.authService.createRefreshToken(user.id);
await this.authService.createRefreshToken(token.user.id);
const accessToken = await this.authService.createAccessToken(
user,
token.user,
refreshTokenId,
);

View File

@@ -1,7 +1,7 @@
import { IsString } from "class-validator";
import { AuthSignInDTO } from "./authSignIn.dto";
export class AuthSignInTotpDTO extends AuthSignInDTO {
export class AuthSignInTotpDTO {
@IsString()
totp: string;

View File

@@ -1,8 +1,9 @@
import { PickType } from "@nestjs/swagger";
import { IsString } from "class-validator";
import { IsOptional, IsString } from "class-validator";
import { UserDTO } from "src/user/dto/user.dto";
export class UpdatePasswordDTO extends PickType(UserDTO, ["password"]) {
@IsString()
oldPassword: string;
@IsOptional()
oldPassword?: string;
}

View File

@@ -6,13 +6,20 @@ import {
} from "@nestjs/common";
import { Config } from "@prisma/client";
import { PrismaService } from "src/prisma/prisma.service";
import { EventEmitter } from "events";
/**
* ConfigService extends EventEmitter to allow listening for config updates,
* now only `update` event will be emitted.
*/
@Injectable()
export class ConfigService {
export class ConfigService extends EventEmitter {
constructor(
@Inject("CONFIG_VARIABLES") private configVariables: Config[],
private prisma: PrismaService,
) {}
) {
super();
}
get(key: `${string}.${string}`): any {
const configVariable = this.configVariables.filter(
@@ -105,6 +112,8 @@ export class ConfigService {
this.configVariables = await this.prisma.config.findMany();
this.emit("update", key, value);
return updatedVariable;
}
}

View File

@@ -26,7 +26,7 @@ export class LogoService {
fs.promises.writeFile(
`${IMAGES_PATH}/icons/icon-${size}x${size}.png`,
resized,
"binary"
"binary",
);
}
}

View File

@@ -0,0 +1,9 @@
import { IsString } from "class-validator";
export class OAuthCallbackDto {
@IsString()
code: string;
@IsString()
state: string;
}

View File

@@ -0,0 +1,6 @@
export interface OAuthSignInDto {
provider: "github" | "google" | "microsoft" | "discord" | "oidc";
providerId: string;
providerUsername: string;
email: string;
}

View File

@@ -0,0 +1,15 @@
export class ErrorPageException extends Error {
/**
* Exception for redirecting to error page (all i18n key should omit `error.msg` and `error.param` prefix)
* @param key i18n key of message
* @param redirect redirect url
* @param params message params (key)
*/
constructor(
public readonly key: string = "default",
public readonly redirect: string = "/",
public readonly params?: string[],
) {
super("error");
}
}

View File

@@ -0,0 +1,22 @@
import { ArgumentsHost, Catch, ExceptionFilter } from "@nestjs/common";
import { ConfigService } from "../../config/config.service";
import { ErrorPageException } from "../exceptions/errorPage.exception";
@Catch(ErrorPageException)
export class ErrorPageExceptionFilter implements ExceptionFilter {
constructor(private config: ConfigService) {}
catch(exception: ErrorPageException, host: ArgumentsHost) {
const ctx = host.switchToHttp();
const response = ctx.getResponse();
const url = new URL(`${this.config.get("general.appUrl")}/error`);
url.searchParams.set("redirect", exception.redirect);
url.searchParams.set("error", exception.key);
if (exception.params) {
url.searchParams.set("params", exception.params.join(","));
}
response.redirect(url.toString());
}
}

View File

@@ -0,0 +1,31 @@
import {
ArgumentsHost,
Catch,
ExceptionFilter,
HttpException,
} from "@nestjs/common";
import { ConfigService } from "../../config/config.service";
@Catch(HttpException)
export class OAuthExceptionFilter implements ExceptionFilter {
private errorKeys: Record<string, string> = {
access_denied: "access_denied",
expired_token: "expired_token",
};
constructor(private config: ConfigService) {}
catch(_exception: HttpException, host: ArgumentsHost) {
const ctx = host.switchToHttp();
const response = ctx.getResponse();
const request = ctx.getRequest();
const key = this.errorKeys[request.query.error] || "default";
const url = new URL(`${this.config.get("general.appUrl")}/error`);
url.searchParams.set("redirect", "/account");
url.searchParams.set("error", key);
response.redirect(url.toString());
}
}

View File

@@ -0,0 +1,12 @@
import { CanActivate, ExecutionContext, Injectable } from "@nestjs/common";
@Injectable()
export class OAuthGuard implements CanActivate {
constructor() {}
canActivate(context: ExecutionContext): boolean {
const request = context.switchToHttp().getRequest();
const provider = request.params.provider;
return request.query.state === request.cookies[`oauth_${provider}_state`];
}
}

View File

@@ -0,0 +1,24 @@
import {
CanActivate,
ExecutionContext,
Inject,
Injectable,
} from "@nestjs/common";
import { ConfigService } from "../../config/config.service";
@Injectable()
export class ProviderGuard implements CanActivate {
constructor(
private config: ConfigService,
@Inject("OAUTH_PLATFORMS") private platforms: string[],
) {}
canActivate(context: ExecutionContext): boolean {
const request = context.switchToHttp().getRequest();
const provider = request.params.provider;
return (
this.platforms.includes(provider) &&
this.config.get(`oauth.${provider}-enabled`)
);
}
}

View File

@@ -0,0 +1,110 @@
import {
Controller,
Get,
Inject,
Param,
Post,
Query,
Req,
Res,
UseFilters,
UseGuards,
} from "@nestjs/common";
import { User } from "@prisma/client";
import { Request, Response } from "express";
import { nanoid } from "nanoid";
import { AuthService } from "../auth/auth.service";
import { GetUser } from "../auth/decorator/getUser.decorator";
import { JwtGuard } from "../auth/guard/jwt.guard";
import { ConfigService } from "../config/config.service";
import { OAuthCallbackDto } from "./dto/oauthCallback.dto";
import { ErrorPageExceptionFilter } from "./filter/errorPageException.filter";
import { OAuthGuard } from "./guard/oauth.guard";
import { ProviderGuard } from "./guard/provider.guard";
import { OAuthService } from "./oauth.service";
import { OAuthProvider } from "./provider/oauthProvider.interface";
import { OAuthExceptionFilter } from "./filter/oauthException.filter";
@Controller("oauth")
export class OAuthController {
constructor(
private authService: AuthService,
private oauthService: OAuthService,
private config: ConfigService,
@Inject("OAUTH_PROVIDERS")
private providers: Record<string, OAuthProvider<unknown>>,
) {}
@Get("available")
available() {
return this.oauthService.available();
}
@Get("status")
@UseGuards(JwtGuard)
async status(@GetUser() user: User) {
return this.oauthService.status(user);
}
@Get("auth/:provider")
@UseGuards(ProviderGuard)
@UseFilters(ErrorPageExceptionFilter)
async auth(
@Param("provider") provider: string,
@Res({ passthrough: true }) response: Response,
) {
const state = nanoid(16);
const url = await this.providers[provider].getAuthEndpoint(state);
response.cookie(`oauth_${provider}_state`, state, { sameSite: "lax" });
response.redirect(url);
}
@Get("callback/:provider")
@UseGuards(ProviderGuard, OAuthGuard)
@UseFilters(ErrorPageExceptionFilter, OAuthExceptionFilter)
async callback(
@Param("provider") provider: string,
@Query() query: OAuthCallbackDto,
@Req() request: Request,
@Res({ passthrough: true }) response: Response,
) {
const oauthToken = await this.providers[provider].getToken(query);
const user = await this.providers[provider].getUserInfo(oauthToken, query);
const id = await this.authService.getIdOfCurrentUser(request);
if (id) {
await this.oauthService.link(
id,
provider,
user.providerId,
user.providerUsername,
);
response.redirect(this.config.get("general.appUrl") + "/account");
} else {
const token: {
accessToken?: string;
refreshToken?: string;
loginToken?: string;
} = await this.oauthService.signIn(user);
if (token.accessToken) {
this.authService.addTokensToResponse(
response,
token.refreshToken,
token.accessToken,
);
response.redirect(this.config.get("general.appUrl"));
} else {
response.redirect(
this.config.get("general.appUrl") + `/auth/totp/${token.loginToken}`,
);
}
}
}
@Post("unlink/:provider")
@UseGuards(JwtGuard, ProviderGuard)
@UseFilters(ErrorPageExceptionFilter)
unlink(@GetUser() user: User, @Param("provider") provider: string) {
return this.oauthService.unlink(user, provider);
}
}

View File

@@ -0,0 +1,56 @@
import { Module } from "@nestjs/common";
import { OAuthController } from "./oauth.controller";
import { OAuthService } from "./oauth.service";
import { AuthModule } from "../auth/auth.module";
import { GitHubProvider } from "./provider/github.provider";
import { GoogleProvider } from "./provider/google.provider";
import { OAuthProvider } from "./provider/oauthProvider.interface";
import { OidcProvider } from "./provider/oidc.provider";
import { DiscordProvider } from "./provider/discord.provider";
import { MicrosoftProvider } from "./provider/microsoft.provider";
@Module({
controllers: [OAuthController],
providers: [
OAuthService,
GitHubProvider,
GoogleProvider,
MicrosoftProvider,
DiscordProvider,
OidcProvider,
{
provide: "OAUTH_PROVIDERS",
useFactory(
github: GitHubProvider,
google: GoogleProvider,
microsoft: MicrosoftProvider,
discord: DiscordProvider,
oidc: OidcProvider,
): Record<string, OAuthProvider<unknown>> {
return {
github,
google,
microsoft,
discord,
oidc,
};
},
inject: [
GitHubProvider,
GoogleProvider,
MicrosoftProvider,
DiscordProvider,
OidcProvider,
],
},
{
provide: "OAUTH_PLATFORMS",
useFactory(providers: Record<string, OAuthProvider<unknown>>): string[] {
return Object.keys(providers);
},
inject: ["OAUTH_PROVIDERS"],
},
],
imports: [AuthModule],
})
export class OAuthModule {}

View File

@@ -0,0 +1,171 @@
import { Inject, Injectable } from "@nestjs/common";
import { User } from "@prisma/client";
import { nanoid } from "nanoid";
import { AuthService } from "../auth/auth.service";
import { ConfigService } from "../config/config.service";
import { PrismaService } from "../prisma/prisma.service";
import { OAuthSignInDto } from "./dto/oauthSignIn.dto";
import { ErrorPageException } from "./exceptions/errorPage.exception";
@Injectable()
export class OAuthService {
constructor(
private prisma: PrismaService,
private config: ConfigService,
private auth: AuthService,
@Inject("OAUTH_PLATFORMS") private platforms: string[],
) {}
available(): string[] {
return this.platforms
.map((platform) => [
platform,
this.config.get(`oauth.${platform}-enabled`),
])
.filter(([_, enabled]) => enabled)
.map(([platform, _]) => platform);
}
async status(user: User) {
const oauthUsers = await this.prisma.oAuthUser.findMany({
select: {
provider: true,
providerUsername: true,
},
where: {
userId: user.id,
},
});
return Object.fromEntries(oauthUsers.map((u) => [u.provider, u]));
}
async signIn(user: OAuthSignInDto) {
const oauthUser = await this.prisma.oAuthUser.findFirst({
where: {
provider: user.provider,
providerUserId: user.providerId,
},
include: {
user: true,
},
});
if (oauthUser) {
return this.auth.generateToken(oauthUser.user, true);
}
return this.signUp(user);
}
async link(
userId: string,
provider: string,
providerUserId: string,
providerUsername: string,
) {
const oauthUser = await this.prisma.oAuthUser.findFirst({
where: {
provider,
providerUserId,
},
});
if (oauthUser) {
throw new ErrorPageException("already_linked", "/account", [
`provider_${provider}`,
]);
}
await this.prisma.oAuthUser.create({
data: {
userId,
provider,
providerUsername,
providerUserId,
},
});
}
async unlink(user: User, provider: string) {
const oauthUser = await this.prisma.oAuthUser.findFirst({
where: {
userId: user.id,
provider,
},
});
if (oauthUser) {
await this.prisma.oAuthUser.delete({
where: {
id: oauthUser.id,
},
});
} else {
throw new ErrorPageException("not_linked", "/account", [provider]);
}
}
private async getAvailableUsername(email: string) {
// only remove + and - from email for now (maybe not enough)
let username = email.split("@")[0].replace(/[+-]/g, "").substring(0, 20);
while (true) {
const user = await this.prisma.user.findFirst({
where: {
username: username,
},
});
if (user) {
username = username + "_" + nanoid(10).replaceAll("-", "");
} else {
return username;
}
}
}
private async signUp(user: OAuthSignInDto) {
// register
if (!this.config.get("oauth.allowRegistration")) {
throw new ErrorPageException("no_user", "/auth/signIn", [
`provider_${user.provider}`,
]);
}
if (!user.email) {
throw new ErrorPageException("no_email", "/auth/signIn", [
`provider_${user.provider}`,
]);
}
const existingUser: User = await this.prisma.user.findFirst({
where: {
email: user.email,
},
});
if (existingUser) {
await this.prisma.oAuthUser.create({
data: {
provider: user.provider,
providerUserId: user.providerId.toString(),
providerUsername: user.providerUsername,
userId: existingUser.id,
},
});
return this.auth.generateToken(existingUser, true);
}
const result = await this.auth.signUp({
email: user.email,
username: await this.getAvailableUsername(user.email),
password: null,
});
await this.prisma.oAuthUser.create({
data: {
provider: user.provider,
providerUserId: user.providerId.toString(),
providerUsername: user.providerUsername,
userId: result.user.id,
},
});
return result;
}
}

View File

@@ -0,0 +1,98 @@
import { OAuthProvider, OAuthToken } from "./oauthProvider.interface";
import { OAuthCallbackDto } from "../dto/oauthCallback.dto";
import { OAuthSignInDto } from "../dto/oauthSignIn.dto";
import { ConfigService } from "../../config/config.service";
import { BadRequestException, Injectable } from "@nestjs/common";
import fetch from "node-fetch";
@Injectable()
export class DiscordProvider implements OAuthProvider<DiscordToken> {
constructor(private config: ConfigService) {}
getAuthEndpoint(state: string): Promise<string> {
return Promise.resolve(
"https://discord.com/api/oauth2/authorize?" +
new URLSearchParams({
client_id: this.config.get("oauth.discord-clientId"),
redirect_uri:
this.config.get("general.appUrl") + "/api/oauth/callback/discord",
response_type: "code",
state: state,
scope: "identify email",
}).toString(),
);
}
private getAuthorizationHeader() {
return (
"Basic " +
Buffer.from(
this.config.get("oauth.discord-clientId") +
":" +
this.config.get("oauth.discord-clientSecret"),
).toString("base64")
);
}
async getToken(query: OAuthCallbackDto): Promise<OAuthToken<DiscordToken>> {
const res = await fetch("https://discord.com/api/v10/oauth2/token", {
method: "post",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
Authorization: this.getAuthorizationHeader(),
},
body: new URLSearchParams({
code: query.code,
grant_type: "authorization_code",
redirect_uri:
this.config.get("general.appUrl") + "/api/oauth/callback/discord",
}),
});
const token: DiscordToken = await res.json();
return {
accessToken: token.access_token,
refreshToken: token.refresh_token,
expiresIn: token.expires_in,
scope: token.scope,
tokenType: token.token_type,
rawToken: token,
};
}
async getUserInfo(token: OAuthToken<DiscordToken>): Promise<OAuthSignInDto> {
const res = await fetch("https://discord.com/api/v10/user/@me", {
method: "post",
headers: {
Accept: "application/json",
Authorization: `${token.tokenType || "Bearer"} ${token.accessToken}`,
},
});
const user = (await res.json()) as DiscordUser;
if (user.verified === false) {
throw new BadRequestException("Unverified account.");
}
return {
provider: "discord",
providerId: user.id,
providerUsername: user.global_name ?? user.username,
email: user.email,
};
}
}
export interface DiscordToken {
access_token: string;
token_type: string;
expires_in: number;
refresh_token: string;
scope: string;
}
export interface DiscordUser {
id: string;
username: string;
global_name: string;
email: string;
verified: boolean;
}

View File

@@ -0,0 +1,206 @@
import { BadRequestException } from "@nestjs/common";
import fetch from "node-fetch";
import { ConfigService } from "../../config/config.service";
import { JwtService } from "@nestjs/jwt";
import { Cache } from "cache-manager";
import { nanoid } from "nanoid";
import { OAuthCallbackDto } from "../dto/oauthCallback.dto";
import { OAuthProvider, OAuthToken } from "./oauthProvider.interface";
import { OAuthSignInDto } from "../dto/oauthSignIn.dto";
export abstract class GenericOidcProvider implements OAuthProvider<OidcToken> {
protected redirectUri: string;
protected discoveryUri: string;
private configuration: OidcConfigurationCache;
private jwk: OidcJwkCache;
protected constructor(
protected name: string,
protected keyOfConfigUpdateEvents: string[],
protected config: ConfigService,
protected jwtService: JwtService,
protected cache: Cache,
) {
this.discoveryUri = this.getDiscoveryUri();
this.redirectUri = `${this.config.get(
"general.appUrl",
)}/api/oauth/callback/${this.name}`;
this.config.addListener("update", (key: string, _: unknown) => {
if (this.keyOfConfigUpdateEvents.includes(key)) {
this.deinit();
this.discoveryUri = this.getDiscoveryUri();
}
});
}
async getConfiguration(): Promise<OidcConfiguration> {
if (!this.configuration || this.configuration.expires < Date.now()) {
await this.fetchConfiguration();
}
return this.configuration.data;
}
async getJwk(): Promise<OidcJwk[]> {
if (!this.jwk || this.jwk.expires < Date.now()) {
await this.fetchJwk();
}
return this.jwk.data;
}
async getAuthEndpoint(state: string) {
const configuration = await this.getConfiguration();
const endpoint = configuration.authorization_endpoint;
const nonce = nanoid();
await this.cache.set(
`oauth-${this.name}-nonce-${state}`,
nonce,
1000 * 60 * 5,
);
return (
endpoint +
"?" +
new URLSearchParams({
client_id: this.config.get(`oauth.${this.name}-clientId`),
response_type: "code",
scope: "openid profile email",
redirect_uri: this.redirectUri,
state,
nonce,
}).toString()
);
}
async getToken(query: OAuthCallbackDto): Promise<OAuthToken<OidcToken>> {
const configuration = await this.getConfiguration();
const endpoint = configuration.token_endpoint;
const res = await fetch(endpoint, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
body: new URLSearchParams({
client_id: this.config.get(`oauth.${this.name}-clientId`),
client_secret: this.config.get(`oauth.${this.name}-clientSecret`),
grant_type: "authorization_code",
code: query.code,
redirect_uri: this.redirectUri,
}).toString(),
});
const token: OidcToken = await res.json();
return {
accessToken: token.access_token,
expiresIn: token.expires_in,
idToken: token.id_token,
refreshToken: token.refresh_token,
tokenType: token.token_type,
rawToken: token,
};
}
async getUserInfo(
token: OAuthToken<OidcToken>,
query: OAuthCallbackDto,
): Promise<OAuthSignInDto> {
const idTokenData = this.decodeIdToken(token.idToken);
// maybe it's not necessary to verify the id token since it's directly obtained from the provider
const key = `oauth-${this.name}-nonce-${query.state}`;
const nonce = await this.cache.get(key);
await this.cache.del(key);
if (nonce !== idTokenData.nonce) {
throw new BadRequestException("Invalid token");
}
return {
provider: this.name as any,
email: idTokenData.email,
providerId: idTokenData.sub,
providerUsername: idTokenData.name,
};
}
protected abstract getDiscoveryUri(): string;
private async fetchConfiguration(): Promise<void> {
const res = await fetch(this.discoveryUri);
const expires = res.headers.has("expires")
? new Date(res.headers.get("expires")).getTime()
: Date.now() + 1000 * 60 * 60 * 24;
this.configuration = {
expires,
data: await res.json(),
};
}
private async fetchJwk(): Promise<void> {
const configuration = await this.getConfiguration();
const res = await fetch(configuration.jwks_uri);
const expires = res.headers.has("expires")
? new Date(res.headers.get("expires")).getTime()
: Date.now() + 1000 * 60 * 60 * 24;
this.jwk = {
expires,
data: (await res.json())["keys"],
};
}
private deinit() {
this.discoveryUri = undefined;
this.configuration = undefined;
this.jwk = undefined;
}
private decodeIdToken(idToken: string): OidcIdToken {
return this.jwtService.decode(idToken) as OidcIdToken;
}
}
export interface OidcCache<T> {
expires: number;
data: T;
}
export interface OidcConfiguration {
issuer: string;
authorization_endpoint: string;
token_endpoint: string;
userinfo_endpoint?: string;
jwks_uri: string;
response_types_supported: string[];
id_token_signing_alg_values_supported: string[];
scopes_supported?: string[];
claims_supported?: string[];
}
export interface OidcJwk {
e: string;
alg: string;
kid: string;
use: string;
kty: string;
n: string;
}
export type OidcConfigurationCache = OidcCache<OidcConfiguration>;
export type OidcJwkCache = OidcCache<OidcJwk[]>;
export interface OidcToken {
access_token: string;
refresh_token: string;
token_type: string;
expires_in: number;
id_token: string;
}
export interface OidcIdToken {
iss: string;
sub: string;
exp: number;
iat: number;
email: string;
name: string;
nonce: string;
}

View File

@@ -0,0 +1,110 @@
import { OAuthProvider, OAuthToken } from "./oauthProvider.interface";
import { OAuthCallbackDto } from "../dto/oauthCallback.dto";
import { OAuthSignInDto } from "../dto/oauthSignIn.dto";
import { ConfigService } from "../../config/config.service";
import fetch from "node-fetch";
import { BadRequestException, Injectable } from "@nestjs/common";
@Injectable()
export class GitHubProvider implements OAuthProvider<GitHubToken> {
constructor(private config: ConfigService) {}
getAuthEndpoint(state: string): Promise<string> {
return Promise.resolve(
"https://github.com/login/oauth/authorize?" +
new URLSearchParams({
client_id: this.config.get("oauth.github-clientId"),
redirect_uri:
this.config.get("general.appUrl") + "/api/oauth/callback/github",
state: state,
scope: "user:email",
}).toString(),
);
}
async getToken(query: OAuthCallbackDto): Promise<OAuthToken<GitHubToken>> {
const res = await fetch(
"https://github.com/login/oauth/access_token?" +
new URLSearchParams({
client_id: this.config.get("oauth.github-clientId"),
client_secret: this.config.get("oauth.github-clientSecret"),
code: query.code,
}).toString(),
{
method: "post",
headers: {
Accept: "application/json",
},
},
);
const token: GitHubToken = await res.json();
return {
accessToken: token.access_token,
tokenType: token.token_type,
rawToken: token,
};
}
async getUserInfo(token: OAuthToken<GitHubToken>): Promise<OAuthSignInDto> {
const user = await this.getGitHubUser(token);
if (!token.scope.includes("user:email")) {
throw new BadRequestException("No email permission granted");
}
const email = await this.getGitHubEmail(token);
if (!email) {
throw new BadRequestException("No email found");
}
return {
provider: "github",
providerId: user.id.toString(),
providerUsername: user.name ?? user.login,
email,
};
}
private async getGitHubUser(
token: OAuthToken<GitHubToken>,
): Promise<GitHubUser> {
const res = await fetch("https://api.github.com/user", {
headers: {
Accept: "application/vnd.github+json",
Authorization: `${token.tokenType ?? "Bearer"} ${token.accessToken}`,
},
});
return (await res.json()) as GitHubUser;
}
private async getGitHubEmail(
token: OAuthToken<GitHubToken>,
): Promise<string | undefined> {
const res = await fetch("https://api.github.com/user/public_emails", {
headers: {
Accept: "application/vnd.github+json",
Authorization: `${token.tokenType ?? "Bearer"} ${token.accessToken}`,
},
});
const emails = (await res.json()) as GitHubEmail[];
return emails.find((e) => e.primary && e.verified)?.email;
}
}
export interface GitHubToken {
access_token: string;
token_type: string;
scope: string;
}
export interface GitHubUser {
login: string;
id: number;
name?: string;
email?: string; // this filed seems only return null
}
export interface GitHubEmail {
email: string;
primary: boolean;
verified: boolean;
visibility: string | null;
}

View File

@@ -0,0 +1,21 @@
import { GenericOidcProvider } from "./genericOidc.provider";
import { ConfigService } from "../../config/config.service";
import { JwtService } from "@nestjs/jwt";
import { Inject, Injectable } from "@nestjs/common";
import { CACHE_MANAGER } from "@nestjs/cache-manager";
import { Cache } from "cache-manager";
@Injectable()
export class GoogleProvider extends GenericOidcProvider {
constructor(
config: ConfigService,
jwtService: JwtService,
@Inject(CACHE_MANAGER) cache: Cache,
) {
super("google", ["oauth.google-enabled"], config, jwtService, cache);
}
protected getDiscoveryUri(): string {
return "https://accounts.google.com/.well-known/openid-configuration";
}
}

View File

@@ -0,0 +1,29 @@
import { GenericOidcProvider } from "./genericOidc.provider";
import { ConfigService } from "../../config/config.service";
import { JwtService } from "@nestjs/jwt";
import { Inject, Injectable } from "@nestjs/common";
import { CACHE_MANAGER } from "@nestjs/cache-manager";
import { Cache } from "cache-manager";
@Injectable()
export class MicrosoftProvider extends GenericOidcProvider {
constructor(
config: ConfigService,
jwtService: JwtService,
@Inject(CACHE_MANAGER) cache: Cache,
) {
super(
"microsoft",
["oauth.microsoft-enabled", "oauth.microsoft-tenant"],
config,
jwtService,
cache,
);
}
protected getDiscoveryUri(): string {
return `https://login.microsoftonline.com/${this.config.get(
"oauth.microsoft-tenant",
)}/v2.0/.well-known/openid-configuration`;
}
}

View File

@@ -0,0 +1,24 @@
import { OAuthCallbackDto } from "../dto/oauthCallback.dto";
import { OAuthSignInDto } from "../dto/oauthSignIn.dto";
/**
* @typeParam T - type of token
* @typeParam C - type of callback query
*/
export interface OAuthProvider<T, C = OAuthCallbackDto> {
getAuthEndpoint(state: string): Promise<string>;
getToken(query: C): Promise<OAuthToken<T>>;
getUserInfo(token: OAuthToken<T>, query: C): Promise<OAuthSignInDto>;
}
export interface OAuthToken<T> {
accessToken: string;
expiresIn?: number;
refreshToken?: string;
tokenType?: string;
scope?: string;
idToken?: string;
rawToken: T;
}

View File

@@ -0,0 +1,27 @@
import { GenericOidcProvider } from "./genericOidc.provider";
import { Inject, Injectable } from "@nestjs/common";
import { ConfigService } from "../../config/config.service";
import { JwtService } from "@nestjs/jwt";
import { CACHE_MANAGER } from "@nestjs/cache-manager";
import { Cache } from "cache-manager";
@Injectable()
export class OidcProvider extends GenericOidcProvider {
constructor(
config: ConfigService,
jwtService: JwtService,
@Inject(CACHE_MANAGER) protected cache: Cache,
) {
super(
"oidc",
["oauth.oidc-enabled", "oauth.oidc-discoveryUri"],
config,
jwtService,
cache,
);
}
protected getDiscoveryUri(): string {
return this.config.get("oauth.oidc-discoveryUri");
}
}

View File

@@ -16,6 +16,9 @@ export class UserDTO {
@IsEmail()
email: string;
@Expose()
hasPassword: boolean;
@MinLength(8)
password: string;

View File

@@ -28,7 +28,9 @@ export class UserController {
@Get("me")
@UseGuards(JwtGuard)
async getCurrentUser(@GetUser() user: User) {
return new UserDTO().from(user);
const userDTO = new UserDTO().from(user);
userDTO.hasPassword = !!user.password;
return userDTO;
}
@Patch("me")