Skip to content

API Reference

FlashANSR

Flash Amortized Neural Symbolic Regressor.

PARAMETER DESCRIPTION
simplipy_engine

Engine responsible for manipulating and evaluating symbolic expressions.

TYPE: SimpliPyEngine

flash_ansr_model

Trained transformer backbone that proposes expression programs.

TYPE: FlashANSRModel

tokenizer

Tokenizer mapping model outputs to expression tokens.

TYPE: Tokenizer

generation_config

Configuration that controls candidate generation. If None a default SoftmaxSamplingConfig is created.

TYPE: GenerationConfig DEFAULT: None

n_restarts

Number of optimizer restarts used by the refiner when fitting constants.

TYPE: int DEFAULT: 8

refiner_method

Optimization routine employed by the refiner.

TYPE: (curve_fit_lm, minimize_bfgs, minimize_lbfgsb, minimize_neldermead, minimize_powell, least_squares_trf, least_squares_dogbox) DEFAULT: 'curve_fit_lm'

refiner_p0_noise

Distribution applied to perturb initial constant guesses. None disables perturbations.

TYPE: (uniform, normal) DEFAULT: 'uniform'

refiner_p0_noise_kwargs

Keyword arguments forwarded to the noise sampler. 'default' yields {'loc': 0.0, 'scale': 5.0} for the normal distribution.

TYPE: dict or {default} or None DEFAULT: 'default'

numpy_errors

Desired NumPy error handling strategy applied during constant refinement.

TYPE: (ignore, warn, 'raise', call, print, log) DEFAULT: 'ignore'

parsimony

Penalty coefficient that discourages overly complex expressions.

TYPE: float DEFAULT: 0.05

refiner_workers

Number of worker processes to run during constant refinement. None (the default) uses all available CPU cores, while explicit integers select a fixed pool size. Set 0 to disable multiprocessing.

TYPE: int or None DEFAULT: None

Source code in src/flash_ansr/flash_ansr.py
def __init__(
        self,
        simplipy_engine: SimpliPyEngine,
        flash_ansr_model: FlashANSRModel,
        tokenizer: Tokenizer,
        generation_config: GenerationConfig | None = None,
        n_restarts: int = 8,
        refiner_method: Literal[
            'curve_fit_lm',
            'minimize_bfgs',
            'minimize_lbfgsb',
            'minimize_neldermead',
            'minimize_powell',
            'least_squares_trf',
            'least_squares_dogbox',
        ] = 'curve_fit_lm',
        refiner_p0_noise: Literal['uniform', 'normal'] | None = 'normal',
        refiner_p0_noise_kwargs: dict | None | Literal['default'] = 'default',
        numpy_errors: Literal['ignore', 'warn', 'raise', 'call', 'print', 'log'] | None = 'ignore',
        parsimony: float = 0.05,
        refiner_workers: int | None = None):
    self.simplipy_engine = simplipy_engine
    self.flash_ansr_model = flash_ansr_model.eval()
    self.tokenizer = tokenizer

    if refiner_p0_noise_kwargs == 'default':
        refiner_p0_noise_kwargs = {'loc': 0.0, 'scale': 5.0}

    if generation_config is None:
        generation_config = SoftmaxSamplingConfig()

    self.generation_config = generation_config
    self.n_restarts = n_restarts
    self.refiner_method = refiner_method
    self.refiner_p0_noise = refiner_p0_noise
    self.refiner_p0_noise_kwargs = copy.deepcopy(refiner_p0_noise_kwargs) if refiner_p0_noise_kwargs is not None else None
    self.numpy_errors = numpy_errors
    self.parsimony = parsimony

    cpu_count = os.cpu_count() or 1

    if refiner_workers is None:
        resolved_workers = max(1, cpu_count)
    elif isinstance(refiner_workers, numbers.Integral):
        resolved_workers = max(0, int(refiner_workers))
    else:
        raise TypeError("refiner_workers must be an integer or None.")

    self.refiner_workers = resolved_workers

    self._results: list[Result] = []
    self.results: pd.DataFrame = pd.DataFrame()
    self._mcts_cache: dict[Tuple[int, ...], dict[str, Any]] = {}

    self._input_dim: int | None = None

    self.variable_mapping: dict[str, str] = {}
    self._prompt_prefix: PromptPrefix | None = None
    self._prompt_metadata: dict[str, list[list[str]]] | None = None

load classmethod

load(directory: str, generation_config: GenerationConfig | None = None, n_restarts: int = 8, refiner_method: Literal['curve_fit_lm', 'minimize_bfgs', 'minimize_lbfgsb', 'minimize_neldermead', 'minimize_powell', 'least_squares_trf', 'least_squares_dogbox'] = 'curve_fit_lm', refiner_p0_noise: Literal['uniform', 'normal'] | None = 'normal', refiner_p0_noise_kwargs: dict | None | Literal['default'] = 'default', numpy_errors: Literal['ignore', 'warn', 'raise', 'call', 'print', 'log'] | None = 'ignore', parsimony: float = 0.05, device: str = 'cpu', refiner_workers: int | None = None) -> FlashANSR

Instantiate a FlashANSR model from a configuration directory.

PARAMETER DESCRIPTION
directory

Directory that contains model.yaml, tokenizer.yaml and state_dict.pt artifacts.

TYPE: str

generation_config

Generation parameters to override defaults during candidate search.

TYPE: GenerationConfig DEFAULT: None

n_restarts

Number of restarts passed to the refiner.

TYPE: int DEFAULT: 8

refiner_method

Optimization routine for constant fitting.

TYPE: (curve_fit_lm, minimize_bfgs, minimize_lbfgsb, minimize_neldermead, minimize_powell, least_squares_trf, least_squares_dogbox) DEFAULT: 'curve_fit_lm'

refiner_p0_noise

Distribution used to perturb initial constant guesses.

TYPE: (uniform, normal) DEFAULT: 'uniform'

refiner_p0_noise_kwargs

Additional keyword arguments for the noise sampler. 'default' resolves to {'loc': 0.0, 'scale': 5.0}.

