import { TRPCError } from "@trpc/server"; import { and, count, eq } from "drizzle-orm"; import invariant from "tiny-invariant"; import { z } from "zod"; import { SqliteError } from "@lifetracker/db"; import { users } from "@lifetracker/db/schema"; import serverConfig from "@lifetracker/shared/config"; import { zSignUpSchema } from "@lifetracker/shared/types/users"; import { hashPassword, validatePassword } from "../auth"; import { adminProcedure, authedProcedure, Context, publicProcedure, router, } from "../index"; export async function createUser( input: z.infer, ctx: Context, role?: "user" | "admin", ) { // console.log(ctx.db); return ctx.db.transaction(async (trx) => { let userRole = role; if (!userRole) { const [{ count: userCount }] = await trx .select({ count: count() }) .from(users); userRole = userCount == 0 ? "admin" : "user"; } try { const result = await trx .insert(users) .values({ name: input.name, email: input.email, password: await hashPassword(input.password), role: userRole, }) .returning({ id: users.id, name: users.name, email: users.email, role: users.role, }); return result[0]; } catch (e) { if (e instanceof SqliteError) { if (e.code == "SQLITE_CONSTRAINT_UNIQUE") { throw new TRPCError({ code: "BAD_REQUEST", message: "Email is already taken", }); } } throw new TRPCError({ code: "INTERNAL_SERVER_ERROR", message: "Something went wrong", }); } }); } export const usersAppRouter = router({ create: publicProcedure .input(zSignUpSchema) .output( z.object({ id: z.string(), name: z.string(), email: z.string(), role: z.enum(["user", "admin"]).nullable(), }), ) .mutation(async ({ input, ctx }) => { if ( serverConfig.auth.disableSignups || serverConfig.auth.disablePasswordAuth ) { const errorMessage = serverConfig.auth.disablePasswordAuth ? "Local Signups are disabled in the server config. Use OAuth instead!" : "Signups are disabled in server config"; throw new TRPCError({ code: "FORBIDDEN", message: errorMessage, }); } return createUser(input, ctx); }), list: adminProcedure .output( z.object({ users: z.array( z.object({ id: z.string(), name: z.string(), email: z.string(), role: z.enum(["user", "admin"]).nullable(), localUser: z.boolean(), }), ), }), ) .query(async ({ ctx }) => { const dbUsers = await ctx.db .select({ id: users.id, name: users.name, email: users.email, role: users.role, password: users.password, }) .from(users); return { users: dbUsers.map(({ password, ...user }) => ({ ...user, localUser: password !== null, })), }; }), changePassword: authedProcedure .input( z.object({ currentPassword: z.string(), newPassword: z.string(), }), ) .mutation(async ({ input, ctx }) => { invariant(ctx.user.email, "A user always has an email specified"); let user; try { user = await validatePassword(ctx.user.email, input.currentPassword); } catch (e) { throw new TRPCError({ code: "UNAUTHORIZED" }); } invariant(user.id, ctx.user.id); await ctx.db .update(users) .set({ password: await hashPassword(input.newPassword), }) .where(eq(users.id, ctx.user.id)); }), whoami: authedProcedure .output( z.object({ id: z.string(), name: z.string().nullish(), email: z.string().nullish(), timezone: z.string().nullish(), }), ) .query(async ({ ctx }) => { if (!ctx.user.email) { throw new TRPCError({ code: "UNAUTHORIZED" }); } const userDb = await ctx.db.query.users.findFirst({ where: and(eq(users.id, ctx.user.id), eq(users.email, ctx.user.email)), }); if (!userDb) { throw new TRPCError({ code: "UNAUTHORIZED" }); } return { id: ctx.user.id, name: ctx.user.name, email: ctx.user.email, timezone: ctx.user.timezone, }; }), getTimezone: authedProcedure .output( z.string(), ) .query(async ({ ctx }) => { const res = await ctx.db.select({ timezone: users.timezone }).from(users).where(eq(users.id, ctx.user.id)); return res[0].timezone; }), changeTimezone: authedProcedure .input( z.object({ newTimezone: z.string(), }), ) .output(z.string()) .mutation(async ({ input, ctx }) => { // invariant(ctx.user.timezone, "A user always has a timezone specified"); await ctx.db .update(users) .set({ timezone: input.newTimezone, }) .where(eq(users.id, ctx.user.id)); return input.newTimezone; }), });