Skip to content

API Reference

This page is auto-generated from the public API exported by slurptuna.

Public API

slurptuna

Functions

execution_mode(mode)

Validate and normalize an execution mode.

Parameters:

Name Type Description Default
mode str

One of the supported strings: "single" or "distributed".

required

Returns:

Type Description
ExecutionMode

The corresponding ExecutionMode enum value.

Raises:

Type Description
ValueError

If mode is not one of the supported values.

Source code in src/slurptuna/api.py
def execution_mode(mode: str) -> ExecutionMode:
    """Validate and normalize an execution mode.

    Args:
        mode: One of the supported strings: `"single"` or `"distributed"`.

    Returns:
        The corresponding `ExecutionMode` enum value.

    Raises:
        ValueError: If `mode` is not one of the supported values.
    """
    normalized = str(mode).lower()
    if normalized == "single":
        return ExecutionMode.SINGLE
    if normalized == "distributed":
        return ExecutionMode.DISTRIBUTED
    raise ValueError("mode must be one of: single, distributed")

loss(*, name=None, description=None, parameter_space=None, seed_start=0)

Decorator to define a loss function for hyperparameter optimization.

The decorated function should accept params (dict of hyperparameters), seed (int), and optionally a context dict. It should return either: - A scalar loss value (float) - A dict of scalar losses (will be averaged across entries for shared optimization)

Parameters:

Name Type Description Default
name str | None

Unique identifier for this loss. Required.

None
description str | None

Human-readable description of what this loss represents. Required.

None
parameter_space dict[str, ParamSpec] | None

Dict mapping parameter names to search specs. Required. Specs can be tuples (min, max) for continuous ranges or SearchParam objects for more control over type (int/categorical) and bounds. Example: {"alpha": (0.0, 1.0), "lr": search_param(range=(1e-4, 1e-2), dtype="log")}

None
seed_start int

Starting seed number for this loss. Default 0.

0

Returns:

Type Description

The decorated function as a LossDefinition ready for optimize_run() or optimize_entries().

Example

@loss( name="my_model_loss", description="Fit model parameters to training data", parameter_space={"learning_rate": (1e-4, 1e-2), "batch_size": search_param(range=(8, 256), dtype="int")} ) def my_model_loss(params, seed, context): # ... model training logic ... return mean_squared_error

Source code in src/slurptuna/registry.py
def loss(
    *,
    name: str | None = None,
    description: str | None = None,
    parameter_space: dict[str, ParamSpec] | None = None,
    seed_start: int = 0,
):
    """Decorator to define a loss function for hyperparameter optimization.

    The decorated function should accept `params` (dict of hyperparameters), `seed` (int),
    and optionally a `context` dict. It should return either:
    - A scalar loss value (float)
    - A dict of scalar losses (will be averaged across entries for shared optimization)

    Args:
        name: Unique identifier for this loss. Required.
        description: Human-readable description of what this loss represents. Required.
        parameter_space: Dict mapping parameter names to search specs. Required.
            Specs can be tuples (min, max) for continuous ranges or SearchParam objects
            for more control over type (int/categorical) and bounds.
            Example: {"alpha": (0.0, 1.0), "lr": search_param(range=(1e-4, 1e-2), dtype="log")}
        seed_start: Starting seed number for this loss. Default 0.

    Returns:
        The decorated function as a LossDefinition ready for optimize_run() or optimize_entries().

    Example:
        @loss(
            name="my_model_loss",
            description="Fit model parameters to training data",
            parameter_space={"learning_rate": (1e-4, 1e-2), "batch_size": search_param(range=(8, 256), dtype="int")}
        )
        def my_model_loss(params, seed, context):
            # ... model training logic ...
            return mean_squared_error
    """
    missing_fields: list[str] = []
    if not name:
        missing_fields.append("name")
    if not description:
        missing_fields.append("description")
    if not parameter_space:
        missing_fields.append("parameter_space")

    if missing_fields:
        missing_str = ", ".join(missing_fields)
        raise ValueError(
            "@loss requires non-empty metadata fields: "
            f"{missing_str}. Example: "
            "@loss(name='my_loss', description='...', parameter_space={'x': (0.0, 1.0)})"
        )

    def _wrap(fn: SeedLossFn) -> LossDefinition:
        try:
            src = inspect.getfile(fn)
        except (TypeError, OSError):
            src = None
        call_mode = _resolve_seed_loss_call_mode(fn)
        return register_loss(
            LossDefinition(
                name=name,
                description=description,
                parameter_space=normalize_parameter_space(parameter_space),
                seed_loss_fn=fn,
                seed_loss_call_mode=call_mode,
                seed_start=seed_start,
                source_file=src,
            )
        )

    return _wrap