TYPE: dict or {default} or None DEFAULT: 'default'

numpy_errors

NumPy floating-point error policy applied during refinement.

TYPE: (ignore, warn, 'raise', call, print, log) DEFAULT: 'ignore'

parsimony

Parsimony coefficient used when compiling results.

TYPE: float DEFAULT: 0.05

device

Torch device where the model weights will be loaded.

TYPE: str DEFAULT: 'cpu'

refiner_workers

Desired worker-pool size for constant refinement. None uses the number of available CPU cores, integers select an explicit pool size, and 0 disables multiprocessing. Mirrors the constructor parameter.

TYPE: int or None DEFAULT: None

RETURNS DESCRIPTION
model

Fully initialized regressor ready for inference.

TYPE: FlashANSR

Source code in src/flash_ansr/flash_ansr.py
@classmethod
def load(
        cls,
        directory: str,
        generation_config: GenerationConfig | None = None,
        n_restarts: int = 8,
        refiner_method: Literal[
            'curve_fit_lm',
            'minimize_bfgs',
            'minimize_lbfgsb',
            'minimize_neldermead',
            'minimize_powell',
            'least_squares_trf',
            'least_squares_dogbox',
        ] = 'curve_fit_lm',
        refiner_p0_noise: Literal['uniform', 'normal'] | None = 'normal',
        refiner_p0_noise_kwargs: dict | None | Literal['default'] = 'default',
        numpy_errors: Literal['ignore', 'warn', 'raise', 'call', 'print', 'log'] | None = 'ignore',
        parsimony: float = 0.05,
        device: str = 'cpu',
        refiner_workers: int | None = None) -> "FlashANSR":
    """Instantiate a `FlashANSR` model from a configuration directory.

    Parameters
    ----------
    directory : str
        Directory that contains ``model.yaml``, ``tokenizer.yaml`` and
        ``state_dict.pt`` artifacts.
    generation_config : GenerationConfig, optional
        Generation parameters to override defaults during candidate search.
    n_restarts : int, optional
        Number of restarts passed to the refiner.
    refiner_method : {'curve_fit_lm', 'minimize_bfgs', 'minimize_lbfgsb', 'minimize_neldermead', 'minimize_powell', 'least_squares_trf', 'least_squares_dogbox'}
        Optimization routine for constant fitting.
    refiner_p0_noise : {'uniform', 'normal'}, optional
        Distribution used to perturb initial constant guesses.
    refiner_p0_noise_kwargs : dict or {'default'} or None, optional
        Additional keyword arguments for the noise sampler. ``'default'``
        resolves to ``{'loc': 0.0, 'scale': 5.0}``.
    numpy_errors : {'ignore', 'warn', 'raise', 'call', 'print', 'log'} or None, optional
        NumPy floating-point error policy applied during refinement.
    parsimony : float, optional
        Parsimony coefficient used when compiling results.
    device : str, optional
        Torch device where the model weights will be loaded.
    refiner_workers : int or None, optional
        Desired worker-pool size for constant refinement. ``None`` uses the
        number of available CPU cores, integers select an explicit pool size,
        and ``0`` disables multiprocessing. Mirrors the constructor parameter.

    Returns
    -------
    model : FlashANSR
        Fully initialized regressor ready for inference.
    """
    directory = substitute_root_path(directory)

    flash_ansr_model_path = os.path.join(directory, 'model.yaml')
    tokenizer_path = os.path.join(directory, 'tokenizer.yaml')

    model = FlashANSRModel.from_config(flash_ansr_model_path)
    model.load_state_dict(torch.load(os.path.join(directory, "state_dict.pt"), weights_only=True, map_location=device))
    model.eval().to(device)

    tokenizer = Tokenizer.from_config(tokenizer_path)

    return cls(
        simplipy_engine=model.simplipy_engine,
        flash_ansr_model=model,
        tokenizer=tokenizer,
        generation_config=generation_config,
        n_restarts=n_restarts,
        refiner_method=refiner_method,
        refiner_p0_noise=refiner_p0_noise,
        refiner_p0_noise_kwargs=refiner_p0_noise_kwargs,
        numpy_errors=numpy_errors,
        parsimony=parsimony,
        refiner_workers=refiner_workers)

fit

fit(X: ndarray | Tensor | DataFrame, y: ndarray | Tensor | DataFrame | Series, variable_names: list[str] | dict[str, str] | Literal['auto'] | None = 'auto', converge_error: Literal['raise', 'ignore', 'print'] = 'ignore', verbose: bool = False, *, complexity: int | float | None = None, allowed_terms: Iterable[Sequence[Any]] | None = None, include_terms: Iterable[Sequence[Any]] | None = None, exclude_terms: Iterable[Sequence[Any]] | None = None) -> None

Perform symbolic regression on (X, y) and refine candidate expressions.

PARAMETER DESCRIPTION
X

Feature matrix where rows index observations and columns variables.

TYPE: ndarray or Tensor or DataFrame

y

Target values. Multi-output targets are unsupported.

TYPE: ndarray or Tensor or DataFrame or Series

variable_names

Mapping from internal variable tokens to descriptive names.

TYPE: list[str] or dict[str, str] or {auto} or None DEFAULT: 'auto'

converge_error

Handling strategy when the refiner fails to converge.

TYPE: ('raise', ignore, print) DEFAULT: 'raise'

verbose

If True progress bars and diagnostic output are displayed.

TYPE: bool DEFAULT: False

allowed_terms

Keyword-only list of term token sequences that may appear in the generated expression.

TYPE: Iterable[Sequence[str]] or None DEFAULT: None

include_terms

Keyword-only subset of allowed terms that the expression should prioritise using.

TYPE: Iterable[Sequence[str]] or None DEFAULT: None

exclude_terms

Keyword-only list of term token sequences that should be discouraged during generation.

TYPE: Iterable[Sequence[str]] or None DEFAULT: None

RAISES DESCRIPTION
ValueError

If y has more than one output dimension or cannot be reshaped.

