99from database import get_db
1010from schemas import PostCreate , PostResponse , PostUpdate
1111
12+ from auth import CurrentUser
13+
1214router = APIRouter ()
1315
1416@router .get ("" , response_model = list [PostResponse ])
@@ -27,21 +29,12 @@ async def get_posts(db: Annotated[AsyncSession, Depends(get_db)]):
2729 response_model = PostResponse ,
2830 status_code = status .HTTP_201_CREATED ,
2931)
30- async def create_post (post : PostCreate , db : Annotated [AsyncSession , Depends (get_db )]):
31- result = await db .execute (
32- select (models .User ).where (models .User .id == post .user_id )
33- )
34- user = result .scalars ().first ()
35- if not user :
36- raise HTTPException (
37- status_code = status .HTTP_404_NOT_FOUND ,
38- detail = "User not found" ,
39- )
32+ async def create_post (post : PostCreate , current_user : CurrentUser , db : Annotated [AsyncSession , Depends (get_db )]):
4033
4134 new_post = models .Post (
4235 title = post .title ,
4336 content = post .content ,
44- user_id = post . user_id ,
37+ user_id = current_user . id ,
4538 )
4639 db .add (new_post )
4740 await db .commit ()
@@ -66,6 +59,7 @@ async def get_post(post_id: int, db: Annotated[AsyncSession, Depends(get_db)]):
6659async def update_post_full (
6760 post_id : int ,
6861 post_data : PostCreate ,
62+ current_user : CurrentUser ,
6963 db : Annotated [AsyncSession , Depends (get_db )],
7064):
7165 result = await db .execute (select (models .Post ).where (models .Post .id == post_id ))
@@ -75,20 +69,15 @@ async def update_post_full(
7569 status_code = status .HTTP_404_NOT_FOUND ,
7670 detail = "Post not found" ,
7771 )
78- if post_data .user_id != post .user_id :
79- result = await db .execute (
80- select (models .User ).where (models .User .id == post_data .user_id ),
72+
73+ if post .user_id != current_user .id :
74+ raise HTTPException (
75+ status_code = status .HTTP_403_FORBIDDEN ,
76+ detail = "Not authorized to update this post" ,
8177 )
82- user = result .scalars ().first ()
83- if not user :
84- raise HTTPException (
85- status_code = status .HTTP_404_NOT_FOUND ,
86- detail = "User not found" ,
87- )
8878
8979 post .title = post_data .title
9080 post .content = post_data .content
91- post .user_id = post_data .user_id
9281
9382 await db .commit ()
9483 await db .refresh (post , attribute_names = ["author" ])
@@ -99,6 +88,7 @@ async def update_post_full(
9988async def update_post_partial (
10089 post_id : int ,
10190 post_data : PostUpdate ,
91+ current_user : CurrentUser ,
10292 db : Annotated [AsyncSession , Depends (get_db )],
10393):
10494 result = await db .execute (select (models .Post ).where (models .Post .id == post_id ))
@@ -109,6 +99,12 @@ async def update_post_partial(
10999 detail = "Post not found" ,
110100 )
111101
102+ if post .user_id != current_user .id :
103+ raise HTTPException (
104+ status_code = status .HTTP_403_FORBIDDEN ,
105+ detail = "Not authorized to update this post" ,
106+ )
107+
112108 update_data = post_data .model_dump (exclude_unset = True )
113109 for field , value in update_data .items ():
114110 setattr (post , field , value )
@@ -119,7 +115,10 @@ async def update_post_partial(
119115
120116
121117@router .delete ("/{post_id}" , status_code = status .HTTP_204_NO_CONTENT )
122- async def delete_post (post_id : int , db : Annotated [AsyncSession , Depends (get_db )]):
118+ async def delete_post (
119+ post_id : int ,
120+ current_user : CurrentUser ,
121+ db : Annotated [AsyncSession , Depends (get_db )]):
123122 result = await db .execute (select (models .Post ).where (models .Post .id == post_id ))
124123 post = result .scalars ().first ()
125124 if not post :
@@ -128,5 +127,11 @@ async def delete_post(post_id: int, db: Annotated[AsyncSession, Depends(get_db)]
128127 detail = "Post not found" ,
129128 )
130129
130+ if post .user_id != current_user .id :
131+ raise HTTPException (
132+ status_code = status .HTTP_403_FORBIDDEN ,
133+ detail = "Not authorized to delete this post" ,
134+ )
135+
131136 await db .delete (post )
132137 await db .commit ()
0 commit comments