Skip to content

Commit c246772

Browse files
Fix/whatsapp connect (#1254)
* fix: recover whatsapp qr after deleted device * fix: whatsapp connect * fix whatsapp instance device scoping --------- Co-authored-by: Duy /zuey/ <duy@wearetopgroup.com>
1 parent 0ae5599 commit c246772

11 files changed

Lines changed: 777 additions & 53 deletions

cmd/gateway_channels_setup.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func registerConfigChannels(cfg *config.Config, channelMgr *channels.Manager, ms
7474
if strings.Contains(fmt.Sprintf("%T", pgStores.DB.Driver()), "sqlite") {
7575
waDialect = "sqlite3"
7676
}
77-
wa, err := whatsapp.New(cfg.Channels.WhatsApp, msgBus, pgStores.Pairing, pgStores.DB, pgStores.PendingMessages, waDialect, audioMgr, pgStores.BuiltinTools)
77+
wa, err := whatsapp.New(cfg.Channels.WhatsApp, msgBus, pgStores.Pairing, pgStores.DB, pgStores.PendingMessages, waDialect, audioMgr, pgStores.BuiltinTools, whatsapp.WithLegacyFirstDeviceFallback())
7878
if err != nil {
7979
channelMgr.RecordFailure(channels.TypeWhatsApp, "", err)
8080
slog.Error("failed to initialize whatsapp channel", "error", err)

internal/channels/instance_loader.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ type InstanceLoader struct {
4646
msgBus *bus.MessageBus
4747
pairingSvc store.PairingStore
4848
mu sync.Mutex
49-
loaded map[string]struct{} // channel names managed by this loader
49+
loaded map[string]struct{} // channel names managed by this loader
50+
loadedIDs map[string]uuid.UUID // channel name -> DB instance ID
5051
}
5152

5253
// NewInstanceLoader creates a new InstanceLoader.
@@ -65,6 +66,7 @@ func NewInstanceLoader(
6566
msgBus: msgBus,
6667
pairingSvc: pairingSvc,
6768
loaded: make(map[string]struct{}),
69+
loadedIDs: make(map[string]uuid.UUID),
6870
}
6971
}
7072

@@ -132,6 +134,7 @@ func (l *InstanceLoader) Reload(ctx context.Context) {
132134
l.manager.UnregisterChannel(name)
133135
}
134136
l.loaded = make(map[string]struct{})
137+
l.loadedIDs = make(map[string]uuid.UUID)
135138

136139
// Brief pause to let external APIs (e.g., Telegram getUpdates) release polling locks.
137140
time.Sleep(500 * time.Millisecond)
@@ -170,10 +173,19 @@ func (l *InstanceLoader) LoadInstanceByID(ctx context.Context, id uuid.UUID) err
170173
}
171174

172175
if _, ok := l.loaded[inst.Name]; ok {
173-
if _, exists := l.manager.GetChannel(inst.Name); exists {
174-
return nil
176+
if loadedID, hasLoadedID := l.loadedIDs[inst.Name]; hasLoadedID && loadedID == inst.ID {
177+
if _, exists := l.manager.GetChannel(inst.Name); exists {
178+
return nil
179+
}
180+
} else if ch, exists := l.manager.GetChannel(inst.Name); exists {
181+
if err := ch.Stop(ctx); err != nil {
182+
slog.Warn("failed to stop stale channel instance before targeted reload",
183+
"name", inst.Name, "old_id", loadedID, "new_id", inst.ID, "error", err)
184+
}
185+
l.manager.UnregisterChannel(inst.Name)
175186
}
176187
delete(l.loaded, inst.Name)
188+
delete(l.loadedIDs, inst.Name)
177189
}
178190

179191
if !inst.Enabled {
@@ -202,6 +214,7 @@ func (l *InstanceLoader) Stop(ctx context.Context) {
202214
l.manager.UnregisterChannel(name)
203215
}
204216
l.loaded = make(map[string]struct{})
217+
l.loadedIDs = make(map[string]uuid.UUID)
205218
}
206219

207220
// coerceStringBools converts string "true"/"false" values to JSON booleans
@@ -249,6 +262,7 @@ func (l *InstanceLoader) LoadedNames() map[string]struct{} {
249262
// If false, the caller is responsible for starting (used by LoadAll, where StartAll handles it).
250263
func (l *InstanceLoader) loadInstance(ctx context.Context, inst store.ChannelInstanceData, autoStart bool) error {
251264
l.loaded[inst.Name] = struct{}{}
265+
l.loadedIDs[inst.Name] = inst.ID
252266

253267
factory, ok := l.factories[inst.ChannelType]
254268
if !ok {

internal/channels/instance_loader_timeout_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,52 @@ func TestLoadInstanceByIDLoadsTargetWithoutReloadingExisting(t *testing.T) {
141141
}
142142
}
143143

144+
func TestLoadInstanceByIDRefreshesSameNameDifferentInstance(t *testing.T) {
145+
msgBus := bus.New()
146+
mgr := NewManager(msgBus)
147+
oldID := uuid.New()
148+
newID := uuid.New()
149+
150+
loader := NewInstanceLoader(&singleInstanceStore{inst: store.ChannelInstanceData{
151+
BaseModel: store.BaseModel{ID: newID},
152+
Name: "whatsapp-main",
153+
ChannelType: TypeWhatsApp,
154+
Enabled: true,
155+
}}, nil, mgr, msgBus, nil)
156+
157+
oldChannel := newTimeoutTestChannel("whatsapp-main", TypeWhatsApp, false)
158+
mgr.RegisterChannel("whatsapp-main", oldChannel)
159+
loader.loaded["whatsapp-main"] = struct{}{}
160+
loader.loadedIDs["whatsapp-main"] = oldID
161+
162+
var loadedChannel *timeoutTestChannel
163+
loader.RegisterFactory(TypeWhatsApp, func(name string, _ json.RawMessage, _ json.RawMessage, _ *bus.MessageBus, _ store.PairingStore) (Channel, error) {
164+
loadedChannel = newTimeoutTestChannel(name, TypeWhatsApp, false)
165+
return loadedChannel, nil
166+
})
167+
168+
if err := loader.LoadInstanceByID(context.Background(), newID); err != nil {
169+
t.Fatalf("LoadInstanceByID returned error: %v", err)
170+
}
171+
172+
if oldChannel.stopCalls.Load() == 0 {
173+
t.Fatal("expected stale same-name channel to be stopped before targeted reload")
174+
}
175+
got, ok := mgr.GetChannel("whatsapp-main")
176+
if !ok {
177+
t.Fatal("expected refreshed channel to be registered")
178+
}
179+
if got == oldChannel {
180+
t.Fatal("expected manager to replace stale same-name channel")
181+
}
182+
if got != loadedChannel {
183+
t.Fatal("manager registered unexpected channel after targeted reload")
184+
}
185+
if loader.loadedIDs["whatsapp-main"] != newID {
186+
t.Fatalf("loadedIDs = %s, want %s", loader.loadedIDs["whatsapp-main"], newID)
187+
}
188+
}
189+
144190
// TestLoadInstance_HungStartDoesNotBlock verifies that a Start() that never
145191
// returns is abandoned after reloadStartTimeout so Reload() can proceed to
146192
// other channels instead of deadlocking on the loader mutex.

internal/channels/whatsapp/auth.go

Lines changed: 91 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"log/slog"
77

88
"go.mau.fi/whatsmeow"
9+
wastore "go.mau.fi/whatsmeow/store"
10+
"go.mau.fi/whatsmeow/types"
911
)
1012

1113
// StartQRFlow initiates the QR authentication flow.
@@ -16,35 +18,26 @@ import (
1618
func (c *Channel) StartQRFlow(ctx context.Context) (<-chan whatsmeow.QRChannelItem, error) {
1719
c.reauthMu.Lock()
1820
defer c.reauthMu.Unlock()
19-
if c.client == nil {
20-
// Lazy init: wizard may request QR before Start() is called.
21-
c.mu.Lock()
22-
if c.client == nil {
23-
if c.ctx == nil {
24-
c.ctx, c.cancel = context.WithCancel(context.Background())
25-
}
26-
deviceStore, err := c.container.GetFirstDevice(ctx)
27-
if err != nil {
28-
c.mu.Unlock()
29-
return nil, fmt.Errorf("whatsapp get device: %w", err)
30-
}
31-
c.client = whatsmeow.NewClient(deviceStore, nil)
32-
c.client.AddEventHandler(c.handleEvent)
33-
}
21+
22+
c.mu.Lock()
23+
if err := c.ensureQRClientLocked(ctx); err != nil {
3424
c.mu.Unlock()
25+
return nil, fmt.Errorf("whatsapp get device: %w", err)
3526
}
27+
client := c.client
28+
c.mu.Unlock()
3629

3730
if c.IsAuthenticated() {
3831
return nil, nil // caller checks this
3932
}
4033

41-
qrChan, err := c.client.GetQRChannel(ctx)
34+
qrChan, err := client.GetQRChannel(ctx)
4235
if err != nil {
4336
return nil, fmt.Errorf("whatsapp get QR channel: %w", err)
4437
}
4538

46-
if !c.client.IsConnected() {
47-
if err := c.client.Connect(); err != nil {
39+
if !client.IsConnected() {
40+
if err := client.Connect(); err != nil {
4841
return nil, fmt.Errorf("whatsapp connect for QR: %w", err)
4942
}
5043
}
@@ -90,13 +83,89 @@ func (c *Channel) Reauth() error {
9083
}
9184
c.ctx, c.cancel = context.WithCancel(parent)
9285

93-
// Re-create client with fresh device store.
94-
deviceStore, err := c.container.GetFirstDevice(context.Background())
95-
if err != nil {
86+
if err := c.resetClientLocked(context.Background()); err != nil {
9687
return fmt.Errorf("whatsapp: get fresh device: %w", err)
9788
}
89+
90+
return nil
91+
}
92+
93+
// ensureQRClientLocked lazily creates or refreshes the client before QR login.
94+
// The caller must hold c.mu and c.reauthMu.
95+
func (c *Channel) ensureQRClientLocked(ctx context.Context) error {
96+
if c.client == nil {
97+
return c.resetClientLocked(ctx)
98+
}
99+
if !c.client.Store.Deleted {
100+
return nil
101+
}
102+
c.lastQRMu.Lock()
103+
c.waAuthenticated = false
104+
c.lastQRB64 = ""
105+
c.lastQRMu.Unlock()
106+
return c.resetClientLocked(ctx)
107+
}
108+
109+
// resetClientLocked replaces the whatsmeow client while preserving the channel lifecycle.
110+
// The caller must hold c.mu.
111+
func (c *Channel) resetClientLocked(ctx context.Context) error {
112+
if ctx == nil {
113+
ctx = context.Background()
114+
}
115+
if c.ctx == nil {
116+
parent := c.parentCtx
117+
if parent == nil {
118+
parent = context.Background()
119+
}
120+
c.ctx, c.cancel = context.WithCancel(parent)
121+
}
122+
deviceStore, err := c.resolveDeviceStoreLocked(ctx)
123+
if err != nil {
124+
return err
125+
}
98126
c.client = whatsmeow.NewClient(deviceStore, nil)
99127
c.client.AddEventHandler(c.handleEvent)
100-
101128
return nil
102129
}
130+
131+
func (c *Channel) resolveDeviceStoreLocked(ctx context.Context) (*wastore.Device, error) {
132+
if !c.deviceJID.IsEmpty() {
133+
deviceStore, err := c.container.GetDevice(ctx, c.deviceJID)
134+
if err != nil {
135+
return nil, err
136+
}
137+
if deviceStore != nil && !deviceStore.Deleted {
138+
return deviceStore, nil
139+
}
140+
slog.Info("whatsapp scoped device missing; creating fresh QR device",
141+
"channel", c.Name(), "device_hash", hashWhatsAppIdentifier(c.deviceJID.String()))
142+
return c.container.NewDevice(), nil
143+
}
144+
145+
if c.legacyFirstDeviceFallback {
146+
devices, err := c.container.GetAllDevices(ctx)
147+
if err != nil {
148+
return nil, err
149+
}
150+
if len(devices) > 0 {
151+
return devices[0], nil
152+
}
153+
}
154+
155+
return c.container.NewDevice(), nil
156+
}
157+
158+
func (c *Channel) currentDeviceJID() types.JID {
159+
c.mu.Lock()
160+
defer c.mu.Unlock()
161+
if c.client == nil || c.client.Store == nil {
162+
return types.EmptyJID
163+
}
164+
return c.client.Store.GetJID()
165+
}
166+
167+
func (c *Channel) setDeviceJID(jid types.JID) {
168+
c.mu.Lock()
169+
defer c.mu.Unlock()
170+
c.deviceJID = jid
171+
}

0 commit comments

Comments
 (0)