Source code in src/flash_ansr/flash_ansr.py
def fit(
        self,
        X: np.ndarray | torch.Tensor | pd.DataFrame,
        y: np.ndarray | torch.Tensor | pd.DataFrame | pd.Series,
        variable_names: list[str] | dict[str, str] | Literal['auto'] | None = 'auto',
        converge_error: Literal['raise', 'ignore', 'print'] = 'ignore',
        verbose: bool = False,
        *,
        complexity: int | float | None = None,
        allowed_terms: Iterable[Sequence[Any]] | None = None,
        include_terms: Iterable[Sequence[Any]] | None = None,
        exclude_terms: Iterable[Sequence[Any]] | None = None) -> None:
    """Perform symbolic regression on ``(X, y)`` and refine candidate expressions.

    Parameters
    ----------
    X : ndarray or Tensor or DataFrame
        Feature matrix where rows index observations and columns variables.
    y : ndarray or Tensor or DataFrame or Series
        Target values. Multi-output targets are unsupported.
    variable_names : list[str] or dict[str, str] or {'auto'} or None, optional
        Mapping from internal variable tokens to descriptive names.
    converge_error : {'raise', 'ignore', 'print'}, optional
        Handling strategy when the refiner fails to converge.
    verbose : bool, optional
        If ``True`` progress bars and diagnostic output are displayed.
    allowed_terms : Iterable[Sequence[str]] or None, optional
        Keyword-only list of term token sequences that may appear in the
        generated expression.
    include_terms : Iterable[Sequence[str]] or None, optional
        Keyword-only subset of allowed terms that the expression should
        prioritise using.
    exclude_terms : Iterable[Sequence[str]] or None, optional
        Keyword-only list of term token sequences that should be discouraged
        during generation.

    Raises
    ------
    ValueError
        If ``y`` has more than one output dimension or cannot be reshaped.
    """
    # TODO: Support lists
    # TODO: Support 0-d and 1-d tensors

    if len(X.shape) == 1:
        X = X.reshape(-1, 1)
    if len(y.shape) == 1:
        y = y.reshape(-1, 1)
    elif y.shape[-1] != 1:
        raise ValueError("The target data must have a single output dimension")

    X = self._truncate_input(X)

    # Default: No mapping
    self.variable_mapping = {}

    if isinstance(variable_names, list):
        # column i -> variable_names[i]
        self.variable_mapping = {f"x{i + 1}": name for i, name in enumerate(variable_names)}

    elif isinstance(variable_names, dict):
        if isinstance(X, pd.DataFrame):
            # column i -> variable_names[column i]
            self.variable_mapping = {f"x{i + 1}": variable_names[c] for i, c in enumerate(X.columns)}
        else:
            # custom mapping
            self.variable_mapping = variable_names

    elif variable_names == 'auto':
        if isinstance(X, pd.DataFrame):
            # column i -> column name
            self.variable_mapping = {f"x{i + 1}": name for i, name in enumerate(X.columns)}

    if complexity is not None and not isinstance(complexity, numbers.Real):
        raise TypeError("complexity must be a real scalar when provided")

    with torch.no_grad():
        # Convert the input data to a tensor
        if not isinstance(X, torch.Tensor):
            if isinstance(X, pd.DataFrame):
                X = torch.tensor(X.values, dtype=torch.float32, device=self.flash_ansr_model.device)
            else:
                X = torch.tensor(X, dtype=torch.float32, device=self.flash_ansr_model.device)
        else:
            X = X.to(self.flash_ansr_model.device)

        if not isinstance(y, torch.Tensor):
            if isinstance(y, (pd.DataFrame, pd.Series)):
                y = torch.tensor(y.values, dtype=torch.float32, device=self.flash_ansr_model.device)
            else:
                y = torch.tensor(y, dtype=torch.float32, device=self.flash_ansr_model.device)
        else:
            y = y.to(self.flash_ansr_model.device)

        if y.dim() == 1:
            y = y.unsqueeze(-1)

        sample_count = y.shape[0]
        if sample_count <= 1:
            # Torch warns when computing an unbiased variance with a single sample.
            # Skip the reduction entirely so downstream scoring quietly falls back
            # to the residual loss via ``_compute_fvu``.
            y_variance = float('nan')
        else:
            y_variance = y.var(dim=0).item()

        X = pad_input_set(X, self.n_variables)

        # Concatenate x and y along the feature dimension
        data_tensor = torch.cat([X, y], dim=-1)

        self._results = []

        # Temporarily adopt the configured floating-point error policy for refinement.
        numpy_errors_before = np.geterr()
        np.seterr(all=self.numpy_errors)

        prompt_prefix = self._prepare_prompt_prefix(
            complexity=complexity,
            allowed_terms=allowed_terms,
            include_terms=include_terms,
            exclude_terms=exclude_terms,
        )

        metadata_snapshot: dict[str, list[list[str]]] | None
        if prompt_prefix is not None:
            metadata_snapshot = copy.deepcopy(prompt_prefix.metadata)
        else:
            metadata_snapshot = None

        self._prompt_metadata = copy.deepcopy(metadata_snapshot) if metadata_snapshot is not None else None

        raw_beams, log_probs, _completed_flags, _rewards = self.generate(
            data_tensor,
            prompt_prefix=prompt_prefix,
            complexity=complexity,
            verbose=verbose,
        )

        beams = [self.flash_ansr_model.tokenizer.extract_expression_from_beam(raw_beam)[0] for raw_beam in raw_beams]

        raw_beams_decoded = [self.tokenizer.decode(raw_beam, special_tokens='<constant>') for raw_beam in raw_beams]
        beams_decoded = [self.tokenizer.decode(beam, special_tokens='<constant>') for beam in beams]

        X_np = X.cpu().numpy()
        y_np = y.cpu().numpy()

        refinement_jobs: list[dict[str, Any]] = []
        beam_iterator = zip(raw_beams, raw_beams_decoded, beams, beams_decoded, log_probs)
        for raw_beam, raw_beam_decoded, beam, beam_decoded, log_prob in beam_iterator:
            if not self.simplipy_engine.is_valid(beam_decoded):
                continue

            job: dict[str, Any] = {
                'raw_beam': raw_beam,
                'raw_beam_decoded': raw_beam_decoded,
                'beam': beam,
                'expression': beam_decoded,
                'log_prob': log_prob,
                'n_variables': self.n_variables,
                'n_restarts': self.n_restarts,
                'method': self.refiner_method,
                'p0_noise': self.refiner_p0_noise,
                'p0_noise_kwargs': copy.deepcopy(self.refiner_p0_noise_kwargs) if self.refiner_p0_noise_kwargs is not None else None,
                'converge_error': converge_error,
                'numpy_errors': self.numpy_errors,
                'y_variance': y_variance,
                'parsimony': self.parsimony,
                'complexity': complexity,
                'metadata_snapshot': metadata_snapshot,
            }
            refinement_jobs.append(job)

        if refinement_jobs:
            available_methods = mp.get_all_start_methods()
            max_workers = min(self.refiner_workers, len(refinement_jobs))
            use_parallel = max_workers > 1 and 'fork' in available_methods

            input_dim = X_np.shape[1]
            self._input_dim = input_dim

            if max_workers > 1 and not use_parallel:
                warnings.warn("Parallel refinement requires the 'fork' start method; falling back to serial execution.")

            with _RefinementContext(self.simplipy_engine, X_np, y_np):
                if use_parallel:
                    ctx = mp.get_context('fork')
                    seed_sequence = np.random.SeedSequence()
                    spawned = seed_sequence.spawn(len(refinement_jobs))
                    for job, seq in zip(refinement_jobs, spawned):
                        job['seed'] = int(seq.generate_state(1, dtype=np.uint32)[0])

                    with ProcessPoolExecutor(max_workers=max_workers, mp_context=ctx) as executor:
                        futures = [executor.submit(_refine_candidate_worker, job) for job in refinement_jobs]
                        for future in _iterate_with_progress(
                            as_completed(futures),
                            total=len(futures),
                            verbose=verbose,
                            desc="Fitting Constants",
                        ):
                            result, warning_msg = future.result()
                            if warning_msg and converge_error == 'print':
                                print(warning_msg)
                            if result is not None:
                                entry = self._create_result_entry(payload=result, input_dim=input_dim)
                                if entry is not None:
                                    self._results.append(entry)
                else:
                    for job in _iterate_with_progress(
                        refinement_jobs,
                        total=len(refinement_jobs),
                        verbose=verbose,
                        desc="Fitting Constants",
                    ):
                        serial_payload = job.copy()
                        serial_payload.update({'X': X_np, 'y': y_np, 'simplipy_engine': self.simplipy_engine})
                        result, warning_msg = _refine_candidate_worker(serial_payload)
                        if warning_msg and converge_error == 'print':
                            print(warning_msg)
                        if result is not None:
                            entry = self._create_result_entry(payload=result, input_dim=input_dim)
                            if entry is not None:
                                self._results.append(entry)

        self.compile_results(self.parsimony)

        np.seterr(**numpy_errors_before)

