Getting Started¶
Download a checkpoint¶
By default models are cached under./models/ relative to the package root and can be uninstalled with flash_ansr remove <repo>.
Models can also be managed with the Python API via flash_ansr.model.manage.install_model and flash_ansr.model.manage.remove_model.
See all available models on Hugging Face:
Minimal inference Example¶
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Import flash_ansr
from flash_ansr import (
FlashANSR,
SoftmaxSamplingConfig,
install_model,
get_path,
)
# Select a model from Hugging Face
# https://huggingface.co/models?search=flash-ansr-v23.0
MODEL = "psaegert/flash-ansr-v23.0-120M"
# Download the latest snapshot of the model
# By default, the model is downloaded to the directory `./models/` in the package root
install_model(MODEL)
# Load the model
model = FlashANSR.load(
directory=get_path('models', MODEL),
generation_config=SoftmaxSamplingConfig(choices=32), # or BeamSearchConfig / MCTSGenerationConfig
n_restarts=8,
).to(device)
# Define data
X = ...
y = ...
# Fit the model to the data
model.fit(X, y, verbose=True)
# Show the best expression
print(model.get_expression())
# Predict with the best expression
y_pred = model.predict(X)
Find more details in the API Reference.
One-command evaluation¶
Produces a pickle underresults/evaluation/... with entries like:
predicted_expressionpredicted_log_proby_pred...
For more details, see Evaluation.
Next steps¶
- See Concepts & Architecture for how the pieces fit together.
- For training your own checkpoints, jump to Training.
- For baseline comparisons and sweeps, read Evaluation.