optimize_entries(loss, *, entry_ids, n_trials=10, seeds=None, n_seeds=None, chunk_size=None, num_chunks=None, random_seed=123, direction='minimize', mode=ExecutionMode.SINGLE, run_root='runs', run_name_prefix=None, loss_module=None, slurm_poll_seconds=15, slurm_timeout_minutes=120, cpus_per_task=1, max_concurrent_trials=1, array_parallelism_limit=None, worker_parallelism=1, max_concurrent_entries=None, worker_time_limit=timedelta(hours=2), slurm_qos='short', trial_retry_attempts=1, fail_on_chunk_error=True, use_processes=False, mem_per_cpu='2G', forward_sys_argv_to_workers=True)

Optimize independent fits for each entry, returning per-entry best parameters.

Use this for participant-wise, condition-wise, or other per-entry fitting where each entry gets its own separate optimization study. The loss function receives the current entry_id in its context dict, allowing entry-specific behavior.

Parameters:

Name Type Description Default
loss LossDefinition

LossDefinition created with the @loss decorator.

required
entry_ids Iterable[str]

Iterable of entry identifiers (e.g., participant IDs, condition names). Each entry gets its own optimization study.

required
n_trials int

Number of optimization trials per entry. Default 10.

10
seeds Iterable[int] | None

Explicit iterable of seed IDs (takes priority over n_seeds).

None
n_seeds int | None

Number of contiguous seeds per entry. Default derived from chunk_size/num_chunks.

None
chunk_size int | None

Seeds per distributed task (for mode=DISTRIBUTED). Default from loss metadata.

None
num_chunks int | None

Number of chunks per trial (for mode=DISTRIBUTED). Default auto from n_seeds/chunk_size.

None
random_seed int

Base seed for TPESampler; each entry gets random_seed + entry_index. Default 123.

123
direction str

Optimization direction: "minimize" or "maximize". Default "minimize".

'minimize'
mode ExecutionMode

ExecutionMode.SINGLE (in-process) or ExecutionMode.DISTRIBUTED (Slurm arrays). Default SINGLE.

SINGLE
run_root str | Path

Root directory for all output. Default "runs".

'runs'
run_name_prefix str | None

Prefix for the parent directory containing all entry runs. Auto-generated if not provided.

None
loss_module str | None

Module path for loss function (required for DISTRIBUTED mode if loss not in main).

None
slurm_poll_seconds int

Polling interval for Slurm job status. Default 15.

15
slurm_timeout_minutes int

Maximum wait time for distributed jobs. Default 120.

120
cpus_per_task int

CPUs per Slurm task (DISTRIBUTED only). Default 1.

1
max_concurrent_trials int

Number of trials per entry to run in parallel. Default 1.

1
array_parallelism_limit int | None

Max concurrent Slurm array jobs across all entries (DISTRIBUTED). Default unlimited.

None
worker_parallelism int

Number of seeds in parallel per worker. Default 1.

1
max_concurrent_entries int | None

Number of entries to optimize in parallel. Default all entries.

None
worker_time_limit timedelta

Max wall time per distributed task. Default 2 hours.

timedelta(hours=2)
slurm_qos str | None

Optional Slurm QoS name passed to sbatch --qos. Default "short". Set to None to omit QoS entirely.

'short'
trial_retry_attempts int

Retry failed trials this many times. Default 1 (no retries).

1
fail_on_chunk_error bool

Whether to fail immediately if any chunk fails in distributed mode. Default True.

True
use_processes bool

Use ProcessPoolExecutor instead of ThreadPoolExecutor for seed parallelism. Default False (threads). See optimize_run() for full details on when to prefer processes over threads.

False
mem_per_cpu str

Memory per CPU for Slurm chunk tasks (DISTRIBUTED only). Default "2G".

'2G'
forward_sys_argv_to_workers bool