predict

predict(X: ndarray | Tensor | DataFrame, nth_best_beam: int = 0, nth_best_constants: int = 0) -> np.ndarray

Evaluate a fitted expression on new data.

PARAMETER DESCRIPTION
X

Feature matrix to evaluate.

TYPE: ndarray or Tensor or DataFrame

nth_best_beam

Beam index to select from the ranked results.

TYPE: int DEFAULT: 0

nth_best_constants

Index of the constant fit to choose for the selected beam.

TYPE: int DEFAULT: 0

RETURNS DESCRIPTION
y_pred

Predicted targets with the same leading dimension as X.

TYPE: ndarray

RAISES DESCRIPTION
ValueError

If the model has not been fitted before prediction.

Source code in src/flash_ansr/flash_ansr.py
def predict(self, X: np.ndarray | torch.Tensor | pd.DataFrame, nth_best_beam: int = 0, nth_best_constants: int = 0) -> np.ndarray:
    """Evaluate a fitted expression on new data.

    Parameters
    ----------
    X : ndarray or Tensor or DataFrame
        Feature matrix to evaluate.
    nth_best_beam : int, optional
        Beam index to select from the ranked results.
    nth_best_constants : int, optional
        Index of the constant fit to choose for the selected beam.

    Returns
    -------
    y_pred : ndarray
        Predicted targets with the same leading dimension as ``X``.

    Raises
    ------
    ValueError
        If the model has not been fitted before prediction.
    """
    # TODO: Support lists
    # TODO: Support 0-d and 1-d tensors

    X = self._truncate_input(X)

    if isinstance(X, pd.DataFrame):
        X = X.values

    X = pad_input_set(X, self.n_variables)

    if len(self._results) == 0:
        raise ValueError("The model has not been fitted yet. Please call the fit method first.")

    return self._results[nth_best_beam]['refiner'].predict(X, nth_best_constants=nth_best_constants)

get_expression

get_expression(nth_best_beam: int = 0, nth_best_constants: int = 0, return_prefix: bool = False, precision: int = 2, map_variables: bool = True, **kwargs: Any) -> list[str] | str

Retrieve a formatted expression from the compiled results.

PARAMETER DESCRIPTION
nth_best_beam

Beam index to extract from self._results.

TYPE: int DEFAULT: 0

nth_best_constants

Constant fit index for the selected beam.

TYPE: int DEFAULT: 0

return_prefix

If True return the prefix notation instead of infix string.

TYPE: bool DEFAULT: False

precision

Number of decimal places used when rendering constants.

TYPE: int DEFAULT: 2

map_variables

When True apply self.variable_mapping to humanise variables.

TYPE: bool DEFAULT: True

**kwargs

Extra keyword arguments forwarded to :meth:Refiner.transform.

TYPE: Any DEFAULT: {}

RETURNS DESCRIPTION
expression

Expression either as a token list or human-readable string.

TYPE: list[str] or str

