Building Your Own Models
From Following Recipes to Creating Your Own
You’ve learned to use GenJAX through examples. Now it’s time to build your own probabilistic models!
This chapter shows you how to think about building generative models — turning real-world problems into code.
The Model-Building Process
Step 1: Understand the Problem
Before writing any code, answer:
- What am I trying to predict or understand? (The question)
- What do I observe? (The data/evidence)
- What’s hidden? (The unknown variables)
- How are they related? (The causal structure)
Example: Spam detection
- Question: Is this email spam?
- Observations: Email content, sender, time
- Hidden: True spam status
- Relationship: Spam emails have certain word patterns
Step 2: Sketch the Generative Story
Write out the process that generates the data:
“First, nature chooses…, then based on that, it generates…, which produces…”
Example: Coin flips
- First, the coin has a (hidden) bias parameter
- Based on that bias, each flip is heads or tails
- We observe a sequence of flips
This narrative becomes your code!
Step 3: Choose Distributions
For each random choice, pick a distribution:
| Type of Variable | Common Distributions |
|---|---|
| Binary (yes/no) | flip(p) |
| Categorical (A/B/C) | categorical(probs) |
| Count (0, 1, 2, …) | poisson(rate) |
| Continuous | normal(mean, std), uniform(low, high) |
Start simple! Use flip for most binary choices.
Step 4: Write the Code
Pattern:
| |
Key points:
- Use
@gendecorator - Name all random choices with
@ "name" - Return what you want to infer
- Use
ifstatements to model dependencies
Step 5: Test and Validate
- Generate samples — does the output look reasonable?
- Check extreme cases — what if parameters are 0 or 1?
- Verify inference — do posterior results make intuitive sense?
📐→💻 Math-to-Code Translation
How model-building concepts translate to GenJAX:
| Math Concept | Mathematical Notation | GenJAX Pattern |
|---|---|---|
| Joint Distribution | $P(X, Y)$ | Multiple flip() calls in @gen function |
| Conditional Distribution | $P(Y \mid X)$ | if X: Y = flip(p1) |
| Independence | $P(X, Y) = P(X) \cdot P(Y)$ | Separate random choices (no if statements) |
| Dependence | $P(Y \mid X) \neq P(Y)$ | Y’s distribution uses X in if statement |
| Hierarchical Model | $\theta \sim \text{Prior}, X \mid \theta$ | Parameter as random variable: theta = uniform() @ "theta" |
| Mixture Model | $\sum_k P(Z=k) P(X \mid Z=k)$ | if category == k: X = distribution_k() |
| Sequence Model | $P(X_t \mid X_{t-1})$ | Loop with prev_state dependency |
Common modeling patterns:
| Pattern | Probability Structure | Code Structure |
|---|---|---|
| Independent observations | $P(X_1, \ldots, X_n) = \prod P(X_i)$ | for i: X_i = flip() |
| Hierarchical | $P(\theta) P(X \mid \theta)$ | theta = uniform(); X = flip(theta) |
| Conditional | $P(Y \mid X)$ depends on X | if X: Y = flip(p1) else: Y = flip(p2) |
| Time series | $P(X_t \mid X_{t-1})$ | for t: X[t] = flip(f(X[t-1])) |
| Mixture | $\sum_k \pi_k P(X \mid k)$ | k = categorical(pi); if k==0: ... else: ... |
Key insights:
- @gen function = Joint distribution — Defines P(all variables)
- if statements = Conditional dependence — Y depends on X
- for loops = Repeated structure — Multiple observations or time steps
- Parameters as random variables = Hierarchical — Uncertainty at multiple levels
- Your generative story = The math — If you can describe how data is generated, you can code it
Example: Medical diagnosis
Math: P(Disease, Fever, Cough) = P(Disease) × P(Fever|Disease) × P(Cough|Disease)
Code: has_disease = flip(0.01) @ "disease"
fever_prob = jnp.where(has_disease, 0.9, 0.1)
cough_prob = jnp.where(has_disease, 0.8, 0.2)
fever = flip(fever_prob) @ "fever"
cough = flip(cough_prob) @ "cough"Common Patterns
Pattern 1: Independent Observations
Scenario: Multiple independent measurements
Example: Coin flips
| |
Usage:
| |
Output (example):
Flips: [1 0 1 1 1 0 1 1 1 0]Pattern 2: Hierarchical Structure
Scenario: Parameters have their own distributions
Example: Learning a coin’s bias from flips
| |
Inference:
| |
Output (example):
Estimated bias: 0.69Pattern 3: Conditional Dependencies
Scenario: Observations depend on hidden state
Example: Weather affects mood
| |
Question: “Chibany is happy. What’s the probability it’s sunny?”
| |
Output (example):
P(Sunny | Happy) ≈ 0.875Pattern 4: Sequences and Time Series
Scenario: Events unfold over time
Example: Chibany’s weekly meals
| |
This models dependence through time!
Pattern 5: Mixture Models
Scenario: Data comes from multiple sources, but which source is not observed
Example: Two types of days (weekday vs weekend). Chibany doesn’t know what day it is. Bentos on the weekend are much more likely to have tonkatsu.
| |
Infer: “Given Chibany had tonkatsu, is it a weekend?”
Building a Complete Model: Medical Diagnosis
Let’s build a realistic example from scratch.
Scenario: Diagnosing a disease based on symptoms
Setup:
- Disease prevalence: 1% (rare)
- Symptom 1 (fever): 90% if diseased, 10% if healthy
- Symptom 2 (cough): 80% if diseased, 20% if healthy
Question: Patient has fever and cough. Probability of disease?
Step 1: Understand the Problem
- Question: Does patient have disease?
- Observations: Fever and cough
- Hidden: True disease status
- Relationships: Symptoms more likely if diseased
Step 2: Generative Story
- First, patient either has disease (1%) or not (99%)
- If diseased, fever is very likely (90%)
- If diseased, cough is very likely (80%)
- If healthy, both symptoms are rare (10%, 20%)
Step 3: Write the Model
| |
Step 4: Run Inference
| |
Output (example):
=== MEDICAL DIAGNOSIS ===
Prevalence: 1%
Symptoms: Fever + Cough
P(Disease | Symptoms) ≈ 0.266Expected: ≈ 0.265 (26.5%)
Interpretation: Even with both symptoms, only 26.5% chance of disease because it’s so rare!
Base Rate Neglect in Medicine!
This is why false positives are a problem in medical testing.
Even accurate tests produce many false positives for rare diseases because:
- True positives: $0.01 \times 0.9 \times 0.8 = 0.0072$ (0.72%)
- False positives: $0.99 \times 0.1 \times 0.2 = 0.0198$ (1.98%)
More false positives than true positives!
This is why doctors don’t diagnose based on symptoms alone — they need confirmatory tests or consider patient history (updating the prior).
Best Practices
✅ DO
1. Name everything clearly
| |
2. Use meaningful parameters
| |
3. Document your model
| |
4. Start simple, add complexity
- Build the simplest model first
- Verify it works
- Add features incrementally
5. Test edge cases
- What if parameters are 0? 1?
- What if all observations are the same?
- Does the posterior make intuitive sense?
❌ DON’T
1. Don’t forget to name random choices
| |
2. Don’t use the same name twice
| |
3. Don’t overthink distributions
flipcovers most binary casesnormalfor continuouscategoricalfor multiple choices- You don’t need exotic distributions to start!
4. Don’t skip validation
- Always generate samples first
- Check if outputs look reasonable
- Verify extreme parameter values
Exercises
Exercise 1: Email Spam Filter
Build a simple spam filter model.
Scenario:
- 30% of emails are spam
- Spam emails contain “FREE” 80% of the time
- Legitimate emails contain “FREE” 10% of the time
Task: Calculate $P(\text{Spam} \mid \text{contains “FREE”})$
Exercise 2: Learning from Multiple Observations
Extend the coin flip model to infer bias from multiple observations.
Task: Given a sequence of 20 flips (e.g., [1,1,0,1,1,1,0,1,1,1,1,0,1,1,0,1,1,1,1,1]), infer the coin’s bias.
Exercise 3: Multi-Symptom Diagnosis
Extend the disease model to include 3 symptoms: fever, cough, fatigue.
Parameters:
- Disease: 2% prevalence
- If diseased: fever 90%, cough 80%, fatigue 95%
- If healthy: fever 10%, cough 20%, fatigue 30%
Task: Calculate posterior for:
- Fever only
- Fever + cough
- All three symptoms
What You’ve Learned
In this chapter, you learned:
✅ The model-building process — from problem to code ✅ Common patterns — independent, hierarchical, conditional, sequential, mixture ✅ Best practices — naming, documentation, testing ✅ Complete examples — medical diagnosis, spam filtering, coin flipping ✅ How to think generatively — “what generates the data?”
The key insight: Building models is about encoding your assumptions about how the world works, then letting GenJAX do the inference!
Next Steps
You’re Ready to Build!
You now have all the tools to:
- Build generative models for your problems
- Perform Bayesian inference automatically
- Understand uncertainty in your predictions
Where to go from here:
1. Explore More Distributions
GenJAX supports many distributions beyond flip:
normal(mean, std)— Continuous values (heights, weights, temperatures)categorical(probs)— Multiple discrete choices (A, B, C, D)poisson(rate)— Count data (number of events)gamma,beta,exponential— Specialized continuous distributions
See the GenJAX documentation for complete reference.
2. Learn Advanced Inference
This tutorial covered:
- Filtering/rejection sampling
- Conditioning with
generate()
Next level:
- Importance sampling (more efficient for rare events)
- Markov Chain Monte Carlo (MCMC) for complex models
- Variational inference (approximate but fast)
Check out: GenJAX advanced tutorials
3. Real-World Applications
Apply what you learned to:
- Science: Modeling experiments, analyzing data
- Medicine: Diagnosis, treatment optimization
- Engineering: Fault detection, quality control
- Social science: Understanding human behavior
- AI/ML: Building better models with uncertainty
The Journey
You started with: Sets, counting, basic probability
Now you can: Build probabilistic programs, perform Bayesian inference, reason under uncertainty
That’s a huge accomplishment!
Final Thoughts
Probabilistic programming is a superpower:
- Express uncertainty — the world is uncertain, our models should reflect that
- Automate inference — computers do the hard math
- Combine knowledge and data — use both domain expertise (priors) and observations (data)
- Make better decisions — understand risks and probabilities
Keep building, keep learning, keep questioning!
Chapter Complete!
You’ve learned how to build your own probabilistic models from scratch. This is the final chapter of the GenJAX programming tutorial.
What you accomplished in this tutorial:
- Set up your GenJAX environment
- Learned essential Python for probabilistic programming
- Built generative models with the
@gendecorator - Understood traces and how GenJAX records execution
- Conditioned models on observations
- Performed inference to answer probabilistic questions
- Created complete models for real-world problems
You’re ready for the next step!
What’s Next: Continuous Probability & Bayesian Learning
So far, you’ve worked with discrete random variables (coin flips, categories, yes/no outcomes). But many real-world quantities are continuous — heights, temperatures, waiting times.
In Tutorial 3: Continuous Probability & Bayesian Learning, you’ll:
- Work with continuous distributions (normal, exponential, etc.)
- Learn about Bayesian updating with continuous parameters
- Build mixture models for clustering
- Explore the Dirichlet Process for infinite mixtures
The probabilistic programming skills you’ve learned here will transfer directly!
Continue to Tutorial 3: Continuous Probability →
| ← Previous: Inference in Action | Tutorial 3: Continuous Probability → |
|---|