Forward the launcher's sys.argv into distributed worker module import context. Default True.

True

Returns:

Type Description
MultiOptimizeResult

MultiOptimizeResult containing:

MultiOptimizeResult
  • best_params_by_entry: Dict mapping entry_id -> best parameters
MultiOptimizeResult
  • results_by_entry: Dict mapping entry_id -> OptimizeResult for each entry
MultiOptimizeResult
  • entries: List of all entry IDs
MultiOptimizeResult
  • n_entries: Total number of entries
MultiOptimizeResult
  • mode: ExecutionMode used
Source code in src/slurptuna/api.py
def optimize_entries(
    loss: LossDefinition,
    *,
    entry_ids: Iterable[str],
    n_trials: int = 10,
    seeds: Iterable[int] | None = None,
    n_seeds: int | None = None,
    chunk_size: int | None = None,
    num_chunks: int | None = None,
    random_seed: int = 123,
    direction: str = "minimize",
    mode: ExecutionMode = ExecutionMode.SINGLE,
    run_root: str | Path = "runs",
    run_name_prefix: str | None = None,
    loss_module: str | None = None,
    slurm_poll_seconds: int = 15,
    slurm_timeout_minutes: int = 120,
    cpus_per_task: int = 1,
    max_concurrent_trials: int = 1,
    array_parallelism_limit: int | None = None,
    worker_parallelism: int = 1,
    max_concurrent_entries: int | None = None,
    worker_time_limit: timedelta = timedelta(hours=2),
    slurm_qos: str | None = "short",
    trial_retry_attempts: int = 1,
    fail_on_chunk_error: bool = True,
    use_processes: bool = False,
    mem_per_cpu: str = "2G",
    forward_sys_argv_to_workers: bool = True,
) -> MultiOptimizeResult:
    """Optimize independent fits for each entry, returning per-entry best parameters.

    Use this for participant-wise, condition-wise, or other per-entry fitting where each
    entry gets its own separate optimization study. The loss function receives the current
    entry_id in its context dict, allowing entry-specific behavior.

    Args:
        loss: LossDefinition created with the @loss decorator.
        entry_ids: Iterable of entry identifiers (e.g., participant IDs, condition names).
            Each entry gets its own optimization study.
        n_trials: Number of optimization trials per entry. Default 10.
        seeds: Explicit iterable of seed IDs (takes priority over n_seeds).
        n_seeds: Number of contiguous seeds per entry. Default derived from chunk_size/num_chunks.
        chunk_size: Seeds per distributed task (for mode=DISTRIBUTED). Default from loss metadata.
        num_chunks: Number of chunks per trial (for mode=DISTRIBUTED). Default auto from n_seeds/chunk_size.
        random_seed: Base seed for TPESampler; each entry gets random_seed + entry_index. Default 123.
        direction: Optimization direction: "minimize" or "maximize". Default "minimize".
        mode: ExecutionMode.SINGLE (in-process) or ExecutionMode.DISTRIBUTED (Slurm arrays). Default SINGLE.
        run_root: Root directory for all output. Default "runs".
        run_name_prefix: Prefix for the parent directory containing all entry runs. Auto-generated if not provided.
        loss_module: Module path for loss function (required for DISTRIBUTED mode if loss not in __main__).
        slurm_poll_seconds: Polling interval for Slurm job status. Default 15.
        slurm_timeout_minutes: Maximum wait time for distributed jobs. Default 120.
        cpus_per_task: CPUs per Slurm task (DISTRIBUTED only). Default 1.
        max_concurrent_trials: Number of trials per entry to run in parallel. Default 1.
        array_parallelism_limit: Max concurrent Slurm array jobs across all entries (DISTRIBUTED). Default unlimited.
        worker_parallelism: Number of seeds in parallel per worker. Default 1.
        max_concurrent_entries: Number of entries to optimize in parallel. Default all entries.
        worker_time_limit: Max wall time per distributed task. Default 2 hours.
        slurm_qos: Optional Slurm QoS name passed to `sbatch --qos`. Default "short".
            Set to None to omit QoS entirely.
        trial_retry_attempts: Retry failed trials this many times. Default 1 (no retries).
        fail_on_chunk_error: Whether to fail immediately if any chunk fails in distributed mode. Default True.
        use_processes: Use ProcessPoolExecutor instead of ThreadPoolExecutor for seed parallelism.
            Default False (threads). See optimize_run() for full details on when to prefer processes
            over threads.
        mem_per_cpu: Memory per CPU for Slurm chunk tasks (DISTRIBUTED only). Default "2G".
        forward_sys_argv_to_workers: Forward the launcher's `sys.argv` into distributed worker
            module import context. Default True.

    Returns:
        MultiOptimizeResult containing:
        - best_params_by_entry: Dict mapping entry_id -> best parameters
        - results_by_entry: Dict mapping entry_id -> OptimizeResult for each entry
        - entries: List of all entry IDs
        - n_entries: Total number of entries
        - mode: ExecutionMode used
    """

    entry_list = [str(x) for x in entry_ids]
    if not entry_list:
        raise ValueError("entry_ids must not be empty")

    sanitized = [_sanitize_entry_label(entry) for entry in entry_list]
    if len(set(sanitized)) != len(sanitized):
        raise ValueError("entry_ids collapse to duplicate run labels after sanitization")
    if max_concurrent_entries is not None and max_concurrent_entries <= 0:
        raise ValueError("max_concurrent_entries must be > 0 when provided")

    resolved_mode = _coerce_mode(mode)

    run_root_path = Path(run_root)
    if run_name_prefix is not None:
        parent_name = run_name_prefix
    else:
        parent_name = _next_versioned_run_name(run_root_path, f"{loss.name}_entries")
    parent_dir = run_root_path / parent_name
    parent_dir.mkdir(parents=True, exist_ok=True)

    def _run_one(idx: int, entry: str) -> tuple[str, OptimizeResult]:
        result = optimize_run(
            loss,
            n_trials=n_trials,
            seeds=seeds,
            n_seeds=n_seeds,
            chunk_size=chunk_size,
            num_chunks=num_chunks,
            random_seed=random_seed + idx,
            entry_id=entry,
            direction=direction,
            mode=resolved_mode,
            run_root=parent_dir,
            run_name=sanitized[idx],
            loss_module=loss_module,
            slurm_poll_seconds=slurm_poll_seconds,
            slurm_timeout_minutes=slurm_timeout_minutes,
            cpus_per_task=cpus_per_task,
            max_concurrent_trials=max_concurrent_trials,
            array_parallelism_limit=array_parallelism_limit,
            worker_parallelism=worker_parallelism,
            worker_time_limit=worker_time_limit,
            slurm_qos=slurm_qos,
            trial_retry_attempts=trial_retry_attempts,
            fail_on_chunk_error=fail_on_chunk_error,
            use_processes=use_processes,
            mem_per_cpu=mem_per_cpu,
            forward_sys_argv_to_workers=forward_sys_argv_to_workers,
        )
        return entry, result

    workers = min(len(entry_list), max_concurrent_entries or len(entry_list))
    results_by_entry: dict[str, OptimizeResult] = {}
    with ThreadPoolExecutor(max_workers=workers) as pool:
        futures = [pool.submit(_run_one, idx, entry) for idx, entry in enumerate(entry_list)]
        for fut in as_completed(futures):
            entry, result = fut.result()
            results_by_entry[entry] = result

            # Write summary.json progressively as each entry completes
            best_params_by_entry = {k: dict(v.best_params) for k, v in results_by_entry.items()}
            best_values_by_entry = {k: v.best_value for k, v in results_by_entry.items()}
            (parent_dir / "summary.json").write_text(
                json.dumps(
                    {
                        "entries": entry_list,
                        "best_params_by_entry": best_params_by_entry,
                        "best_values_by_entry": best_values_by_entry,
                    },
                    indent=2,
                ),
                encoding="utf-8",
            )

    best_params_by_entry = {k: dict(v.best_params) for k, v in results_by_entry.items()}
    best_values_by_entry = {k: v.best_value for k, v in results_by_entry.items()}

    return MultiOptimizeResult(
        loss_name=loss.name,
        n_entries=len(entry_list),
        entries=entry_list,
        results_by_entry=results_by_entry,
        best_values_by_entry=best_values_by_entry,
        best_params_by_entry=best_params_by_entry,
        mode=resolved_mode,
        run_dir=str(parent_dir),
    )