Source code in src/flash_ansr/flash_ansr.py
def get_expression(self, nth_best_beam: int = 0, nth_best_constants: int = 0, return_prefix: bool = False, precision: int = 2, map_variables: bool = True, **kwargs: Any) -> list[str] | str:
    """Retrieve a formatted expression from the compiled results.

    Parameters
    ----------
    nth_best_beam : int, optional
        Beam index to extract from ``self._results``.
    nth_best_constants : int, optional
        Constant fit index for the selected beam.
    return_prefix : bool, optional
        If ``True`` return the prefix notation instead of infix string.
    precision : int, optional
        Number of decimal places used when rendering constants.
    map_variables : bool, optional
        When ``True`` apply ``self.variable_mapping`` to humanise variables.
    **kwargs : Any
        Extra keyword arguments forwarded to :meth:`Refiner.transform`.

    Returns
    -------
    expression : list[str] or str
        Expression either as a token list or human-readable string.
    """
    return self._results[nth_best_beam]['refiner'].transform(
        expression=self._results[nth_best_beam]['expression'],
        nth_best_constants=nth_best_constants,
        return_prefix=return_prefix,
        precision=precision,
        variable_mapping=self.variable_mapping if map_variables else None,
        **kwargs)

save_results

save_results(path: str) -> None

Persist fitted results (minus lambdas) for later reuse.

Source code in src/flash_ansr/flash_ansr.py
def save_results(self, path: str) -> None:
    """Persist fitted results (minus lambdas) for later reuse."""

    if not self._results:
        raise ValueError("No results available to save. Run `fit` first.")

    input_dim = self._input_dim if self._input_dim is not None else self.n_variables
    metadata = {
        "format_version": RESULTS_FORMAT_VERSION,
        "parsimony": self.parsimony,
        "n_variables": self.n_variables,
        "input_dim": input_dim,
        "variable_mapping": copy.deepcopy(self.variable_mapping),
    }

    payload = serialize_results_payload(self._results, metadata=metadata)
    save_results_payload(payload, path)

load_results

load_results(path: str, *, rebuild_refiners: bool = True) -> None

Load previously saved results and rebuild refiners if requested.

Source code in src/flash_ansr/flash_ansr.py
def load_results(self, path: str, *, rebuild_refiners: bool = True) -> None:
    """Load previously saved results and rebuild refiners if requested."""

    payload = load_results_payload(path)
    metadata = payload.get("metadata", {})

    version = int(payload.get("version", 0))
    if version != RESULTS_FORMAT_VERSION:
        warnings.warn(
            f"Results payload version {version} does not match expected {RESULTS_FORMAT_VERSION}; attempting to proceed anyway."
        )

    parsimony = float(metadata.get("parsimony", self.parsimony))
    n_variables = int(metadata.get("n_variables", self.n_variables))
    input_dim = int(metadata.get("input_dim", n_variables))

    self._input_dim = input_dim
    self.variable_mapping = metadata.get("variable_mapping", self.variable_mapping)

    restored = deserialize_results_payload(
        payload,
        simplipy_engine=self.simplipy_engine,
        n_variables=n_variables,
        input_dim=input_dim,
        rebuild_refiners=rebuild_refiners,
    )

    self._results = restored
    self.compile_results(parsimony)

compile_results

compile_results(parsimony: float) -> None

Aggregate refiner outputs into a tidy pandas.DataFrame.

PARAMETER DESCRIPTION
parsimony

Parsimony coefficient used to recompute scores before ranking.

TYPE: float

RAISES DESCRIPTION
ConvergenceError

If no beams converged during refinement.

Source code in src/flash_ansr/flash_ansr.py
def compile_results(self, parsimony: float) -> None:
    """Aggregate refiner outputs into a tidy `pandas.DataFrame`.

    Parameters
    ----------
    parsimony : float
        Parsimony coefficient used to recompute scores before ranking.

    Raises
    ------
    ConvergenceError
        If no beams converged during refinement.
    """
    if not self._results:
        raise ConvergenceError("The optimization did not converge for any beam")

    self.initial_parsimony = self.parsimony
    self.parsimony = parsimony

    # Compute the new score for each result
    for result in self._results:
        if 'score' in result:
            # Recompute the score with the new parsimony coefficient
            fvu = result.get('fvu', np.nan)
            if np.isfinite(fvu):
                result['score'] = self._score_from_fvu(float(fvu), len(result['expression']), self.parsimony)
            else:
                result['score'] = np.nan

    # Sort the results by the best loss of each beam
    self._results = list(sorted(self._results, key=lambda x: (
        x['score'] if not np.isnan(x['score']) else float('inf'),
        np.isnan(x['score'])
    )))

    # Create a dataframe
    self.results = pd.DataFrame(self._results)

    # Explode the fits for each beam
    self.results = self.results.explode('fits')
    self.results['beam_id'] = self.results.index
    self.results.reset_index(drop=True, inplace=True)

    # Split the fit tuples into columns
    fits_columns = pd.DataFrame(self.results['fits'].tolist(), columns=['fit_constants', 'fit_covariances', 'fit_loss'])
    self.results = pd.concat([self.results, fits_columns], axis=1)
    self.results.drop(columns=['fits'], inplace=True)

FlashANSRDataset

Dataset wrapper for amortized neural symbolic regression training.

Manages skeleton sampling, support point generation, optional prompt preprocessing, and collation into model-ready batches. Can also compile streaming output into an on-disk datasets.Dataset for deterministic iteration.

PARAMETER DESCRIPTION
skeleton_pool

Source of operator-only expression templates and sampling logic.

TYPE: SkeletonPool

tokenizer

Tokenizer used for expression serialization and padding.

TYPE: Tokenizer

padding

Strategy for padding numeric support points.

TYPE: (random, zero) DEFAULT: "random"

preprocessor

Prompt-aware preprocessor; when provided, prompt metadata can be injected during sampling or in worker processes.

TYPE: FlashANSRPreprocessor DEFAULT: None

