Skip to content

Use elaborated cache key and use it for filelock semaphore #1644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 75 additions & 11 deletions src/unitxt/api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import hashlib
import inspect
import json
import os
import random
import time
from datetime import datetime
from functools import lru_cache
from typing import Any, Dict, List, Optional, Union

import filelock
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
from datasets.exceptions import DatasetGenerationError
from huggingface_hub import constants as hf_constants

from .artifact import fetch_artifact
from .card import TaskCard
from .dataclass import to_dict
from .dataset_utils import get_dataset_artifact
from .error_utils import UnitxtError
from .inference import (
Expand Down Expand Up @@ -134,43 +140,101 @@ def create_dataset(
card = TaskCard(loader=LoadFromDictionary(data=data, data_classification_policy=data_classification_policy), task=task)
return load_dataset(card=card, split=split, **kwargs)

def object_to_str_without_addresses(obj):
"""Generates a string representation of a Python object while removing memory address references.

This function is useful for creating consistent and comparable string representations of objects
that would otherwise include memory addresses (e.g., `<object_name at 0x123abc>`), which can vary
between executions. By stripping the memory address, the function ensures that the representation
is stable and independent of the object's location in memory.

Args:
obj: Any Python object to be converted to a string representation.

Returns:
str: A string representation of the object with memory addresses removed if present.

Example:
```python
class MyClass:
pass

obj = MyClass()
print(str(obj)) # "<__main__.MyClass object at 0x7f8b9d4d6e20>"
print(to_str_without_addresses(obj)) # "<__main__.MyClass object>"
```
"""
obj_str = str(obj)
if " at 0x" in obj_str:
obj_str = obj_str.split(" at 0x")[0] + ">"
return obj_str

def _source_to_dataset(
source: SourceOperator,
split=None,
use_cache=False,
streaming=False,
lock_timeout=60, # Timeout in seconds for acquiring the lock
):
from .dataset import Dataset as UnitxtDataset

# Generate a unique signature for the source
source_signature = json.dumps(to_dict(source, object_to_str_without_addresses), sort_keys=True)
config_name = "recipe-" + short_hex_hash(source_signature)
hf_cache_home = hf_constants.HF_HOME
lock_dir = os.path.join(hf_cache_home, "locks")
os.makedirs(lock_dir, exist_ok=True)

# Create a lock file path based on the dataset configuration
lock_file = os.path.join(lock_dir, f"unitxt_{config_name}.lock")

# Add retry logic
max_attempts = 5
base_wait = 5 # seconds

stream = source()

try:
ds_builder = UnitxtDataset(
dataset_name="unitxt",
config_name="recipe-" + short_hex_hash(repr(source)),
config_name=config_name,
version=constants.version,
)

if split is not None:
stream = {split: stream[split]}

ds_builder._generators = stream

ds_builder.download_and_prepare(
verification_mode="no_checks",
download_mode=None if use_cache else "force_redownload",
)

if streaming:
return ds_builder.as_streaming_dataset(split=split)
for attempt in range(max_attempts):
# Create a file lock with appropriate timeout
lock = filelock.FileLock(lock_file, timeout=300) # 5 minutes

return ds_builder.as_dataset(
split=split, run_post_process=False, verification_mode="no_checks"
)
try:
with lock:
ds_builder.download_and_prepare(
verification_mode="no_checks",
download_mode=None if use_cache else "force_redownload",
)

# If we reach here, the lock was successfully acquired and released
if streaming:
return ds_builder.as_streaming_dataset(split=split)
return ds_builder.as_dataset(
split=split, run_post_process=False, verification_mode="no_checks"
)

except filelock.Timeout:
if attempt < max_attempts - 1: # Not the last attempt
wait_time = base_wait * (2 ** attempt) + random.uniform(0, 1)
time.sleep(wait_time)
else:
raise TimeoutError(f"Could not acquire lock for {config_name} after {max_attempts} attempts")

except DatasetGenerationError as e:
raise e.__cause__


def load_dataset(
dataset_query: Optional[str] = None,
split: Optional[str] = None,
Expand Down
45 changes: 45 additions & 0 deletions src/unitxt/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,51 @@ def _asdict_inner(obj):

return copy.deepcopy(obj)

def to_dict(obj, func=copy.deepcopy, _visited=None):
"""Recursively converts an object into a dictionary representation while avoiding infinite recursion due to circular references.

Args:
obj: Any Python object to be converted into a dictionary-like structure.
func (Callable, optional): A function applied to non-iterable objects. Defaults to `copy.deepcopy`.
_visited (set, optional): A set of object IDs used to track visited objects and prevent infinite recursion.

Returns:
dict: A dictionary representation of the input object, with supported collections and dataclasses
recursively processed.

Notes:
- Supports dataclasses, named tuples, lists, tuples, and dictionaries.
- Circular references are detected using object IDs and replaced by `func(obj)`.
- Named tuples retain their original type instead of being converted to dictionaries.
"""
# Initialize visited set on first call
if _visited is None:
_visited = set()

# Get object ID to track visited objects
obj_id = id(obj)

# If we've seen this object before, return a placeholder to avoid infinite recursion
if obj_id in _visited:
return func(obj)

# For mutable objects, add to visited set before recursing
if isinstance(obj, (dict, list)) or is_dataclass(obj) or (isinstance(obj, tuple) and hasattr(obj, "_fields")):
_visited.add(obj_id)

if is_dataclass(obj):
return {field.name: to_dict(getattr(obj, field.name), func, _visited) for field in fields(obj)}

if isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple
return type(obj)(*[to_dict(v, func, _visited) for v in obj])

if isinstance(obj, (list, tuple)):
return type(obj)([to_dict(v, func, _visited) for v in obj])

if isinstance(obj, dict):
return type(obj)({to_dict(k, func, _visited): to_dict(v, func, _visited) for k, v in obj.items()})

return func(obj)

class DataclassMeta(ABCMeta):
"""Metaclass for Dataclass.
Expand Down
Loading