optimize_run(loss, *, n_trials=10, seeds=None, n_seeds=None, chunk_size=None, num_chunks=None, random_seed=123, entry_id=None, direction='minimize', mode=ExecutionMode.SINGLE, run_root='runs', run_name=None, loss_module=None, slurm_poll_seconds=15, slurm_timeout_minutes=120, cpus_per_task=1, mem_per_cpu='2G', max_concurrent_trials=1, array_parallelism_limit=None, worker_parallelism=1, worker_time_limit=timedelta(hours=2), slurm_qos='short', trial_retry_attempts=1, fail_on_chunk_error=True, use_processes=False, forward_sys_argv_to_workers=True)

Optimize hyperparameters for a single shared fit.

Runs Bayesian optimization on a loss function. For averaging across multiple entries (participants, conditions, etc.), return a dict from your loss function. For per-entry optimization, use optimize_entries() instead.

Parameters:

Name Type Description Default
loss LossDefinition

LossDefinition created with the @loss decorator.

required
n_trials int

Number of optimization trials to run. Default 10.

10
seeds Iterable[int] | None

Explicit iterable of seed integer IDs (takes priority over n_seeds).

None
n_seeds int | None

Number of contiguous seeds starting from loss.seed_start. Default derived from chunk_size/num_chunks.

None
chunk_size int | None

Seeds per distributed task (for mode=DISTRIBUTED). Default from loss metadata.

None
num_chunks int | None

Number of chunks per trial (for mode=DISTRIBUTED). Default auto from n_seeds/chunk_size.

None
random_seed int

Seed for TPESampler (Optuna). Default 123.

123
entry_id str | None

Optional entry/participant ID passed to loss context. Only used in optimize_entries().

None
direction str

Optimization direction: "minimize" or "maximize". Default "minimize".

'minimize'
mode ExecutionMode

ExecutionMode.SINGLE (in-process) or ExecutionMode.DISTRIBUTED (Slurm arrays). Default SINGLE.

SINGLE
run_root str | Path

Root directory for output runs. Default "runs".

'runs'
run_name str | None

Name of this run directory. Auto-generated if not provided.

None
loss_module str | None

Module path for loss function (required for DISTRIBUTED mode if loss not in main).

None
slurm_poll_seconds int

Polling interval for Slurm job status. Default 15.

15
slurm_timeout_minutes int

Maximum wait time for distributed job. Default 120.

120
cpus_per_task int

CPUs per Slurm task (DISTRIBUTED only). Default 1.

1
max_concurrent_trials int

Number of trials to run in parallel. Default 1.

1
array_parallelism_limit int | None

Max concurrent Slurm array jobs (DISTRIBUTED). Default unlimited.

None
worker_parallelism int

Number of seeds in parallel per worker. Default 1.

1
worker_time_limit timedelta

Max wall time per distributed task. Default 2 hours.

timedelta(hours=2)
slurm_qos str | None

Optional Slurm QoS name passed to sbatch --qos. Default "short". Set to None to omit QoS entirely.

'short'
trial_retry_attempts int

Retry failed trials this many times. Default 1 (no retries).

1
fail_on_chunk_error bool

Whether to fail immediately if any chunk fails in distributed mode. Default True.

True
use_processes bool

Use ProcessPoolExecutor instead of ThreadPoolExecutor for seed parallelism. Default False (threads). Processes bypass Python's GIL, giving true parallelism even for pure-Python loss functions. The trade-off is higher overhead per seed evaluation due to inter-process pickling of the loss function, params, and results — so this is only worth enabling when individual seed evaluations are slow enough that the pickling cost is negligible (roughly >10 ms per seed). For numpy/scipy-heavy losses, threads are usually sufficient because numpy already releases the GIL. For DISTRIBUTED mode, each Slurm task is already a separate process, so this controls parallelism within each task.

False
forward_sys_argv_to_workers bool

Forward the launcher's sys.argv into distributed worker module import context so loss modules that read argv at import time behave consistently. Default True.

True

Returns:

Type Description
OptimizeResult

OptimizeResult with best_value, best_params, study metadata, and run_dir path.

Source code in src/slurptuna/api.py
def optimize_run(
    loss: LossDefinition,
    *,
    n_trials: int = 10,
    seeds: Iterable[int] | None = None,
    n_seeds: int | None = None,
    chunk_size: int | None = None,
    num_chunks: int | None = None,
    random_seed: int = 123,
    entry_id: str | None = None,
    direction: str = "minimize",
    mode: ExecutionMode = ExecutionMode.SINGLE,
    run_root: str | Path = "runs",
    run_name: str | None = None,
    loss_module: str | None = None,
    slurm_poll_seconds: int = 15,
    slurm_timeout_minutes: int = 120,
    cpus_per_task: int = 1,
    mem_per_cpu: str = "2G",
    max_concurrent_trials: int = 1,
    array_parallelism_limit: int | None = None,
    worker_parallelism: int = 1,
    worker_time_limit: timedelta = timedelta(hours=2),
    slurm_qos: str | None = "short",
    trial_retry_attempts: int = 1,
    fail_on_chunk_error: bool = True,
    use_processes: bool = False,
    forward_sys_argv_to_workers: bool = True,
) -> OptimizeResult:
    """Optimize hyperparameters for a single shared fit.

    Runs Bayesian optimization on a loss function. For averaging across multiple entries
    (participants, conditions, etc.), return a dict from your loss function. For per-entry
    optimization, use optimize_entries() instead.

    Args:
        loss: LossDefinition created with the @loss decorator.
        n_trials: Number of optimization trials to run. Default 10.
        seeds: Explicit iterable of seed integer IDs (takes priority over n_seeds).
        n_seeds: Number of contiguous seeds starting from loss.seed_start. Default derived from chunk_size/num_chunks.
        chunk_size: Seeds per distributed task (for mode=DISTRIBUTED). Default from loss metadata.
        num_chunks: Number of chunks per trial (for mode=DISTRIBUTED). Default auto from n_seeds/chunk_size.
        random_seed: Seed for TPESampler (Optuna). Default 123.
        entry_id: Optional entry/participant ID passed to loss context. Only used in optimize_entries().
        direction: Optimization direction: "minimize" or "maximize". Default "minimize".
        mode: ExecutionMode.SINGLE (in-process) or ExecutionMode.DISTRIBUTED (Slurm arrays). Default SINGLE.
        run_root: Root directory for output runs. Default "runs".
        run_name: Name of this run directory. Auto-generated if not provided.
        loss_module: Module path for loss function (required for DISTRIBUTED mode if loss not in __main__).
        slurm_poll_seconds: Polling interval for Slurm job status. Default 15.
        slurm_timeout_minutes: Maximum wait time for distributed job. Default 120.
        cpus_per_task: CPUs per Slurm task (DISTRIBUTED only). Default 1.
        max_concurrent_trials: Number of trials to run in parallel. Default 1.
        array_parallelism_limit: Max concurrent Slurm array jobs (DISTRIBUTED). Default unlimited.
        worker_parallelism: Number of seeds in parallel per worker. Default 1.
        worker_time_limit: Max wall time per distributed task. Default 2 hours.
        slurm_qos: Optional Slurm QoS name passed to `sbatch --qos`. Default "short".
            Set to None to omit QoS entirely.
        trial_retry_attempts: Retry failed trials this many times. Default 1 (no retries).
        fail_on_chunk_error: Whether to fail immediately if any chunk fails in distributed mode. Default True.
        use_processes: Use ProcessPoolExecutor instead of ThreadPoolExecutor for seed parallelism.
            Default False (threads). Processes bypass Python's GIL, giving true parallelism even for
            pure-Python loss functions. The trade-off is higher overhead per seed evaluation due to
            inter-process pickling of the loss function, params, and results — so this is only worth
            enabling when individual seed evaluations are slow enough that the pickling cost is
            negligible (roughly >10 ms per seed). For numpy/scipy-heavy losses, threads are usually
            sufficient because numpy already releases the GIL. For DISTRIBUTED mode, each Slurm task
            is already a separate process, so this controls parallelism within each task.
        forward_sys_argv_to_workers: Forward the launcher's `sys.argv` into distributed worker
            module import context so loss modules that read argv at import time behave consistently.
            Default True.

    Returns:
        OptimizeResult with best_value, best_params, study metadata, and run_dir path.
    """

    register_loss(loss, overwrite=True)
    resolved_mode = _coerce_mode(mode)
    if max_concurrent_trials <= 0:
        raise ValueError("max_concurrent_trials must be > 0")
    if worker_parallelism <= 0:
        raise ValueError("worker_parallelism must be > 0")
    if array_parallelism_limit is not None and array_parallelism_limit <= 0:
        raise ValueError("array_parallelism_limit must be > 0 when provided")
    if worker_time_limit.total_seconds() <= 0:
        raise ValueError("worker_time_limit must be > 0 seconds")
    if slurm_qos is not None and not str(slurm_qos).strip():
        raise ValueError("slurm_qos must be a non-empty string or None")

    active_seeds, resolved_chunk_size, resolved_num_chunks = _resolve_seed_layout(
        loss=loss,
        seeds=seeds,
        n_seeds=n_seeds,
        chunk_size=chunk_size or 100,
        num_chunks=num_chunks or 10,
    )

    run_root_path = Path(run_root)
    chosen_run_name = run_name or _next_versioned_run_name(run_root_path, loss.name)
    run_dir = run_root_path / chosen_run_name
    run_dir.mkdir(parents=True, exist_ok=True)

    started_at_utc = datetime.now(timezone.utc)

    (run_dir / "meta.json").write_text(
        json.dumps(
            {
                "loss_name": loss.name,
                "run_name": chosen_run_name,
                "run_dir": str(run_dir),
                "mode": resolved_mode.value,
                "n_trials": n_trials,
                "entry_id": entry_id,
                "started_at_utc": started_at_utc.isoformat(),
                "seed_start": active_seeds[0],
                "seed_end": active_seeds[-1],
                "n_seeds": len(active_seeds),
                "chunk_size": resolved_chunk_size,
                "num_chunks": resolved_num_chunks,
                "max_concurrent_trials": max_concurrent_trials,
                "array_parallelism_limit": array_parallelism_limit,
                "worker_parallelism": worker_parallelism,
                "worker_time_limit_seconds": int(worker_time_limit.total_seconds()),
                "slurm_qos": slurm_qos,
            },
            indent=2,
        ),
        encoding="utf-8",
    )

    storage = f"sqlite:///{(run_dir / 'optuna.db').resolve()}"
    sampler = optuna.samplers.TPESampler(seed=random_seed)
    study = optuna.create_study(
        direction=direction,
        sampler=sampler,
        storage=storage,
        load_if_exists=True,
        study_name=loss.name,
    )

    # Callback to write summary.json after each trial completes
    def _write_summary_callback(study: optuna.study.Study, trial: optuna.trial.Trial) -> None:
        (run_dir / "summary.json").write_text(
            json.dumps(
                {
                    "best_value": float(study.best_value),
                    "best_params": dict(study.best_params),
                },
                indent=2,
            ),
            encoding="utf-8",
        )

    if resolved_mode == ExecutionMode.SINGLE:
        if use_processes and max_concurrent_trials > 1:
            raise ValueError("use_processes=True requires max_concurrent_trials=1 in mode=single")

        def objective_local(trial: optuna.trial.Trial) -> float:
            return _objective_local(
                trial=trial,
                loss=loss,
                active_seeds=active_seeds,
                entry_id=entry_id,
                worker_parallelism=worker_parallelism,
                use_processes=use_processes,
            )

        study.optimize(objective_local, n_trials=n_trials, n_jobs=max_concurrent_trials, callbacks=[_write_summary_callback])
    elif resolved_mode == ExecutionMode.DISTRIBUTED:
        resolved_loss_module = loss_module or getattr(loss.seed_loss_fn, "__module__", "")
        if (not resolved_loss_module) or resolved_loss_module == "__main__":
            resolved_loss_module = loss.source_file or ""
        if not resolved_loss_module:
            raise ValueError(
                "For mode=ExecutionMode.DISTRIBUTED, provide loss_module or define loss in an importable module"
            )

        if len(active_seeds) % resolved_chunk_size != 0:
            raise ValueError("For mode=ExecutionMode.DISTRIBUTED, number of seeds must be divisible by chunk_size")

        slurm_cfg = SlurmConfig(
            poll_seconds=slurm_poll_seconds,
            timeout_minutes=slurm_timeout_minutes,
            cpus_per_task=cpus_per_task,
            mem_per_cpu=mem_per_cpu,
            array_parallelism_limit=array_parallelism_limit,
            worker_time_limit=worker_time_limit,
            qos=slurm_qos,
            fail_on_chunk_error=fail_on_chunk_error,
        )

        project_root = Path.cwd()
        python_executable = sys.executable
        worker_module_argv = list(sys.argv) if forward_sys_argv_to_workers else None

        def objective_slurm(trial: optuna.trial.Trial) -> float:
            params = {
                name: suggest_param(trial, name, spec)
                for name, spec in loss.parameter_space.items()
            }

            last_error: Exception | None = None
            for attempt in range(trial_retry_attempts + 1):
                submitted_trial = submit_trial(
                    project_root=project_root,
                    run_dir=run_dir,
                    trial_number=trial.number,
                    loss_module=resolved_loss_module,
                    loss_name=loss.name,
                    params=params,
                    entry_id=entry_id,
                    seed_start=active_seeds[0],
                    num_chunks=resolved_num_chunks,
                    chunk_size=resolved_chunk_size,
                    worker_parallelism=worker_parallelism,
                    use_processes=use_processes,
                    config=slurm_cfg,
                    python_executable=python_executable,
                    module_argv=worker_module_argv,
                )

                try:
                    summary = wait_for_summary(
                        submitted_trial.summary_path,
                        config=slurm_cfg,
                        chunk_job_id=submitted_trial.chunk_job_id,
                        reduce_job_id=submitted_trial.reduce_job_id,
                    )
                    trial.set_user_attr("n_seeds", summary["n_seeds"])
                    trial.set_user_attr("retry_attempts_used", attempt)
                    return float(summary["mean_total_loss"])
                except (ChunkExecutionError, TimeoutError) as exc:
                    last_error = exc
                    trial.set_user_attr("last_failed_attempt", attempt)
                    trial.set_user_attr("last_failure_type", type(exc).__name__)
                    continue

            if isinstance(last_error, ChunkExecutionError):
                raise ChunkExecutionError(
                    f"Trial {trial.number} failed after {trial_retry_attempts + 1} attempts"
                ) from last_error
            raise TimeoutError(
                f"Trial {trial.number} failed after {trial_retry_attempts + 1} attempts"
            ) from last_error

        study.optimize(objective_slurm, n_trials=n_trials, n_jobs=max_concurrent_trials, callbacks=[_write_summary_callback])
    else:
        raise ValueError("mode must be one of: single, distributed")

    completed_at_utc = datetime.now(timezone.utc)
    (run_dir / "meta.json").write_text(
        json.dumps(
            {
                "loss_name": loss.name,
                "run_name": chosen_run_name,
                "run_dir": str(run_dir),
                "mode": resolved_mode.value,
                "n_trials": n_trials,
                "entry_id": entry_id,
                "started_at_utc": started_at_utc.isoformat(),
                "completed_at_utc": completed_at_utc.isoformat(),
                "duration_seconds": (completed_at_utc - started_at_utc).total_seconds(),
                "seed_start": active_seeds[0],
                "seed_end": active_seeds[-1],
                "n_seeds": len(active_seeds),
                "chunk_size": resolved_chunk_size,
                "num_chunks": resolved_num_chunks,
                "max_concurrent_trials": max_concurrent_trials,
                "array_parallelism_limit": array_parallelism_limit,
                "worker_parallelism": worker_parallelism,
                "worker_time_limit_seconds": int(worker_time_limit.total_seconds()),
                "slurm_qos": slurm_qos,
                "best_value": float(study.best_value),
                "best_params": dict(study.best_params),
            },
            indent=2,
        ),
        encoding="utf-8",
    )

    return OptimizeResult(
        loss_name=loss.name,
        best_value=float(study.best_value),
        best_params=dict(study.best_params),
        n_trials=n_trials,
        study_name=study.study_name,
        run_dir=str(run_dir),
        mode=resolved_mode,
    )

search_param(*, range=None, allowed=None, dtype=None)

Source code in src/slurptuna/params.py
def search_param(
    *,
    range: tuple[float, float] | None = None,
    allowed: list[ParamValue] | tuple[ParamValue, ...] | None = None,
    dtype: ParamDType | None = None,
) -> SearchParam:
    values: tuple[ParamValue, ...] | None = tuple(allowed) if allowed is not None else None
    return SearchParam(range=range, allowed=values, dtype=dtype)