Skip to content

aither-computer/aither-sdk

Repository files navigation

aither

Python SDK for the aither platform - contextual intelligence and model observability.

Features

  • Zero-code logging: Wrap any model and predictions are logged automatically
  • Framework support: sklearn, pytorch, tensorflow, tinygrad, transformers
  • Smart sampling: Captures sample + metadata from features (never blocks inference)
  • Label correlation: Track ground truth with trace ID correlation
  • Non-blocking: Background worker handles all API communication

Installation

pip install aither

Quick Start

import aither
from sklearn.ensemble import RandomForestClassifier

# Initialize
aither.init()  # Uses AITHER_API_KEY env var

# Train your model
model = RandomForestClassifier().fit(X_train, y_train)

# Wrap it - predictions are now logged automatically
tracked = aither.wrap(model, name="fraud_detector")

# Use normally
predictions = tracked.predict(X_test)

# Get trace_id for label correlation
trace_id = tracked.last_trace_id

Supported Frameworks

Framework Example
sklearn aither.wrap(RandomForestClassifier(), name="clf")
pytorch aither.wrap(MyNet(), name="net")
tensorflow aither.wrap(keras_model, name="tf_model")
tinygrad aither.wrap(tinygrad_model, name="tiny")
transformers aither.wrap(pipeline("sentiment"), name="sentiment")

Also works with sklearn-compatible libraries: xgboost, lightgbm, catboost.

Configuration

Environment Variables

export AITHER_API_KEY="aith_your_api_key"
export AITHER_BASE_URL="https://aither.computer"  # optional

Explicit Initialization

aither.init(
    api_key="aith_your_api_key",
    base_url="https://aither.computer",
    flush_interval=1.0,
    batch_size=100,
)

API Reference

aither.wrap(model, name, **options) - Recommended

Wrap a model for automatic prediction logging.

tracked = aither.wrap(
    model,                          # Any ML model
    name="fraud_detector",          # Required: model identifier
    version="1.2.3",                # Optional: model version
    environment="production",       # Optional: deployment env
    sample_rows=5,                  # Max rows to sample (default: 5)
    sample_columns=10,              # Max columns to sample (default: 10)
    features_fn=custom_extractor,   # Custom feature extraction
)

# Use the wrapped model normally
result = tracked.predict(X)
result = tracked(X)                 # Also works
probs = tracked.predict_proba(X)    # Also tracked

# Access trace_id for label correlation
trace_id = tracked.last_trace_id
all_trace_ids = tracked.trace_ids   # All predictions

# Access underlying model
tracked.model                       # Original model
tracked.classes_                    # Passthrough to model attributes

aither.log_prediction(...) - Manual Control

For custom pipelines or when you need full control:

trace_id = aither.log_prediction(
    model_name="my_pipeline",
    features={"amount": 150.0, "country": "US"},
    prediction=0.87,
    version="1.2.3",
    environment="production",
)

aither.log_label(trace_id, label)

Log ground truth for a prediction:

aither.log_label(trace_id=trace_id, label=1)

@aither.track(name) - Decorator

For functions instead of model objects:

@aither.track("my_function")
def predict(features):
    return model.predict(features)

result = predict(X)
trace_id = aither.last_trace_id()

Feature Extraction

Wrapped models automatically extract features with smart sampling:

# Input: pandas DataFrame with 50,000 rows, 100 columns
# What gets logged:
{
    "sample": [
        {"col1": 1.5, "col2": "A", ...},  # 5 rows
        ...
    ],
    "_meta": {
        "type": "pandas.DataFrame",
        "shape": [50000, 100],
        "columns": ["col1", "col2", ...],
        "dtypes": {"col1": "float64", ...},
        "truncated": True
    }
}

Supported types:

  • numpy arrays
  • pandas DataFrames/Series
  • polars DataFrames/Series
  • torch Tensors
  • tensorflow Tensors
  • dicts (common in transformers)
  • lists

Custom extraction:

def my_extractor(X):
    return {"shape": X.shape, "mean": X.mean()}

tracked = aither.wrap(model, name="m", features_fn=my_extractor)

Usage Patterns

Basic Model Wrapping

import aither
from sklearn.ensemble import RandomForestClassifier

aither.init()

model = RandomForestClassifier().fit(X_train, y_train)
tracked = aither.wrap(model, name="churn_predictor", version="1.0")

predictions = tracked.predict(X_test)

PyTorch Model

import aither
import torch

aither.init()

class MyNet(torch.nn.Module):
    def forward(self, x):
        return self.layers(x)

model = MyNet()
tracked = aither.wrap(model, name="my_net")

# Both work:
output = tracked(input_tensor)
output = tracked.predict(input_tensor)

Ground Truth Correlation

# At prediction time
predictions = tracked.predict(X)
trace_ids = tracked.trace_ids  # List of all trace IDs

# Store trace_ids with your predictions
save_to_db(prediction_ids, trace_ids)

# Later, when ground truth is known
for pred_id, trace_id in load_from_db():
    actual = get_actual_outcome(pred_id)
    aither.log_label(trace_id, actual)

FastAPI Integration

import aither
from fastapi import FastAPI

aither.init()
model = aither.wrap(load_model(), name="api_model")
app = FastAPI()

@app.post("/predict")
async def predict(data: dict):
    prediction = model.predict(data)
    return {
        "prediction": prediction,
        "trace_id": model.last_trace_id
    }

@app.post("/label")
async def label(trace_id: str, actual: int):
    aither.log_label(trace_id, actual)
    return {"status": "ok"}

@app.on_event("shutdown")
async def shutdown():
    aither.close()

Management API

The SDK provides namespaces for managing your organization, API keys, and user account.

import aither

aither.init()

# Organization info
org = aither.org.get()
print(org.name, org.plan)

# Usage stats for current billing period
usage = aither.org.usage()
print(f"API calls: {usage.api_calls}")

# Current user
me = aither.user.me()
print(me.email)

# API key management (requires admin scope)
keys = aither.api_keys.list()
new_key = aither.api_keys.create(name="Production", scopes=["read", "write"])
aither.api_keys.revoke(key_id="...")

Why .get() instead of direct attributes?

You might wonder why aither.org.get().name instead of just aither.org.name. This is intentional:

  1. Network calls are explicit - .get() makes it clear you're making an HTTP request. Hidden network calls on attribute access would be surprising and expensive.

  2. Caching semantics are clear - The returned Organization is a point-in-time snapshot. You control when to refresh by calling .get() again.

  3. Error handling is predictable - Exceptions from HTTP failures occur at the .get() call site, not on attribute access.

# Recommended: fetch once, use the snapshot
org = aither.org.get()
print(f"{org.name} on {org.plan} plan")

# Compare states over time
org_before = aither.org.get()
# ... make changes ...
org_after = aither.org.get()
if org_before.plan != org_after.plan:
    print("Plan changed!")

Data Format

The SDK uses OTLP (OpenTelemetry Protocol) to send predictions as spans with ml.* attributes:

Attribute Description
ml.model.name Model identifier
ml.model.version Model version
ml.features JSON-encoded feature sample + metadata
ml.prediction JSON-encoded prediction
ml.label Ground truth value
ml.environment Deployment environment

License

MIT

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages