A Theory of Mind-Based Cognitive Framework with Adaptive Neuro-Fuzzy Inference for Explainable Pedestrian Crossing Intention Prediction
ToM-ANFIS/
├── app/
│ ├── artifacts/ # Saved model weights and scaler
│ │ ├── anfis_model_optimized.pth
│ │ └── scaler_optimized.save
│ ├── core/ # ANFIS model class and ToM feature extractor class
│ │ ├── anfis.py
│ │ └── extractor.py
│ ├── routes/ # FastAPI route handlers
│ │ └── predict.py
│ ├── schemas/ # Pydantic input/output schemas
│ │ ├── RawAnnotationInput.py
│ │ ├── ToMFeaturesInput.py
│ │ └── PredictionResponse.py
│ ├── services/ # Business logic layer
│ │ ├── prediction_service.py
│ │ └── feature_extraction_service.py
│ ├── config.py # App settings (paths, thresholds)
│ └── main.py # FastAPI app entry point
├── notebooks/ # Training notebooks
├── pyproject.toml
├── uv.lock
├── .gitignore
└── README.md
On Windows:
powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"On macOS/Linux
curl -LsSf https://astral.sh/uv/install.sh | shgit clone https://github.com/roshana1s/ToM-ANFIS.git
cd ToM-ANFISuv syncCopy the saved model and scaler files into app/artifacts/:
app/artifacts/anfis_model_optimized.pth
app/artifacts/scaler_optimized.save
uv run python -m app.mainThe API will be available at http://localhost:8000.
Interactive docs at http://localhost:8000/docs.
uv add <package-name>This updates pyproject.toml and uv.lock automatically.
Observation Window (default: 15 frames)
- Consecutive frames aggregated to compute temporal ToM features (B3, B4, G2, A1, A2, C1, C2)
- Each feature value is the mean/ratio over these 15 frames
- Training notebook & current model use
OBSERVATION_WINDOW=15
Prediction Horizon (default: 30 frames)
- How many frames into the future the model predicts for
prediction_horizon=30means: given 15 frames of observation, predict the crossing state 30 frames ahead- Total frames needed: 15 + 30 = 45 frames minimum
Frame Timeline Visualization
[Observation: 15 frames] [Future Horizon: 30 frames] [Crossing Event]
Frames 0-14 Frames 15-44 Frame 45 (or 59 in longer video)
Model observes frames 0-14, predicts: "30 frames from now, will pedestrian cross?"
Real-world Example: 60-frame video
Video: frame_0 ... frame_29 ... frame_59 (crossing occurs)
With obs=15, horizon=30:
Prediction frame = 59 - 30 = 29
Observation window = frames [15..29] (last 15 frames before prediction)
Model predicts: "At frame 29, will crossing occur 30 frames from now?"
Answer: Yes, at frame 59 (which is 45 frames from window start)
Frame Requirement
- API requires minimum 45 frames (obs_window + prediction_horizon) to avoid truncation
- Fewer frames will fail validation with clear error message
Purpose: Accept raw PIE-style annotations, extract ToM features, then run inference.
Required Fields:
age(int 0-3): pedestrian age categoryintersection(int 0-4): intersection typeframes(list of ints): frame IDs in orderbbox(list of [x1,y1,x2,y2]): pedestrian bounding boxes per frameocclusion(list of ints): occlusion level each frame (0=visible, 1=partial, 2=hidden)behavior(object): action, gesture, look lists (one value per frame)traffic_objects(list): vehicles, crosswalks, traffic lights with obj_class, frames, optional bbox/stateobd_speed(dict): mapping frame_id → speed (m/s)
Optional Fields:
observation_window(int): must equal configured default (15)prediction_horizon(int): override default (30) for different future time
Example Payload (minimum valid request with 45 frames):
{
"age": 2,
"intersection": 0,
"frames": [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44],
"bbox": [[100,200,150,260], [101,200,151,260], [102,200,152,260], [103,200,153,260], [104,200,154,260], [105,200,155,260], [106,200,156,260], [107,200,157,260], [108,200,158,260], [109,200,159,260], [110,200,160,260], [111,200,161,260], [112,200,162,260], [113,200,163,260], [114,200,164,260], [115,200,165,260], [116,200,166,260], [117,200,167,260], [118,200,168,260], [119,200,169,260], [120,200,170,260], [121,200,171,260], [122,200,172,260], [123,200,173,260], [124,200,174,260], [125,200,175,260], [126,200,176,260], [127,200,177,260], [128,200,178,260], [129,200,179,260], [130,200,180,260], [131,200,181,260], [132,200,182,260], [133,200,183,260], [134,200,184,260], [135,200,185,260], [136,200,186,260], [137,200,187,260], [138,200,188,260], [139,200,189,260], [140,200,190,260], [141,200,191,260], [142,200,192,260], [143,200,193,260], [144,200,194,260]],
"occlusion": [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
"behavior": {
"action": [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],
"gesture": [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
"look": [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]
},
"traffic_objects": [],
"obd_speed": {"0": 6.0, "1": 5.9, "2": 5.8, "3": 5.7, "4": 5.6, "5": 5.5, "6": 5.4, "7": 5.3, "8": 5.2, "9": 5.1, "10": 5.0, "11": 4.9, "12": 4.8, "13": 4.7, "14": 4.6, "15": 4.5, "16": 4.4, "17": 4.3, "18": 4.2, "19": 4.1, "20": 4.0, "21": 3.9, "22": 3.8, "23": 3.7, "24": 3.6, "25": 3.5, "26": 3.4, "27": 3.3, "28": 3.2, "29": 3.1, "30": 3.0, "31": 2.9, "32": 2.8, "33": 2.7, "34": 2.6, "35": 2.5, "36": 2.4, "37": 2.3, "38": 2.2, "39": 2.1, "40": 2.0, "41": 1.9, "42": 1.8, "43": 1.7, "44": 1.6},
"observation_window": 15,
"prediction_horizon": 30
}Response:
{
"predicted_crossing_probability": 0.87,
"intention_level": "HIGH",
"top_fired_rules": [
{
"rule_id": 1234,
"firing_strength": 0.456,
"rule": "IF A1 is High AND A2 is High THEN crossing intention is HIGH",
"consequent_linguistic": "HIGH"
}
]
}Purpose: Accept pre-computed 9-dimensional ToM feature vector and run ANFIS inference directly. Use this endpoint when features are already extracted or computed externally.
Required Fields:
B1_age_vulnerability_awareness(float [0, 1]): age-group vulnerability factorB2_intersection_complexity(float [0, 1]): intersection layout complexityB3_pedestrian_position_proximity(float [0, 1]): proximity to vehicle/curbB4_vehicle_density(float [0, 1]): density of surrounding vehiclesG2_urgency_rushing(float [0, 1]): behavioral rushing indicatorA1_vehicle_awareness(float [0, 1]): attention to vehiclesA2_traffic_signal_awareness(float [-1, 1]): attention to traffic signals (-1=looking away, +1=looking towards)C1_relative_vehicle_speed(float [0, 1]): normalized vehicle speedC2_occlusion_factor(float [0, 1]): visibility/occlusion level
Example Payload:
{
"B1_age_vulnerability_awareness": 0.3,
"B2_intersection_complexity": 0.5,
"B3_pedestrian_position_proximity": 0.8,
"B4_vehicle_density": 0.1,
"G2_urgency_rushing": 0.4,
"A1_vehicle_awareness": 0.9,
"A2_traffic_signal_awareness": 1.0,
"C1_relative_vehicle_speed": 0.2,
"C2_occlusion_factor": 0.0
}Response: Same as Endpoint 1 (crossing probability, intention level, top fired rules)
- Swagger UI:
http://localhost:8000/docs - ReDoc:
http://localhost:8000/redoc
Both provide live request/response testing for both endpoints.
- Minimum frames required: 45 (observation_window + prediction_horizon = 15 + 30)
- Per-frame list alignment: All lists (
bbox,occlusion,behavior.*,obd_speedkeys) must matchlen(frames) - Frame order: Frames must be in ascending sequential order
- Normal range [0, 1]: B1, B2, B3, B4, G2, A1, C1, C2
- Extended range [-1, 1]: A2 (signal awareness can be negative)
- Submit invalid ranges → API returns validation error with field guidance
| Error | Cause | Fix |
|---|---|---|
Insufficient frames |
len(frames) < 45 |
Increase frames to ≥ 45 |
List length mismatch |
len(bbox) ≠ len(frames) |
Ensure all per-frame lists match frame count |
Value out of range |
Feature outside valid bounds | Clamp feature to [0,1] or [-1,1] for A2 |
Invalid frame order |
Frames not ascending | Sort frames before submission |
The ANFIS model uses a 5-layer neuro-fuzzy architecture:
- Input Layer: 9 normalized ToM features
- Membership Functions: Gaussian curves for feature fuzzification
- Rule Layer: 19,683 fuzzy IF-THEN rules (3⁹ combinations)
- Normalization & Inference: Weighted rule firing
- Output Layer: Sigmoid-activated crossing probability ∈ [0, 1]
The predict_explain() method returns:
- predicted_crossing_probability (float): sigmoid output indicating likelihood of crossing
- intention_level (str): "LOW", "MEDIUM", or "HIGH" based on threshold bands
- top_fired_rules (list): 3 most relevant rules with firing strengths and linguistic descriptions
- Training Notebook:
notebooks/ToM_ANFIS_training_pipeline.ipynb - Core Model:
app/core/anfis.py - Feature Extractor:
app/core/extractor.py - PIE Dataset: https://github.com/aras62/PIE
- Off-by-one behavior: The current extractor uses
prediction_horizon=30to observe features at frame position 29 when a crossing event occurs at frame 59 (i.e., 30 frames in the future). This is by design to avoid boundary truncation with the trained model. See "Frame Timeline Visualization" above for examples. - Model serialization: The trained ANFIS model and MinMaxScaler are loaded from disk at startup (
app/artifacts/). Ensure both files exist before running the API. - Feature normalization: All 9 features are normalized to [0, 1] using the fitted MinMaxScaler before inference.