Source code in src/flash_ansr/data/data.py
def __init__(
    self,
    skeleton_pool: SkeletonPool,
    tokenizer: Tokenizer,
    padding: Literal["random", "zero"],
    preprocessor: FlashANSRPreprocessor | None = None,
) -> None:
    self.skeleton_pool = skeleton_pool
    self.tokenizer = tokenizer
    self.padding = padding
    self.preprocessor = preprocessor
    self.data = None

    self._collator = BatchFormatter(tokenizer=tokenizer)
    self._stream = SharedMemoryWorkerPool(
        skeleton_pool=skeleton_pool,
        tokenizer=tokenizer,
        padding=padding,
    )
    self._preprocessor_prompt_config = (
        copy.deepcopy(preprocessor.prompt_config) if preprocessor is not None else None
    )

from_config classmethod

from_config(config: dict[str, Any] | str) -> FlashANSRDataset

Instantiate from a YAML/dict config.

Paths are normalized via load_config and substitute_root_path. The config may embed a skeleton pool definition or point to a directory containing one.

PARAMETER DESCRIPTION
config

Dataset config or path to a YAML file.

TYPE: dict or str

RETURNS DESCRIPTION
FlashANSRDataset

Dataset wrapper with tokenizer and optional preprocessor wired.

Source code in src/flash_ansr/data/data.py
@classmethod
def from_config(cls, config: dict[str, Any] | str) -> "FlashANSRDataset":
    """Instantiate from a YAML/dict config.

    Paths are normalized via `load_config` and `substitute_root_path`. The
    config may embed a skeleton pool definition or point to a directory
    containing one.

    Parameters
    ----------
    config : dict or str
        Dataset config or path to a YAML file.

    Returns
    -------
    FlashANSRDataset
        Dataset wrapper with tokenizer and optional preprocessor wired.
    """
    config_ = load_config(config)

    if "dataset" in config_.keys():
        config_ = config_["dataset"]

    if isinstance(config, str) and isinstance(config_["skeleton_pool"], str):
        if config_["skeleton_pool"].startswith('.'):  # pragma: no cover - config guard
            config_["skeleton_pool"] = os.path.join(os.path.dirname(config), config_["skeleton_pool"])
        config_["skeleton_pool"] = substitute_root_path(config_["skeleton_pool"])

    if os.path.isfile(config_["skeleton_pool"]) or isinstance(config_["skeleton_pool"], dict):
        skeleton_pool = SkeletonPool.from_config(config_["skeleton_pool"])
    elif os.path.isdir(config_["skeleton_pool"]):
        skeleton_pool = SkeletonPool.load(config_["skeleton_pool"])[1]
    else:
        raise ValueError(f"Invalid skeleton pool configuration: {config_['skeleton_pool']}")

    tokenizer = Tokenizer.from_config(config_["tokenizer"])

    preprocessor_cfg = config_.get("preprocessor") if isinstance(config_, dict) else None
    preprocessor: FlashANSRPreprocessor | None = None
    if preprocessor_cfg is not None:
        preprocessor = FlashANSRPreprocessor.from_config(
            preprocessor_cfg,
            simplipy_engine=skeleton_pool.simplipy_engine,
            tokenizer=tokenizer,
            skeleton_pool=skeleton_pool,
        )

    return cls(
        skeleton_pool=skeleton_pool,
        tokenizer=tokenizer,
        padding=config_["padding"],
        preprocessor=preprocessor,
    )

iterate

iterate(size: int | None = None, steps: int | None = None, batch_size: int | None = None, n_support: int | None = None, max_seq_len: int = 512, max_n_support: int | None = None, n_per_equation: int = 1, preprocess: bool = False, preprocess_in_worker: bool | None = None, tokenizer_oov: Literal['unk', 'raise'] = 'raise', num_workers: int | None = None, prefetch_factor: int = 2, persistent: bool = False, tqdm_kwargs: dict[str, Any] | None = None, verbose: bool = False) -> Generator[dict[str, Any], None, None]

Stream batches of synthetic data.

PARAMETER DESCRIPTION
size

Total number of samples to generate (used if steps is None).

TYPE: int DEFAULT: None

steps

Number of generation steps; overrides size when set.

TYPE: int DEFAULT: None

batch_size

Samples per step; defaults to 1.

TYPE: int DEFAULT: None

n_support

Support points per equation; pool default when None.

TYPE: int DEFAULT: None

max_seq_len

Maximum prefix length for generated expressions.

TYPE: int DEFAULT: 512

max_n_support

Upper bound for support points; used for padding.

TYPE: int DEFAULT: None

n_per_equation

Number of datasets to draw per skeleton before moving on.

TYPE: int DEFAULT: 1

preprocess

Whether to run the preprocessor on generated batches.

TYPE: bool DEFAULT: False

preprocess_in_worker

Force preprocessing inside workers (True), main process (False), or auto-select (None).

TYPE: bool DEFAULT: None

tokenizer_oov

How to handle tokens missing from the tokenizer.

TYPE: (unk, 'raise') DEFAULT: "unk"

num_workers

Worker count for multiprocessing; defaults to CPU count when None.

TYPE: int DEFAULT: None

prefetch_factor

Jobs per worker to pre-schedule.

TYPE: int DEFAULT: 2

persistent

Clone tensors to detach from shared memory buffers.

TYPE: bool DEFAULT: False

tqdm_kwargs

Additional arguments forwarded to tqdm progress bars.

TYPE: dict DEFAULT: None

verbose

Enable progress reporting.

TYPE: bool DEFAULT: False

YIELDS DESCRIPTION
dict

Model-ready batch with tensors and optional prompt metadata.

