Skip to content

Commit 3bb06b8

Browse files
authored
Add user avatar uploads (#687)
* Add user avatar uploads * Update empire/test/test_user_api.py
1 parent c98d30c commit 3bb06b8

File tree

13 files changed

+191
-38
lines changed

13 files changed

+191
-38
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
- Add avatars to users (@Vinnybod)
1011
- Update plugin documentation, update embedded plugins to not abuse notifications (@Vinnybod)
1112
- Add additional pre-commit hooks for code cleanup (@Vinnybod)
1213
- Report test coverage on pull requests (@Vinnybod)

empire/server/api/v2/user/user_api.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from datetime import timedelta
22

3-
from fastapi import Depends, HTTPException
3+
from fastapi import Depends, File, HTTPException, UploadFile
44
from fastapi.security import OAuth2PasswordRequestForm
55
from sqlalchemy.orm import Session
66
from starlette import status
@@ -160,3 +160,21 @@ async def update_user_password(
160160
raise HTTPException(status_code=400, detail=err)
161161

162162
return domain_to_dto_user(resp)
163+
164+
165+
@router.post("/api/v2/users/{uid}/avatar", status_code=201)
166+
async def create_avatar(
167+
uid: int,
168+
db: Session = Depends(get_db),
169+
user: models.User = Depends(get_current_active_user),
170+
file: UploadFile = File(...),
171+
):
172+
if not user.id == uid:
173+
raise HTTPException(
174+
status_code=403, detail="User does not have access to update this resource."
175+
)
176+
177+
if not file.content_type.startswith("image/"):
178+
raise HTTPException(status_code=400, detail="File must be an image.")
179+
180+
user_service.update_user_avatar(db, user, file)

empire/server/api/v2/user/user_dto.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
from datetime import datetime
2-
from typing import List
2+
from typing import List, Optional
33

44
from pydantic import BaseModel
55

6+
from empire.server.api.v2.shared_dto import (
7+
DownloadDescription,
8+
domain_to_dto_download_description,
9+
)
10+
611

712
def domain_to_dto_user(user):
13+
if user.avatar:
14+
download_description = domain_to_dto_download_description(user.avatar)
15+
else:
16+
download_description = None
817
return User(
918
id=user.id,
1019
username=user.username,
1120
enabled=user.enabled,
1221
is_admin=user.admin,
1322
created_at=user.created_at,
1423
updated_at=user.updated_at,
24+
avatar=download_description,
1525
)
1626

1727

@@ -20,6 +30,7 @@ class User(BaseModel):
2030
username: str
2131
enabled: bool
2232
is_admin: bool
33+
avatar: Optional[DownloadDescription]
2334
created_at: datetime
2435
updated_at: datetime
2536

empire/server/common/empire.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,14 @@ def __init__(self, args=None):
7575

7676
self.listenertemplatesv2 = ListenerTemplateService(self)
7777
self.stagertemplatesv2 = StagerTemplateService(self)
78-
self.usersv2 = UserService(self)
7978
self.bypassesv2 = BypassService(self)
8079
self.obfuscationv2 = ObfuscationService(self)
8180
self.profilesv2 = ProfileService(self)
8281
self.credentialsv2 = CredentialService(self)
8382
self.hostsv2 = HostService(self)
8483
self.processesv2 = HostProcessService(self)
8584
self.downloadsv2 = DownloadService(self)
85+
self.usersv2 = UserService(self)
8686
self.listenersv2 = ListenerService(self)
8787
self.stagersv2 = StagerService(self)
8888
self.modulesv2 = ModuleService(self)

empire/server/core/db/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ class User(Base):
141141
updated_at = Column(
142142
UtcDateTime, default=utcnow(), onupdate=utcnow(), nullable=False
143143
)
144+
avatar = relationship("Download")
145+
avatar_id = Column(Integer, ForeignKey("downloads.id"), nullable=True)
144146

145147
def __repr__(self):
146148
return "<User(username='%s')>" % (self.username)

empire/server/core/user_service.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
from fastapi import UploadFile
12
from sqlalchemy.orm import Session
23

34
from empire.server.core.db import models
5+
from empire.server.core.download_service import DownloadService
46

57

68
class UserService(object):
79
def __init__(self, main_menu):
810
self.main_menu = main_menu
11+
self.download_service: DownloadService = main_menu.downloadsv2
912

1013
@staticmethod
1114
def get_all(db: Session):
@@ -57,3 +60,8 @@ def update_user_password(db: Session, db_user: models.User, hashed_password: str
5760
db.flush()
5861

5962
return db_user, None
63+
64+
def update_user_avatar(self, db: Session, db_user: models.User, file: UploadFile):
65+
download = self.download_service.create_download(db, db_user, file)
66+
67+
db_user.avatar = download

empire/test/avatar.png

856 Bytes
Loading

empire/test/avatar2.png

856 Bytes
Loading

empire/test/conftest.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from fastapi import FastAPI
1010
from starlette.testclient import TestClient
1111

12+
from empire.client.src.utils.data_util import get_random_string
13+
1214
SERVER_CONFIG_LOC = "empire/test/test_server_config.yaml"
1315
CLIENT_CONFIG_LOC = "empire/test/test_client_config.yaml"
1416
DEFAULT_ARGV = ["", "server", "--config", SERVER_CONFIG_LOC]
@@ -431,15 +433,18 @@ def credential(client, admin_auth_header):
431433
json={
432434
"credtype": "hash",
433435
"domain": "the-domain",
434-
"username": "user",
435-
"password": "hunter2",
436+
"username": get_random_string(8),
437+
"password": get_random_string(8),
436438
"host": "host1",
437439
},
438440
)
439441

440442
yield resp.json()["id"]
441443

442-
client.delete(f"/api/v2/credentials/{resp.json()['id']}", headers=admin_auth_header)
444+
with suppress(Exception):
445+
client.delete(
446+
f"/api/v2/credentials/{resp.json()['id']}", headers=admin_auth_header
447+
)
443448

444449

445450
@pytest.fixture(scope="function")

empire/test/test_credential_api.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_create_credential(client, admin_auth_header, base_credential):
2020
)
2121

2222
assert response.status_code == 201
23-
assert response.json()["id"] == 1
23+
assert response.json()["id"] > 0
2424
assert response.json()["credtype"] == "hash"
2525
assert response.json()["domain"] == "the-domain"
2626
assert response.json()["username"] == "user"
@@ -49,7 +49,7 @@ def test_update_credential_not_found(client, admin_auth_header, base_credential)
4949

5050

5151
def test_update_credential_unique_constraint_failure(
52-
client, admin_auth_header, base_credential
52+
client, admin_auth_header, base_credential, credential
5353
):
5454
credential_2 = copy.deepcopy(base_credential)
5555
credential_2["domain"] = "the-domain-2"
@@ -59,19 +59,26 @@ def test_update_credential_unique_constraint_failure(
5959
assert response.status_code == 201
6060

6161
response = client.put(
62-
"/api/v2/credentials/2", headers=admin_auth_header, json=base_credential
62+
f"/api/v2/credentials/{credential}",
63+
headers=admin_auth_header,
64+
json=base_credential,
6365
)
6466

6567
assert response.status_code == 400
6668
assert response.json()["detail"] == "Credential not updated. Duplicate detected."
6769

6870

69-
def test_update_credential(client, admin_auth_header, base_credential):
70-
updated_credential = base_credential
71+
def test_update_credential(client, admin_auth_header, credential):
72+
response = client.get(
73+
f"/api/v2/credentials/{credential}", headers=admin_auth_header
74+
)
75+
updated_credential = response.json()
7176
updated_credential["domain"] = "new-domain"
7277
updated_credential["password"] = "password3"
7378
response = client.put(
74-
"/api/v2/credentials/1", headers=admin_auth_header, json=updated_credential
79+
f"/api/v2/credentials/{updated_credential['id']}",
80+
headers=admin_auth_header,
81+
json=updated_credential,
7582
)
7683

7784
assert response.status_code == 200
@@ -86,11 +93,13 @@ def test_get_credential_not_found(client, admin_auth_header):
8693
assert response.json()["detail"] == "Credential not found for id 9999"
8794

8895

89-
def test_get_credential(client, admin_auth_header):
90-
response = client.get("/api/v2/credentials/1", headers=admin_auth_header)
96+
def test_get_credential(client, admin_auth_header, credential):
97+
response = client.get(
98+
f"/api/v2/credentials/{credential}", headers=admin_auth_header
99+
)
91100

92101
assert response.status_code == 200
93-
assert response.json()["id"] == 1
102+
assert response.json()["id"] > 0
94103

95104

96105
def test_get_credentials(client, admin_auth_header):
@@ -100,12 +109,18 @@ def test_get_credentials(client, admin_auth_header):
100109
assert len(response.json()["records"]) > 0
101110

102111

103-
def test_get_credentials_search(client, admin_auth_header):
104-
response = client.get("/api/v2/credentials?search=hunt", headers=admin_auth_header)
112+
def test_get_credentials_search(client, admin_auth_header, credential):
113+
response = client.get(
114+
f"/api/v2/credentials/{credential}", headers=admin_auth_header
115+
)
116+
password = response.json()["password"]
117+
response = client.get(
118+
f"/api/v2/credentials?search={password[:3]}", headers=admin_auth_header
119+
)
105120

106121
assert response.status_code == 200
107122
assert len(response.json()["records"]) == 1
108-
assert response.json()["records"][0]["password"] == "hunter2"
123+
assert response.json()["records"][0]["password"] == password
109124

110125
response = client.get(
111126
"/api/v2/credentials?search=qwerty", headers=admin_auth_header
@@ -115,11 +130,15 @@ def test_get_credentials_search(client, admin_auth_header):
115130
assert len(response.json()["records"]) == 0
116131

117132

118-
def test_delete_credential(client, admin_auth_header):
119-
response = client.delete("/api/v2/credentials/1", headers=admin_auth_header)
133+
def test_delete_credential(client, admin_auth_header, credential):
134+
response = client.delete(
135+
f"/api/v2/credentials/{credential}", headers=admin_auth_header
136+
)
120137

121138
assert response.status_code == 204
122139

123-
response = client.get("/api/v2/credentials/1", headers=admin_auth_header)
140+
response = client.get(
141+
f"/api/v2/credentials/{credential}", headers=admin_auth_header
142+
)
124143

125144
assert response.status_code == 404

empire/test/test_download_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ def test_download_download(client, admin_auth_header):
7373
)
7474

7575
assert response.status_code == 200
76-
assert response.headers.get("content-disposition").startswith(
76+
assert response.headers.get("content-disposition").lower().startswith(
7777
'attachment; filename="test-upload-2'
78+
) or response.headers.get("content-disposition").lower().startswith(
79+
"attachment; filename*=utf-8''test-upload-2"
7880
)
7981

8082

empire/test/test_plugin_task_api.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,23 @@
22

33

44
@pytest.fixture(scope="module", autouse=True)
5-
def plugin_task_1(main, db, models, plugin_name):
6-
db.add(
7-
models.PluginTask(
5+
def plugin_task_1(main, session_local, models, plugin_name):
6+
with session_local.begin() as db:
7+
task = models.PluginTask(
88
plugin_id=plugin_name,
99
input="This is the trimmed input for the task.",
1010
input_full="This is the full input for the task.",
1111
user_id=1,
1212
)
13-
)
14-
db.commit()
15-
yield
13+
db.add(task)
14+
db.flush()
15+
16+
task_id = task.id
17+
18+
yield task_id
1619

17-
db.query(models.PluginTask).delete()
18-
db.commit()
20+
with session_local.begin() as db:
21+
db.query(models.PluginTask).delete()
1922

2023

2124
def test_get_tasks_for_plugin_not_found(client, admin_auth_header):
@@ -60,10 +63,11 @@ def test_get_task_for_plugin_not_found(client, admin_auth_header, plugin_name):
6063
)
6164

6265

63-
def test_get_task_for_plugin(client, admin_auth_header, plugin_name, db):
66+
def test_get_task_for_plugin(client, admin_auth_header, plugin_name, db, plugin_task_1):
6467
response = client.get(
65-
f"/api/v2/plugins/{plugin_name}/tasks/1", headers=admin_auth_header
68+
f"/api/v2/plugins/{plugin_name}/tasks/{plugin_task_1}",
69+
headers=admin_auth_header,
6670
)
6771
assert response.status_code == 200
68-
assert response.json()["id"] == 1
72+
assert response.json()["id"] == plugin_task_1
6973
assert response.json()["plugin_id"] == plugin_name

0 commit comments

Comments
 (0)