mirror of
https://mirror.skon.top/github.com/langgenius/dify.git
synced 2026-04-20 15:20:15 +08:00
fix: scope plugin inner API end-user lookup by tenant (#35325)
This commit is contained in:
@@ -20,10 +20,13 @@ class TenantUserPayload(BaseModel):
|
||||
|
||||
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
"""
|
||||
Get current user
|
||||
Get current user.
|
||||
|
||||
NOTE: user_id is not trusted, it could be maliciously set to any value.
|
||||
As a result, it could only be considered as an end user id.
|
||||
As a result, it could only be considered as an end user id. Even when a
|
||||
concrete end-user ID is supplied, lookups must stay tenant-scoped so one
|
||||
tenant cannot bind another tenant's user record into the plugin request
|
||||
context.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
@@ -42,7 +45,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
user_model = session.get(EndUser, user_id)
|
||||
user_model = session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
|
||||
@@ -41,17 +41,22 @@ class TestTenantUserPayload:
|
||||
class TestGetUser:
|
||||
"""Test get_user function"""
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.select")
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_return_existing_user_by_id(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask):
|
||||
def test_should_return_existing_user_by_id(
|
||||
self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask
|
||||
):
|
||||
"""Test returning existing user when found by ID"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user123"
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.get.return_value = mock_user
|
||||
mock_session.scalar.return_value = mock_user
|
||||
mock_query = MagicMock()
|
||||
mock_select.return_value.where.return_value.limit.return_value = mock_query
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
@@ -59,13 +64,45 @@ class TestGetUser:
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
mock_session.get.assert_called_once()
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.select")
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_not_resolve_non_anonymous_users_across_tenants(
|
||||
self,
|
||||
mock_db,
|
||||
mock_sessionmaker,
|
||||
mock_enduser_class,
|
||||
mock_select,
|
||||
app: Flask,
|
||||
):
|
||||
"""Test that explicit user IDs remain scoped to the current tenant."""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
mock_new_user = MagicMock()
|
||||
mock_new_user.tenant_id = "tenant-current"
|
||||
mock_enduser_class.return_value = mock_new_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant-current", "foreign-user-id")
|
||||
|
||||
# Assert
|
||||
assert result == mock_new_user
|
||||
mock_session.get.assert_not_called()
|
||||
mock_session.scalar.assert_called_once()
|
||||
mock_session.add.assert_called_once_with(mock_new_user)
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.select")
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_return_existing_anonymous_user_by_session_id(
|
||||
self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask
|
||||
self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask
|
||||
):
|
||||
"""Test returning existing anonymous user by session_id"""
|
||||
# Arrange
|
||||
@@ -73,8 +110,9 @@ class TestGetUser:
|
||||
mock_user.session_id = "anonymous_session"
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
# non-anonymous path uses session.get(); anonymous uses session.scalar()
|
||||
mock_session.get.return_value = mock_user
|
||||
mock_session.scalar.return_value = mock_user
|
||||
mock_query = MagicMock()
|
||||
mock_select.return_value.where.return_value.limit.return_value = mock_query
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
@@ -83,17 +121,22 @@ class TestGetUser:
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.select")
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_create_new_user_when_not_found(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask):
|
||||
def test_should_create_new_user_when_not_found(
|
||||
self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask
|
||||
):
|
||||
"""Test creating new user when not found in database"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.get.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_new_user = MagicMock()
|
||||
mock_enduser_class.return_value = mock_new_user
|
||||
mock_query = MagicMock()
|
||||
mock_select.return_value.where.return_value.limit.return_value = mock_query
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
@@ -134,7 +177,7 @@ class TestGetUser:
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.get.side_effect = Exception("Database error")
|
||||
mock_session.scalar.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with app.app_context():
|
||||
|
||||
Reference in New Issue
Block a user