Source code in src/flash_ansr/data/data.py
def iterate(
    self,
    size: int | None = None,
    steps: int | None = None,
    batch_size: int | None = None,
    n_support: int | None = None,
    max_seq_len: int = 512,
    max_n_support: int | None = None,
    n_per_equation: int = 1,
    preprocess: bool = False,
    preprocess_in_worker: bool | None = None,
    tokenizer_oov: Literal["unk", "raise"] = "raise",
    num_workers: int | None = None,
    prefetch_factor: int = 2,
    persistent: bool = False,
    tqdm_kwargs: dict[str, Any] | None = None,
    verbose: bool = False,
) -> Generator[dict[str, Any], None, None]:
    """Stream batches of synthetic data.

    Parameters
    ----------
    size : int, optional
        Total number of samples to generate (used if `steps` is None).
    steps : int, optional
        Number of generation steps; overrides `size` when set.
    batch_size : int, optional
        Samples per step; defaults to 1.
    n_support : int, optional
        Support points per equation; pool default when None.
    max_seq_len : int, default 512
        Maximum prefix length for generated expressions.
    max_n_support : int, optional
        Upper bound for support points; used for padding.
    n_per_equation : int, default 1
        Number of datasets to draw per skeleton before moving on.
    preprocess : bool, default False
        Whether to run the preprocessor on generated batches.
    preprocess_in_worker : bool, optional
        Force preprocessing inside workers (True), main process (False), or auto-select (None).
    tokenizer_oov : {"unk", "raise"}, default "raise"
        How to handle tokens missing from the tokenizer.
    num_workers : int, optional
        Worker count for multiprocessing; defaults to CPU count when None.
    prefetch_factor : int, default 2
        Jobs per worker to pre-schedule.
    persistent : bool, default False
        Clone tensors to detach from shared memory buffers.
    tqdm_kwargs : dict, optional
        Additional arguments forwarded to tqdm progress bars.
    verbose : bool, default False
        Enable progress reporting.

    Yields
    ------
    dict
        Model-ready batch with tensors and optional prompt metadata.
    """
    if batch_size is None:
        batch_size = 1

    tqdm_kwargs = dict(tqdm_kwargs) if tqdm_kwargs else {}

    use_worker_preprocess = False
    if preprocess:
        if self.preprocessor is None:
            if preprocess_in_worker:
                warnings.warn(
                    "worker preprocessing requested but no preprocessor configured; falling back to main process.",
                    RuntimeWarning,
                    stacklevel=2,
                )
        else:
            if preprocess_in_worker is None:
                use_worker_preprocess = True
            else:
                use_worker_preprocess = bool(preprocess_in_worker)

    if self._stream.is_initialized and self._stream.worker_preprocess_enabled != use_worker_preprocess:
        raise RuntimeError(
            "Cannot switch worker preprocessing mode while workers are active. "
            "Call `dataset.shutdown()` before iterating with a new mode."
        )

    if self.data is not None:
        precompiled_kwargs = tqdm_kwargs.copy()
        precompiled_kwargs.setdefault("desc", "Iterating over pre-compiled dataset")
        precompiled_kwargs.setdefault("disable", not verbose)
        precompiled_kwargs.setdefault("smoothing", 0.0)
        yield from tqdm(self.data, **precompiled_kwargs)
        return

    if steps is None and size is None:
        raise ValueError("Either size or steps must be specified.")

    if steps is None:
        assert size is not None
        steps = (size + batch_size - 1) // batch_size

    self._initialize_stream(
        prefetch_factor=prefetch_factor,
        batch_size=batch_size,
        n_per_equation=n_per_equation,
        max_seq_len=max_seq_len,
        max_n_support=max_n_support,
        num_workers=num_workers,
        tokenizer_oov=tokenizer_oov,
        worker_preprocess=use_worker_preprocess,
    )

    if self._stream.metadata_pool is None or not self._stream.buffers:
        raise RuntimeError("Multiprocessing resources are not properly initialized.")

    pool_size = self._stream.pool_size

    progress_kwargs = tqdm_kwargs.copy()
    progress_kwargs.setdefault("total", steps)
    progress_kwargs.setdefault("desc", "Generating Batches")
    progress_kwargs.setdefault("disable", not verbose)
    progress_kwargs.setdefault("smoothing", 0.0)
    pbar = tqdm(**progress_kwargs)

    try:
        for _ in range(min(pool_size, steps)):
            slot_idx = self._stream.acquire_slot()
            self._stream.submit_job(slot_idx, n_support)

        for step_id in range(steps):
            completed_slot_idx = self._stream.get_completed_slot()
            metadata_and_constants = self._stream.metadata_pool[completed_slot_idx]
            if metadata_and_constants is None:
                raise RuntimeError("Worker returned empty payload.")

            metadata_batch = metadata_and_constants["metadata"]
            metadata_fields: dict[str, list[Any]] = {}
            if metadata_batch:
                for key in metadata_batch[0]:
                    metadata_fields[key] = [entry[key] for entry in metadata_batch]

            batch_dict = {
                "x_tensors": torch.from_numpy(self._stream.buffers["x_tensors"][completed_slot_idx]),
                "y_tensors": torch.from_numpy(self._stream.buffers["y_tensors"][completed_slot_idx]),
                "data_attn_mask": torch.from_numpy(self._stream.buffers["data_attn_mask"][completed_slot_idx]).to(torch.bool),
                "input_ids": torch.from_numpy(self._stream.buffers["input_ids"][completed_slot_idx]),
                "constants": [
                    torch.from_numpy(c)
                    for c in metadata_and_constants["constants"]
                ],
            }
            batch_dict.update(metadata_fields)

            preprocessed_batch = metadata_and_constants.get("preprocessed")
            if preprocess:
                if use_worker_preprocess:
                    if preprocessed_batch is not None:
                        self._inject_preprocessed_fields(batch_dict, preprocessed_batch)
                    elif self.preprocessor:
                        batch_dict = self.preprocessor.format(batch_dict)
                elif self.preprocessor:
                    batch_dict = self.preprocessor.format(batch_dict)

            self._collator.ensure_numeric_channel(batch_dict)

            if persistent:
                cloned_batch: dict[str, Any] = {}
                for key, value in batch_dict.items():
                    if isinstance(value, torch.Tensor):
                        cloned_batch[key] = value.clone()
                    elif key == "constants" and isinstance(value, list):
                        cloned_batch[key] = [tensor.clone() for tensor in value]
                    elif key == "constants":
                        cloned_batch[key] = value
                    else:
                        cloned_batch[key] = value
                batch_dict = cloned_batch

            yield batch_dict

            pbar.update(1)

            self._stream.release_slot(completed_slot_idx)
            if step_id + pool_size < steps:
                slot_to_refill = self._stream.acquire_slot()
                self._stream.submit_job(slot_to_refill, n_support)
    finally:
        pbar.close()
        self.shutdown()

