Skip to content

Commit

Permalink
[Network Drive] Add Dropdown for Drive type (#1842)
Browse files Browse the repository at this point in the history
  • Loading branch information
praveen-kukreja authored Oct 30, 2023
1 parent c4be35d commit c3a230e
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 45 deletions.
121 changes: 82 additions & 39 deletions connectors/sources/network_drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import fastjsonschema
import smbclient
import winrm
from requests.exceptions import ConnectionError
from smbprotocol.exceptions import SMBException, SMBOSError
from smbprotocol.file_info import (
InfoType,
Expand All @@ -38,7 +39,7 @@
AdvancedRulesValidator,
SyncRuleValidationResult,
)
from connectors.source import BaseDataSource
from connectors.source import BaseDataSource, ConfigurableFieldValueError
from connectors.utils import (
TIKA_SUPPORTED_FILETYPES,
RetryStrategy,
Expand All @@ -59,6 +60,9 @@
RETRIES = 3
RETRY_INTERVAL = 2

WINDOWS = "windows"
LINUX = "linux"


def _prefix_user(user):
return prefix_identity("user", user)
Expand Down Expand Up @@ -247,6 +251,7 @@ def __init__(self, configuration):
self.server_ip = self.configuration["server_ip"]
self.port = self.configuration["server_port"]
self.drive_path = self.configuration["drive_path"]
self.drive_type = self.configuration["drive_type"]
self.identity_mappings = self.configuration["identity_mappings"]
self.session = None
self.security_info = SecurityInfo(self.username, self.password, self.server_ip)
Expand Down Expand Up @@ -297,10 +302,28 @@ def get_default_configuration(cls):
"type": "bool",
"value": False,
},
"drive_type": {
"display": "dropdown",
"label": "Drive type",
"depends_on": [
{"field": "use_document_level_security", "value": True},
],
"options": [
{"label": "Windows", "value": WINDOWS},
{"label": "Linux", "value": LINUX},
],
"order": 7,
"type": "str",
"ui_restrictions": ["advanced"],
"value": WINDOWS,
},
"identity_mappings": {
"label": "Path of CSV file containing users and groups SID (For Linux Network Drive)",
"depends_on": [{"field": "use_document_level_security", "value": True}],
"order": 7,
"depends_on": [
{"field": "use_document_level_security", "value": True},
{"field": "drive_type", "value": LINUX},
],
"order": 8,
"type": "str",
"required": False,
"ui_restrictions": ["advanced"],
Expand Down Expand Up @@ -445,6 +468,8 @@ async def get_content(self, file, timestamp=None, doit=None):
executor=None, func=partial(self.fetch_file_content, path=file["path"])
)

if not content:
return
attachment = content.read()
content.close()
return {
Expand All @@ -454,17 +479,22 @@ async def get_content(self, file, timestamp=None, doit=None):
}

def list_file_permission(self, file_path, file_type, mode, access):
with smbclient.open_file(
file_path,
mode=mode,
buffering=0,
file_type=file_type,
desired_access=access,
) as file:
descriptor = self.security_info.get_descriptor(
file_descriptor=file.fd, info=SECURITY_INFO_DACL
try:
with smbclient.open_file(
file_path,
mode=mode,
buffering=0,
file_type=file_type,
desired_access=access,
) as file:
descriptor = self.security_info.get_descriptor(
file_descriptor=file.fd, info=SECURITY_INFO_DACL
)
return descriptor.get_dacl()["aces"]
except SMBOSError as error:
self._logger.error(
f"Cannot read the contents of file on path:{file_path}. Error {error}"
)
return descriptor.get_dacl()["aces"]

def _dls_enabled(self):
if (
Expand Down Expand Up @@ -521,8 +551,8 @@ def read_user_info_csv(self):
for row in csv_reader:
user_info.append(
{
"username": row[0],
"user_id": row[1],
"name": row[0],
"user_sid": row[1],
"groups": row[2].split(",") if len(row[2]) > 0 else [],
}
)
Expand All @@ -538,36 +568,49 @@ async def get_access_control(self):
return

# This if block fetches users, groups via local csv file path
if self.identity_mappings:
self._logger.info(
f"Fetching all groups and users from configured file path '{self.identity_mappings}'"
)
if self.drive_type == LINUX:
if self.identity_mappings:
self._logger.info(
f"Fetching all groups and users from configured file path '{self.identity_mappings}'"
)

for user in self.read_user_info_csv():
yield await self._user_access_control_doc(
user["name"], user["sid"], user["groups"]
for user in self.read_user_info_csv():
yield await self._user_access_control_doc(
user=user["name"],
sid=user["user_sid"],
groups_info=user["groups"],
)
else:
raise ConfigurableFieldValueError(
"CSV file path cannot be empty. Please provide a valid csv file path."
)
else:
self._logger.info(
f"Fetching all groups and members for drive at path '{self.drive_path}'"
)
groups_info = await asyncio.to_thread(self.security_info.fetch_groups)

groups_members = {}
for group_name, _ in groups_info.items():
groups_members[group_name] = await asyncio.to_thread(
self.security_info.fetch_members, group_name
try:
self._logger.info(
f"Fetching all groups and members for drive at path '{self.drive_path}'"
)
groups_info = await asyncio.to_thread(self.security_info.fetch_groups)

self._logger.info(
f"Fetching all users for drive at path '{self.drive_path}'"
)
users_info = await asyncio.to_thread(self.security_info.fetch_users)
groups_members = {}
for group_name, _ in groups_info.items():
groups_members[group_name] = await asyncio.to_thread(
self.security_info.fetch_members, group_name
)

for user, sid in users_info.items():
yield await self._user_access_control_doc(
user, sid, groups_info, groups_members
self._logger.info(
f"Fetching all users for drive at path '{self.drive_path}'"
)
users_info = await asyncio.to_thread(self.security_info.fetch_users)

for user, sid in users_info.items():
yield await self._user_access_control_doc(
user=user,
sid=sid,
groups_info=groups_info,
groups_members=groups_members,
)
except ConnectionError as exception:
raise ConnectionError("Something went wrong") from exception

async def get_entity_permission(self, file_path, file_type):
if not self._dls_enabled():
Expand All @@ -590,7 +633,7 @@ async def get_entity_permission(self, file_path, file_type):
mode="br",
access=DirectoryAccessMask.READ_CONTROL,
)
for permission in list_permissions:
for permission in list_permissions or []:
if (
permission["ace_type"].value == ACCESS_ALLOWED_TYPE
or permission["mask"].value == ACCESS_MASK_DENIED_WRITE_PERMISSION
Expand Down
17 changes: 15 additions & 2 deletions tests/sources/fixtures/network_drive/connector.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,23 @@
"order": 6,
"ui_restrictions": []
},
"drive_type":{
"display": "dropdown",
"label": "Drive type",
"depends_on": [{"field": "use_document_level_security", "value": true}],
"options": [
{"label": "Windows", "value": "windows"},
{"label": "Linux", "value": "linux"}
],
"order":7,
"type": "str",
"ui_restrictions": ["advanced"],
"value":"windows"
},
"identity_mappings": {
"label": "Path of CSV file containing users and groups SID (For Linux Network Drive)",
"depends_on": [{"field": "use_document_level_security", "value": true}],
"order": 7,
"depends_on": [{"field": "use_document_level_security", "value": true}, {"field": "drive_type", "value": "linux"}],
"order": 8,
"type": "str",
"required": false,
"ui_restrictions": ["advanced"],
Expand Down
39 changes: 35 additions & 4 deletions tests/sources/test_network_drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from connectors.filtering.validation import SyncRuleValidationResult
from connectors.protocol import Filter
from connectors.source import ConfigurableFieldValueError
from connectors.sources.network_drive import (
NASDataSource,
NetworkDriveAdvancedRulesValidator,
Expand All @@ -30,6 +31,9 @@
MAX_CHUNK_SIZE = 65536
ADVANCED_SNIPPET = "advanced_snippet"

WINDOWS = "windows"
LINUX = "linux"


def mock_permission(sid):
mock_response = {}
Expand Down Expand Up @@ -702,6 +706,16 @@ async def test_get_access_control_dls_disabled():
assert len(acl) == 0


@pytest.mark.asyncio
async def test_get_access_control_linux_empty_csv_file_path():
async with create_source(NASDataSource) as source:
source._dls_enabled = MagicMock(return_value=True)
source.drive_type = LINUX
source.identity_mappings = ""
with pytest.raises(ConfigurableFieldValueError):
await anext(source.get_access_control())


@pytest.mark.asyncio
async def test_get_access_control_dls_enabled():
expected_user_access_control = [
Expand All @@ -718,6 +732,7 @@ async def test_get_access_control_dls_enabled():

async with create_source(NASDataSource) as source:
source._dls_enabled = MagicMock(return_value=True)
source.drive_type = WINDOWS
mock_groups = {"Admins": "S-1-5-32-546"}
mock_group_members = {
"Administrator": "S-1-5-21-227823342-1368486282-703244805-500"
Expand Down Expand Up @@ -890,8 +905,8 @@ async def test_read_csv_with_valid_data():
):
user_info = source.read_user_info_csv()
expected_user_info = [
{"username": "user1", "user_id": "S-1", "groups": ["S-11", "S-22"]},
{"username": "user2", "user_id": "S-2", "groups": ["S-22"]},
{"name": "user1", "user_sid": "S-1", "groups": ["S-11", "S-22"]},
{"name": "user2", "user_sid": "S-2", "groups": ["S-22"]},
]
assert user_info == expected_user_info

Expand All @@ -913,8 +928,8 @@ async def test_read_csv_with_empty_groups():
):
user_info = source.read_user_info_csv()
expected_user_info = [
{"username": "user1", "user_id": "1", "groups": []},
{"username": "user2", "user_id": "2", "groups": []},
{"name": "user1", "user_sid": "1", "groups": []},
{"name": "user2", "user_sid": "2", "groups": []},
]
assert user_info == expected_user_info

Expand All @@ -938,3 +953,19 @@ async def test_list_file_permissions(mock_get_descriptor):
)

assert result == mock_dacl["aces"]


@pytest.mark.asyncio
async def test_list_file_permissions_with_inaccessible_file():
with mock.patch("smbclient.open_file", return_value=MagicMock()) as mock_file:
mock_file.side_effect = SMBOSError(ntstatus=0xC0000043, filename="file1.txt")

async with create_source(NASDataSource) as source:
result = source.list_file_permission(
file_path="/path/to/file.txt",
file_type="file",
mode="rb",
access="read",
)

assert result is None

0 comments on commit c3a230e

Please sign in to comment.