[Misc] Refactor: a better storage folder structure#93
Conversation
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThis PR refactors the storage layer from a monolithic Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Comment |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Pull Request Overview
This PR refactors the storage module by splitting a monolithic storage.py file (1304 lines) into a well-organized package structure with separate modules for storage units, managers, clients, and factories. Additionally, it standardizes UUID generation across the codebase to use 8-character hex strings for consistency.
Key changes:
- Refactored monolithic storage file into a modular package structure
- Standardized UUID format to
uuid4().hex[:8]across all components for consistent 8-character IDs - Updated import paths in tests to reflect the new module organization
Reviewed Changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| transfer_queue/utils/zmq_utils.py | Updated import and UUID format for request IDs to 8-character hex |
| transfer_queue/storage/simple_backend.py | New file containing StorageUnitData, SimpleStorageUnit, and StorageMetaGroup classes |
| transfer_queue/storage/managers/base.py | New file with TransferQueueStorageManager base class |
| transfer_queue/storage/managers/simple_backend_manager.py | New file with AsyncSimpleStorageManager and helper functions |
| transfer_queue/storage/managers/factory.py | New file with TransferQueueStorageManagerFactory |
| transfer_queue/storage/managers/init.py | Module exports for managers package |
| transfer_queue/storage/clients/factory.py | Empty factory file for future client implementations |
| transfer_queue/storage/clients/init.py | Module exports with placeholder comment |
| transfer_queue/storage/init.py | Main storage package exports |
| transfer_queue/storage.py | Deleted monolithic file (refactored into package) |
| transfer_queue/controller.py | Updated UUID format for controller IDs |
| transfer_queue/client.py | Updated UUID format for socket identities |
| tests/test_client.py | Updated import path for AsyncSimpleStorageManager |
| tests/test_async_storage_manager.py | Updated import paths and fixed variable name typo |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| context = zmq.asyncio.Context() | ||
| address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}" | ||
| identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() |
There was a problem hiding this comment.
The socket identity construction is inconsistent with the pattern used in client.py line 118. Both create dynamic identities for ZMQ DEALER sockets communicating with the same type of servers. Consider extracting this logic into a shared utility function to avoid duplication and ensure consistency.
| except Exception: | ||
| pass |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| except Exception: | |
| pass | |
| except Exception as e: | |
| logger.error(f"[{self.storage_manager_id}]: Error closing socket {sock}: {str(e)}") |
| except Exception: | ||
| pass | ||
| try: | ||
| if hasattr(self, "zmq_context") and self.zmq_context: | ||
| self.zmq_context.term() | ||
| except Exception: | ||
| pass |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| except Exception: | |
| pass | |
| try: | |
| if hasattr(self, "zmq_context") and self.zmq_context: | |
| self.zmq_context.term() | |
| except Exception: | |
| pass | |
| except Exception as e: | |
| logger.error(f"[{self.storage_manager_id}]: Error closing socket: {str(e)}") | |
| try: | |
| if hasattr(self, "zmq_context") and self.zmq_context: | |
| self.zmq_context.term() | |
| except Exception as e: | |
| logger.error(f"[{self.storage_manager_id}]: Error terminating zmq_context: {str(e)}") |
There was a problem hiding this comment.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
transfer_queue/client.py (1)
134-141: Fix socket teardown: double-close and conflicting LINGER.Currently closes the socket twice and sets LINGER to -1 then 0. Use a single close with linger=0.
- if not sock.closed: - sock.setsockopt(zmq.LINGER, -1) - sock.close() - sock.close(linger=0) + if not sock.closed: + sock.close(linger=0)tests/test_async_simple_storage_manager.py (1)
68-88: Typo: assigns_connect_to_controllersinstead of_connect_to_controller.This shadow attribute won’t be used and may mislead future readers. Rename to match the actual method.
- manager._connect_to_controllers = mock_connect + manager._connect_to_controller = mock_connect
🧹 Nitpick comments (14)
transfer_queue/storage/simple_backend.py (1)
205-247: Add graceful shutdown hooks for the ZMQ loop (optional).The infinite poll loop has no exit path; consider adding a stop flag and closing sockets/context to avoid leaks on actor teardown. Also set LINGER=0 on sockets during shutdown to prevent hangs.
tests/test_client.py (2)
270-277: Use a single, consistent patch target for AsyncSimpleStorageManager.You already patched the new path below. Align this earlier patch to the same module to avoid breakage if re-exports change.
- with patch("transfer_queue.storage.AsyncSimpleStorageManager._connect_to_controller"): + with patch("transfer_queue.storage.managers.simple_backend_manager.AsyncSimpleStorageManager._connect_to_controller"):
164-171: Nit: prefer TransferQueueRole enum for role.
role="TransferQueueStorage"is a bare string while the rest of the code usesTransferQueueRole. Using the enum improves consistency and type safety.- self.zmq_server_info = ZMQServerInfo( - role="TransferQueueStorage", + self.zmq_server_info = ZMQServerInfo( + role=TransferQueueRole.STORAGE,transfer_queue/client.py (2)
116-121: Optional: reuse asyncio ZMQ Context to reduce overhead.Per-call
zmq.asyncio.Context()allocation is heavier than necessary. Considerzmq.asyncio.Context.instance().- context = zmq.asyncio.Context() + context = zmq.asyncio.Context.instance()
80-88: Docstring config keys don’t match actual usage.The docstring lists
data_system_controller_infoanddata_system_storage_unit_infos, but code/tests usecontroller_infoandstorage_unit_infos. Update to avoid confusion.- config: Configuration dictionary for the storage manager. Must contain the - following required keys: - - data_system_controller_info: ZMQ server information about the controller - - data_system_storage_unit_infos: ZMQ server information about the storage units + config: Configuration dictionary for the storage manager. Must contain: + - controller_info: ZMQServerInfo for the controller + - storage_unit_infos: dict[str, ZMQServerInfo] for storage unitstransfer_queue/storage/managers/__init__.py (1)
15-23: Re‑exports aligned with new layout.Public surface looks clean and centralized. Consider making all a tuple to avoid accidental mutation.
transfer_queue/storage/__init__.py (1)
15-25: Top‑level exports look good.Consolidated API is clear. Optional: use a tuple for all.
transfer_queue/storage/managers/base.py (1)
52-54: Add explicit teardown to avoid ZMQ resource leaks.Context and sockets are never closed. Provide a
close()that closes sockets and terms the context; optionally call it from__del__or expose context manager semantics.class TransferQueueStorageManager(ABC): @@ self._connect_to_controller() + def close(self) -> None: + try: + if self.controller_handshake_socket is not None: + self.controller_handshake_socket.close(0) + self.controller_handshake_socket = None + if self.data_status_update_socket is not None: + self.data_status_update_socket.close(0) + self.data_status_update_socket = None + if self.zmq_context is not None: + self.zmq_context.term() + self.zmq_context = None + except Exception: + logger.exception("Error while closing storage manager resources")transfer_queue/utils/zmq_utils.py (1)
107-114: Remove redundant str() and consider 16-char request_id for better collision resistance.The
str()wrapper is unnecessary sinceuuid4().hexalready returns a string. At scale, 8 hex chars (~32 bits) risks collisions; 16 chars (~64 bits) is safer while remaining practical. Since no tests or code validate the format, this is safe to change.Apply:
- request_id=str(uuid4().hex[:8]), + request_id=uuid4().hex[:16],transfer_queue/storage/managers/simple_backend_manager.py (5)
52-61: Guard zero-length and freeze IDs for mapping lambdasPrecompute IDs and length; avoid repeated list(self.storage_unit_infos.keys()) and potential ZeroDivisionError.
- def _build_storage_mapping_functions(self): + def _build_storage_mapping_functions(self): """Build mapping functions for global index to storage unit and local index. @@ - self.global_index_storage_unit_mapping = lambda x: list(self.storage_unit_infos.keys())[ - x % len(self.storage_unit_infos) - ] - self.global_index_local_index_mapping = lambda x: x // len(self.storage_unit_infos) + storage_ids = tuple(self.storage_unit_infos.keys()) + n = len(storage_ids) + if n == 0: + raise ValueError("No storage units registered.") + self.global_index_storage_unit_mapping = lambda x, _ids=storage_ids, _n=n: _ids[x % _n] + self.global_index_local_index_mapping = lambda x, _n=n: x // _n
120-153: Reuse a shared zmq Context instead of per-call ContextsPer-call Context creation/termination is expensive; prefer self.zmq_context = zmq.asyncio.Context.instance() in init and reuse. Terminate in close().
Would you like me to patch init to create self.zmq_context and update the decorator to use it (and skip context.term() in the wrapper)?
282-296: Simplify stacking logic; avoid nested.as_nested_tensor(...).unbind() + stackUse torch.stack when shapes match; fallback to nested tensor or non-tensor stack. Clearer and faster.
- with limit_pytorch_auto_parallel_threads(): - tensor_data = { - field: ( - torch.stack(torch.nested.as_nested_tensor(v).unbind()) - if v - and all(isinstance(item, torch.Tensor) for item in v) - and all(item.shape == v[0].shape for item in v) - else ( - torch.nested.as_nested_tensor(v) - if v and all(isinstance(item, torch.Tensor) for item in v) - else NonTensorStack(*v) - ) - ) - for field, v in ordered_data.items() - } + with limit_pytorch_auto_parallel_threads(): + tensor_data = {} + for field, v in ordered_data.items(): + if not v: + tensor_data[field] = NonTensorStack(*v) + continue + if all(isinstance(item, torch.Tensor) for item in v): + if all(item.shape == v[0].shape for item in v): + tensor_data[field] = torch.stack(v) + else: + tensor_data[field] = torch.nested.as_nested_tensor(v) + else: + tensor_data[field] = NonTensorStack(*v)
317-321: Avoid logging full GET responses (may include payload/PII, large logs)Log only counts/fields at INFO; move verbose to DEBUG.
- logger.info( - f"[{self.storage_manager_id}]: get data response from storage unit " - f"{target_storage_unit}: {response_msg}" - ) + logger.info( + f"[{self.storage_manager_id}]: GET_DATA_RESPONSE from {target_storage_unit} " + f"(fields={fields}, items={len(local_indexes)})" + ) + logger.debug( + f"[{self.storage_manager_id}]: full GET response from {target_storage_unit}: {response_msg}" + )
32-34: Harden logging level parsing from envsetLevel may receive a string; normalize safely to avoid invalid levels.
-logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) +logger = logging.getLogger(__name__) +_lvl = os.getenv("TQ_LOGGING_LEVEL", "WARNING") +logger.setLevel(getattr(logging, str(_lvl).upper(), logging.WARNING))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (14)
tests/test_async_simple_storage_manager.py(3 hunks)tests/test_client.py(1 hunks)transfer_queue/client.py(1 hunks)transfer_queue/controller.py(1 hunks)transfer_queue/storage.py(0 hunks)transfer_queue/storage/__init__.py(1 hunks)transfer_queue/storage/clients/__init__.py(1 hunks)transfer_queue/storage/clients/factory.py(1 hunks)transfer_queue/storage/managers/__init__.py(1 hunks)transfer_queue/storage/managers/base.py(1 hunks)transfer_queue/storage/managers/factory.py(1 hunks)transfer_queue/storage/managers/simple_backend_manager.py(1 hunks)transfer_queue/storage/simple_backend.py(1 hunks)transfer_queue/utils/zmq_utils.py(2 hunks)
💤 Files with no reviewable changes (1)
- transfer_queue/storage.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: pre-commit (3.10)
- GitHub Check: build (3.11)
🔇 Additional comments (7)
transfer_queue/storage/clients/__init__.py (1)
1-15: Package initializer looks good.This file appropriately serves as a scaffolding placeholder for the clients subpackage. The Apache 2.0 license header and explanatory comment are consistent with the PR's modular storage restructuring effort.
transfer_queue/controller.py (1)
69-71: LGTM: controller_id shortened to 8-hex chars.Matches the new identity convention and reduces frame sizes without changing behavior.
transfer_queue/client.py (1)
118-121: Verification confirms repo-wide UUID truncation policy is consistent.The grep results show all identity-related socket creations across the codebase consistently use
uuid4().hex[:8]. No full-length UUID assumptions detected in downstream code.transfer_queue/storage/clients/factory.py (1)
1-13: The original review concern is not valid; no imports reference this file.Verification found no references to
transfer_queue.storage.clients.factoryorTransferQueueStorageClientFactoryanywhere in the codebase. The clients module's__init__.pyexplicitly documents itself as "reserved for future client implementations," making this an intentional placeholder. No dead import paths exist, and there is zero risk of runtime ImportError from this file being empty.Likely an incorrect or invalid review comment.
tests/test_async_simple_storage_manager.py (2)
274-277: Same here — correct patch path for error‑handling test.Accurately targets
create_zmq_socketinmanagers.base. LGTM.
199-203: Patch target migration verified—all instances correct.Verification confirms all
create_zmq_socketpatches now point totransfer_queue.storage.managers.base.create_zmq_socket(lines 200, 275). No remaining patches at old module paths. Changes are consistent and accurate.transfer_queue/storage/managers/simple_backend_manager.py (1)
206-216: No action needed — NonTensorStack is properly available in tensordict 0.10.0NonTensorStack is available at the top level in tensordict 0.10.0 (tensordict.NonTensorStack), and it is the recommended container for sequences of non-tensor objects. The code correctly handles non-tensor fields using this approach with no fallback required.
| socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000)) | ||
|
|
||
| if self.data_status_update_socket in socks: | ||
| response_msg = ZMQMessage.deserialize(self.data_status_update_socket.recv()) | ||
|
|
||
| if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: | ||
| response_received = True | ||
| logger.debug( |
There was a problem hiding this comment.
Async method performs blocking ZMQ I/O — will stall the event loop.
notify_data_update is async but uses blocking poller.poll(...) and socket.recv(). Under load this can freeze the loop for up to TQ_DATA_UPDATE_RESPONSE_TIMEOUT seconds.
Minimal fix without switching to zmq.asyncio: offload blocking calls to a thread executor.
+import asyncio
@@
- while (
+ while (
not response_received # Only one controller to get response from
and time.time() - start_time < TQ_DATA_UPDATE_RESPONSE_TIMEOUT
):
- socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000))
+ loop = asyncio.get_running_loop()
+ socks = dict(
+ await loop.run_in_executor(None, poller.poll, TQ_STORAGE_POLLER_TIMEOUT * 1000)
+ )
@@
- if self.data_status_update_socket in socks:
- response_msg = ZMQMessage.deserialize(self.data_status_update_socket.recv())
+ if self.data_status_update_socket in socks:
+ raw = await loop.run_in_executor(None, self.data_status_update_socket.recv)
+ response_msg = ZMQMessage.deserialize(raw)Alternative (preferred long‑term): migrate to zmq.asyncio.Context and await socket.recv().
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In transfer_queue/storage/managers/base.py around lines 239 to 246, the async
method performs blocking ZMQ I/O (poller.poll and socket.recv) which can block
the event loop; change the implementation to run those blocking calls on a
thread executor (e.g., asyncio.get_running_loop().run_in_executor or
loop.run_in_executor) so the async function stays non-blocking: call
poller.poll(...) and socket.recv() inside a synchronous helper function or
lambda and await run_in_executor for the result, then deserialize and handle the
response on the event loop; alternatively, for a longer-term refactor, replace
the ZMQ Context/socket with zmq.asyncio and use await socket.recv() instead.
| context = zmq.asyncio.Context() | ||
| address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}" | ||
| identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() | ||
| sock = create_zmq_socket(context, zmq.DEALER, identity=identity) | ||
|
|
There was a problem hiding this comment.
Validate port before connect; clearer error if missing
Without this, address becomes tcp://:None and connect fails cryptically.
- context = zmq.asyncio.Context()
- address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}"
+ context = zmq.asyncio.Context()
+ port = server_info.ports.get(socket_name)
+ if port is None:
+ raise RuntimeError(
+ f"StorageUnit {server_info.id} missing port '{socket_name}'. "
+ f"Available: {list(server_info.ports.keys())}"
+ )
+ address = f"tcp://{server_info.ip}:{port}"
identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode()
sock = create_zmq_socket(context, zmq.DEALER, identity=identity)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| context = zmq.asyncio.Context() | |
| address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}" | |
| identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() | |
| sock = create_zmq_socket(context, zmq.DEALER, identity=identity) | |
| context = zmq.asyncio.Context() | |
| port = server_info.ports.get(socket_name) | |
| if port is None: | |
| raise RuntimeError( | |
| f"StorageUnit {server_info.id} missing port '{socket_name}'. " | |
| f"Available: {list(server_info.ports.keys())}" | |
| ) | |
| address = f"tcp://{server_info.ip}:{port}" | |
| identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() | |
| sock = create_zmq_socket(context, zmq.DEALER, identity=identity) |
🤖 Prompt for AI Agents
In transfer_queue/storage/managers/simple_backend_manager.py around lines 120 to
124, the code builds the address using server_info.ports.get(socket_name) which
can be None causing a cryptic connect failure; before formatting the address,
fetch the port into a variable and validate it is not None (and optionally is an
int > 0), and if missing raise a clear ValueError or log an explicit error
including storage_manager_id, server_info.id, socket_name and server_info.ip,
then skip creating the socket/connecting for that server (or propagate the
explicit error) so you never call create_zmq_socket/connect with a None port.
| def get_data(self, fields: list[str], local_indexes: list[int]) -> TensorDict[str, list]: | ||
| """ | ||
| Get data from storage unit according to given fields and local_indexes. | ||
|
|
||
| Args: | ||
| fields: Field names used for getting data. | ||
| local_indexes: Local indexes used for getting data. | ||
|
|
||
| Returns: | ||
| TensorDict with field names as keys, corresponding data list as values. | ||
| """ | ||
| result: dict[str, list] = {} | ||
|
|
||
| for field in fields: | ||
| # Validate field name | ||
| if field not in self.field_data: | ||
| raise ValueError( | ||
| f"StorageUnitData get_data operation receive invalid field: {field} beyond {self.field_data.keys()}" | ||
| ) | ||
|
|
||
| if len(local_indexes) == 1: | ||
| # The unsqueeze op make the shape from n to (1, n) | ||
| gathered_item = self.field_data[field][local_indexes[0]] | ||
| if not isinstance(gathered_item, torch.Tensor): | ||
| result[field] = NonTensorStack(gathered_item) | ||
| else: | ||
| result[field] = gathered_item.unsqueeze(0) | ||
| else: | ||
| gathered_items = list(itemgetter(*local_indexes)(self.field_data[field])) | ||
|
|
||
| if gathered_items: | ||
| all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items) | ||
| if all_tensors: | ||
| result[field] = torch.nested.as_nested_tensor(gathered_items) | ||
| else: | ||
| result[field] = NonTensorStack(*gathered_items) | ||
|
|
||
| return TensorDict(result) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Harden reads: guard against uninitialized slots and set batch_size.
When fetching multiple indexes, if any slot is None, current code tries to stack None values (NonTensorStack/torch.nested) and will fail. Also, returning TensorDict without batch_size can be brittle.
Apply:
- for field in fields:
+ for field in fields:
# Validate field name
if field not in self.field_data:
raise ValueError(
f"StorageUnitData get_data operation receive invalid field: {field} beyond {self.field_data.keys()}"
)
if len(local_indexes) == 1:
# The unsqueeze op make the shape from n to (1, n)
gathered_item = self.field_data[field][local_indexes[0]]
+ if gathered_item is None:
+ raise ValueError(f"Missing data for field '{field}' at index {local_indexes[0]}")
if not isinstance(gathered_item, torch.Tensor):
result[field] = NonTensorStack(gathered_item)
else:
result[field] = gathered_item.unsqueeze(0)
else:
gathered_items = list(itemgetter(*local_indexes)(self.field_data[field]))
- if gathered_items:
+ if gathered_items:
+ if any(x is None for x in gathered_items):
+ missing = [i for i, x in zip(local_indexes, gathered_items) if x is None]
+ raise ValueError(f"Missing data for field '{field}' at indexes {missing}")
all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items)
if all_tensors:
result[field] = torch.nested.as_nested_tensor(gathered_items)
else:
result[field] = NonTensorStack(*gathered_items)
- return TensorDict(result)
+ # Explicit batch size for stability
+ bs = 0 if not fields or not local_indexes else len(local_indexes)
+ return TensorDict(result, batch_size=bs)🤖 Prompt for AI Agents
In transfer_queue/storage/simple_backend.py around lines 66-104, guard reads
against uninitialized slots and ensure returned TensorDict has batch_size: when
fetching a single index, check if gathered_item is None and raise a clear
ValueError (field + index) or otherwise wrap non-tensor values in NonTensorStack
as a sequence; when fetching multiple indexes, detect any None entries in
gathered_items and raise a ValueError listing the missing local indexes (do not
attempt to stack None), filter/validate that gathered_items is non-empty before
stacking, call NonTensorStack with the list (not via splat), and when
constructing the final TensorDict set its batch_size to len(local_indexes).
There was a problem hiding this comment.
Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!
| def put_data(self, field_data: TensorDict[str, list], local_indexes: list[int]) -> None: | ||
| """ | ||
| Put or update data into storage unit according to given field_data and local_indexes. | ||
|
|
||
| Args: | ||
| field_data: Dict with field names as keys, corresponding data in the field as values. | ||
| local_indexes: Local indexes used for putting data. | ||
| """ | ||
| extracted_data = dict(field_data) | ||
|
|
||
| for f, values in extracted_data.items(): | ||
| if f not in self.field_data: | ||
| self.field_data[f] = [None] * self.storage_size | ||
|
|
||
| for i, idx in enumerate(local_indexes): | ||
| if idx < 0 or idx >= self.storage_size: | ||
| raise ValueError( | ||
| f"StorageUnitData put_data operation receive invalid local_index: {idx} beyond " | ||
| f"storage_size: {self.storage_size}" | ||
| ) | ||
|
|
||
| self.field_data[f][idx] = values[i] | ||
|
|
There was a problem hiding this comment.
Critical: dict(TensorDict) misuse; values likely wrong or exception.
extracted_data = dict(field_data) will not reliably produce {str: list} from a TensorDict and can raise or drop values. Use to_dict() (or .items()) and validate lengths to match local_indexes.
Apply:
- extracted_data = dict(field_data)
+ # Accept TensorDict or plain dict[str, list-like]
+ if isinstance(field_data, TensorDict):
+ extracted_data = field_data.to_dict()
+ elif isinstance(field_data, dict):
+ extracted_data = field_data
+ else:
+ raise TypeError(
+ f"field_data must be a TensorDict or dict[str, list-like], got {type(field_data)}"
+ )
- for f, values in extracted_data.items():
+ for f, values in extracted_data.items():
if f not in self.field_data:
self.field_data[f] = [None] * self.storage_size
- for i, idx in enumerate(local_indexes):
+ # Ensure provided values align with target indexes
+ if len(values) != len(local_indexes):
+ raise ValueError(
+ f"Mismatch: {f} has {len(values)} values for {len(local_indexes)} local_indexes"
+ )
+ for i, idx in enumerate(local_indexes):
if idx < 0 or idx >= self.storage_size:
raise ValueError(
f"StorageUnitData put_data operation receive invalid local_index: {idx} beyond "
f"storage_size: {self.storage_size}"
)
self.field_data[f][idx] = values[i]🤖 Prompt for AI Agents
In transfer_queue/storage/simple_backend.py around lines 105 to 127, replace the
unsafe dict(field_data) usage: call the TensorDict's proper exporter (e.g.,
field_data.to_dict() or iterate field_data.items()) to get reliable str->list
pairs, validate each value is a sequence and that len(value) ==
len(local_indexes) before writing, and raise a clear ValueError if a field's
length mismatches or a value is not indexable; keep existing bounds checks for
local_indexes and then assign values[i] to self.field_data[f][idx].
There was a problem hiding this comment.
Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!
uuid4().hex[:8]Summary by CodeRabbit
Refactor
Chores