shutdown

shutdown() -> None

Release multiprocessing workers and shared buffers.

Source code in src/flash_ansr/data/data.py
def shutdown(self) -> None:
    """Release multiprocessing workers and shared buffers."""
    self._stream.shutdown()

FlashANSRPreprocessor

Format batch inputs and optionally enrich them with prompt metadata.

Source code in src/flash_ansr/preprocessing/pipeline.py
def __init__(
    self,
    simplipy_engine: SimpliPyEngine,
    tokenizer: Tokenizer,
    skeleton_pool: SkeletonPool | None = None,
    *,
    prompt_config: FlashASNRPreprocessorConfig | dict[str, Any] | None = None,
) -> None:
    self.simplipy_engine = simplipy_engine
    self.tokenizer = tokenizer
    self.skeleton_pool = skeleton_pool

    self.prompt_config = FlashASNRPreprocessorConfig.from_dict(prompt_config)
    self._prompt_enabled = (
        skeleton_pool is not None
        and self.prompt_config.prompt_feature.prompt_probability > 0
    )

    self._feature_extractor: PromptFeatureExtractor | None = None
    if self._prompt_enabled:
        self._feature_extractor = PromptFeatureExtractor(
            simplipy_engine=simplipy_engine,
            tokenizer=tokenizer,
            config=self.prompt_config.prompt_feature,
            skeleton_pool=skeleton_pool,
        )

    self._serializer = PromptSerializer(tokenizer)

Generation configurations

BeamSearchConfig

Configuration for beam-search based generation.

Source code in src/flash_ansr/utils/generation.py
def __init__(
    self,
    *,
    beam_width: int = 32,
    max_len: int = 32,
    batch_size: int = 128,
    unique: bool = True,
    limit_expansions: bool = True,
) -> None:
    self.method = 'beam_search'
    self.beam_width = beam_width
    self.max_len = max_len
    self.batch_size = batch_size
    self.unique = unique
    self.limit_expansions = limit_expansions

SoftmaxSamplingConfig

Configuration for softmax sampling generation.

Source code in src/flash_ansr/utils/generation.py
def __init__(
    self,
    *,
    choices: int = 32,
    top_k: int = 0,
    top_p: float = 1.0,
    max_len: int = 64,
    batch_size: int = 128,
    temperature: float = 1.0,
    valid_only: bool = True,
    simplify: bool = True,
    unique: bool = True,
) -> None:
    self.method = 'softmax_sampling'
    self.choices = choices
    self.top_k = top_k
    self.top_p = top_p
    self.max_len = max_len
    self.batch_size = batch_size
    self.temperature = temperature
    self.valid_only = valid_only
    self.simplify = simplify
    self.unique = unique

MCTSGenerationConfig

Configuration for Monte Carlo tree search generation.

Source code in src/flash_ansr/utils/generation.py
def __init__(
    self,
    *,
    beam_width: int = 16,
    simulations: int = 256,
    uct_c: float = 1.4,
    expansion_top_k: int = 32,
    max_depth: int = 64,
    rollout_max_len: int | None = None,
    rollout_policy: str = 'sample',
    temperature: float = 1.0,
    dirichlet_alpha: float | None = None,
    dirichlet_epsilon: float = 0.25,
    invalid_penalty: float = 1e6,
    min_visits_before_expansion: int = 1,
    reward_transform: Callable[[float], float] | None = None,
    completion_sort: str = 'reward',
) -> None:
    self.method = 'mcts'
    self.beam_width = beam_width
    self.simulations = simulations
    self.uct_c = uct_c
    self.expansion_top_k = expansion_top_k
    self.max_depth = max_depth
    self.rollout_max_len = rollout_max_len
    self.rollout_policy = rollout_policy
    self.temperature = temperature
    self.dirichlet_alpha = dirichlet_alpha
    self.dirichlet_epsilon = dirichlet_epsilon
    self.invalid_penalty = invalid_penalty
    self.min_visits_before_expansion = min_visits_before_expansion
    self.reward_transform = reward_transform
    self.completion_sort = completion_sort

Utilities

Resolve a path relative to the repository root (optionally creating directories).

Source code in src/flash_ansr/utils/paths.py
def get_path(*args: str, filename: str | None = None, create: bool = False) -> str:
    """Resolve a path relative to the repository root (optionally creating directories)."""
    if any(not isinstance(arg, str) for arg in args):
        raise TypeError("All arguments must be strings.")

    path = normalize_path_preserve_leading_dot(
        os.path.join(os.path.dirname(__file__), '..', '..', '..', *args, filename or '')
    )

    if create:
        if filename is not None:
            os.makedirs(os.path.dirname(path), exist_ok=True)
        else:
            os.makedirs(path, exist_ok=True)

    return os.path.abspath(path)

Load a YAML config (optionally resolving nested relative paths).

Source code in src/flash_ansr/utils/config_io.py
def load_config(config: dict[str, Any] | str, resolve_paths: bool = True) -> dict[str, Any]:
    """Load a YAML config (optionally resolving nested relative paths)."""
    if isinstance(config, str):
        config_path = substitute_root_path(config)
        config_base_path = os.path.dirname(config_path)

        if not os.path.exists(config_path):
            raise FileNotFoundError(f'Config file {config_path} not found.')
        if os.path.isfile(config_path):
            with open(config_path, 'r') as config_file:
                config_ = yaml.safe_load(config_file)
        else:
            raise ValueError(f'Config file {config_path} is not a valid file.')

        def resolve_path(value: Any) -> Any:
            if (
                isinstance(value, str)
                and (value.endswith('.yaml') or value.endswith('.json'))
                and value.startswith('.')
            ):
                return normalize_path_preserve_leading_dot(os.path.join(config_base_path, value))
            return value

        if resolve_paths:
            config_ = apply_on_nested(config_, resolve_path)
    else:
        config_ = config

    return config_