Skip to content

Commit 6af5bfd

Browse files
authored
Fix/send frame cryptor events from signaling thread (#95)
* fix: send framecyrotor events from signaling thread. (WIP) * replace std::shared_ptr to rtc::scoped_refptr. * create framecryptor with signaling_thread. * null check. * fix. * fix.
1 parent 13fe8b2 commit 6af5bfd

File tree

7 files changed

+108
-91
lines changed

7 files changed

+108
-91
lines changed

api/crypto/frame_crypto_transformer.cc

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -305,15 +305,16 @@ int AesEncryptDecrypt(EncryptOrDecrypt mode,
305305
return AesCbcEncryptDecrypt(mode, raw_key, iv, data, buffer);
306306
}
307307
}
308-
309308
namespace webrtc {
310309

311310
FrameCryptorTransformer::FrameCryptorTransformer(
311+
rtc::Thread* signaling_thread,
312312
const std::string participant_id,
313313
MediaType type,
314314
Algorithm algorithm,
315315
rtc::scoped_refptr<KeyProvider> key_provider)
316-
: participant_id_(participant_id),
316+
: signaling_thread_(signaling_thread),
317+
participant_id_(participant_id),
317318
type_(type),
318319
algorithm_(algorithm),
319320
key_provider_(key_provider) {
@@ -363,9 +364,7 @@ void FrameCryptorTransformer::encryptFrame(
363364
<< "FrameCryptorTransformer::encryptFrame() sink_callback is NULL";
364365
if (last_enc_error_ != FrameCryptionState::kInternalError) {
365366
last_enc_error_ = FrameCryptionState::kInternalError;
366-
if (observer_)
367-
observer_->OnFrameCryptionStateChanged(participant_id_,
368-
last_enc_error_);
367+
onFrameCryptionStateChanged(last_enc_error_);
369368
}
370369
return;
371370
}
@@ -387,9 +386,7 @@ void FrameCryptorTransformer::encryptFrame(
387386
<< participant_id_;
388387
if (last_enc_error_ != FrameCryptionState::kMissingKey) {
389388
last_enc_error_ = FrameCryptionState::kMissingKey;
390-
if (observer_)
391-
observer_->OnFrameCryptionStateChanged(participant_id_,
392-
last_enc_error_);
389+
onFrameCryptionStateChanged(last_enc_error_);
393390
}
394391
return;
395392
}
@@ -454,17 +451,13 @@ void FrameCryptorTransformer::encryptFrame(
454451

455452
if (last_enc_error_ != FrameCryptionState::kOk) {
456453
last_enc_error_ = FrameCryptionState::kOk;
457-
if (observer_)
458-
observer_->OnFrameCryptionStateChanged(participant_id_,
459-
last_enc_error_);
454+
onFrameCryptionStateChanged(last_enc_error_);
460455
}
461456
sink_callback->OnTransformedFrame(std::move(frame));
462457
} else {
463458
if (last_enc_error_ != FrameCryptionState::kEncryptionFailed) {
464459
last_enc_error_ = FrameCryptionState::kEncryptionFailed;
465-
if (observer_)
466-
observer_->OnFrameCryptionStateChanged(participant_id_,
467-
last_enc_error_);
460+
onFrameCryptionStateChanged(last_enc_error_);
468461
}
469462
RTC_LOG(LS_ERROR) << "FrameCryptorTransformer::encryptFrame() failed";
470463
}
@@ -489,9 +482,7 @@ void FrameCryptorTransformer::decryptFrame(
489482
<< "FrameCryptorTransformer::decryptFrame() sink_callback is NULL";
490483
if (last_dec_error_ != FrameCryptionState::kInternalError) {
491484
last_dec_error_ = FrameCryptionState::kInternalError;
492-
if (observer_)
493-
observer_->OnFrameCryptionStateChanged(participant_id_,
494-
last_dec_error_);
485+
onFrameCryptionStateChanged(last_dec_error_);
495486
}
496487
return;
497488
}
@@ -552,9 +543,7 @@ void FrameCryptorTransformer::decryptFrame(
552543
<< static_cast<int>(getIvSize()) << "]";
553544
if (last_dec_error_ != FrameCryptionState::kDecryptionFailed) {
554545
last_dec_error_ = FrameCryptionState::kDecryptionFailed;
555-
if (observer_)
556-
observer_->OnFrameCryptionStateChanged(participant_id_,
557-
last_dec_error_);
546+
onFrameCryptionStateChanged(last_dec_error_);
558547
}
559548
return;
560549
}
@@ -571,9 +560,7 @@ void FrameCryptorTransformer::decryptFrame(
571560
<< participant_id_;
572561
if (last_dec_error_ != FrameCryptionState::kMissingKey) {
573562
last_dec_error_ = FrameCryptionState::kMissingKey;
574-
if (observer_)
575-
observer_->OnFrameCryptionStateChanged(participant_id_,
576-
last_dec_error_);
563+
onFrameCryptionStateChanged(last_dec_error_);
577564
}
578565
return;
579566
}
@@ -630,7 +617,7 @@ void FrameCryptorTransformer::decryptFrame(
630617
decryption_success = true;
631618
} else {
632619
RTC_LOG(LS_WARNING) << "FrameCryptorTransformer::decryptFrame() failed";
633-
std::shared_ptr<ParticipantKeyHandler::KeySet> ratcheted_key_set;
620+
rtc::scoped_refptr<ParticipantKeyHandler::KeySet> ratcheted_key_set;
634621
auto currentKeyMaterial = key_set->material;
635622
if (key_provider_->options().ratchet_window_size > 0) {
636623
while (ratchet_count < key_provider_->options().ratchet_window_size) {
@@ -656,9 +643,7 @@ void FrameCryptorTransformer::decryptFrame(
656643
key_handler->SetHasValidKey();
657644
if (last_dec_error_ != FrameCryptionState::kKeyRatcheted) {
658645
last_dec_error_ = FrameCryptionState::kKeyRatcheted;
659-
if (observer_)
660-
observer_->OnFrameCryptionStateChanged(participant_id_,
661-
last_dec_error_);
646+
onFrameCryptionStateChanged(last_dec_error_);
662647
}
663648
break;
664649
}
@@ -683,9 +668,7 @@ void FrameCryptorTransformer::decryptFrame(
683668
if (last_dec_error_ != FrameCryptionState::kDecryptionFailed) {
684669
last_dec_error_ = FrameCryptionState::kDecryptionFailed;
685670
key_handler->DecryptionFailure();
686-
if (observer_)
687-
observer_->OnFrameCryptionStateChanged(participant_id_,
688-
last_dec_error_);
671+
onFrameCryptionStateChanged(last_dec_error_);
689672
}
690673
return;
691674
}
@@ -698,12 +681,23 @@ void FrameCryptorTransformer::decryptFrame(
698681

699682
if (last_dec_error_ != FrameCryptionState::kOk) {
700683
last_dec_error_ = FrameCryptionState::kOk;
701-
if (observer_)
702-
observer_->OnFrameCryptionStateChanged(participant_id_, last_dec_error_);
684+
onFrameCryptionStateChanged(last_dec_error_);
703685
}
704686
sink_callback->OnTransformedFrame(std::move(frame));
705687
}
706688

689+
void FrameCryptorTransformer::onFrameCryptionStateChanged(FrameCryptionState state) {
690+
webrtc::MutexLock lock(&mutex_);
691+
if(observer_) {
692+
RTC_DCHECK(signaling_thread_ != nullptr);
693+
signaling_thread_->PostTask(
694+
[observer = observer_, state = state, participant_id = participant_id_]() mutable {
695+
observer->OnFrameCryptionStateChanged(participant_id, state);
696+
}
697+
);
698+
}
699+
}
700+
707701
rtc::Buffer FrameCryptorTransformer::makeIv(uint32_t ssrc, uint32_t timestamp) {
708702
uint32_t send_count = 0;
709703
if (send_counts_.find(ssrc) == send_counts_.end()) {

api/crypto/frame_crypto_transformer.h

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include <unordered_map>
2121

2222
#include "api/frame_transformer_interface.h"
23+
#include "api/task_queue/pending_task_safety_flag.h"
24+
#include "api/task_queue/task_queue_base.h"
2325
#include "rtc_base/buffer.h"
2426
#include "rtc_base/synchronization/mutex.h"
2527
#include "rtc_base/system/rtc_export.h"
@@ -56,7 +58,7 @@ class KeyProvider : public rtc::RefCountInterface {
5658

5759
virtual bool SetSharedKey(int key_index, std::vector<uint8_t> key) = 0;
5860

59-
virtual const std::shared_ptr<ParticipantKeyHandler> GetSharedKey(const std::string participant_id) = 0;
61+
virtual const rtc::scoped_refptr<ParticipantKeyHandler> GetSharedKey(const std::string participant_id) = 0;
6062

6163
virtual const std::vector<uint8_t> RatchetSharedKey(int key_index) = 0;
6264

@@ -66,7 +68,7 @@ class KeyProvider : public rtc::RefCountInterface {
6668
int key_index,
6769
std::vector<uint8_t> key) = 0;
6870

69-
virtual const std::shared_ptr<ParticipantKeyHandler> GetKey(
71+
virtual const rtc::scoped_refptr<ParticipantKeyHandler> GetKey(
7072
const std::string participant_id) const = 0;
7173

7274
virtual const std::vector<uint8_t> RatchetKey(
@@ -84,9 +86,9 @@ class KeyProvider : public rtc::RefCountInterface {
8486
virtual ~KeyProvider() {}
8587
};
8688

87-
class ParticipantKeyHandler {
89+
class ParticipantKeyHandler : public rtc::RefCountInterface {
8890
public:
89-
struct KeySet {
91+
struct KeySet : public rtc::RefCountInterface {
9092
std::vector<uint8_t> material;
9193
std::vector<uint8_t> encryption_key;
9294
KeySet(std::vector<uint8_t> material, std::vector<uint8_t> encryptionKey)
@@ -99,8 +101,8 @@ class ParticipantKeyHandler {
99101

100102
virtual ~ParticipantKeyHandler() = default;
101103

102-
std::shared_ptr<ParticipantKeyHandler> Clone() {
103-
auto clone = std::make_shared<ParticipantKeyHandler>(key_provider_);
104+
rtc::scoped_refptr<ParticipantKeyHandler> Clone() {
105+
auto clone = rtc::make_ref_counted<ParticipantKeyHandler>(key_provider_);
104106
clone->crypto_key_ring_ = crypto_key_ring_;
105107
clone->current_key_index_ = current_key_index_;
106108
clone->has_valid_key_ = has_valid_key_;
@@ -124,7 +126,7 @@ class ParticipantKeyHandler {
124126
return new_material;
125127
}
126128

127-
virtual std::shared_ptr<KeySet> GetKeySet(int key_index) {
129+
virtual rtc::scoped_refptr<KeySet> GetKeySet(int key_index) {
128130
webrtc::MutexLock lock(&mutex_);
129131
return crypto_key_ring_[key_index != -1 ? key_index : current_key_index_];
130132
}
@@ -144,13 +146,13 @@ class ParticipantKeyHandler {
144146
return new_material;
145147
}
146148

147-
std::shared_ptr<KeySet> DeriveKeys(std::vector<uint8_t> password,
149+
rtc::scoped_refptr<KeySet> DeriveKeys(std::vector<uint8_t> password,
148150
std::vector<uint8_t> ratchet_salt,
149151
unsigned int optional_length_bits) {
150152
std::vector<uint8_t> derived_key;
151153
if (DerivePBKDF2KeyFromRawKey(password, ratchet_salt, optional_length_bits,
152154
&derived_key) == 0) {
153-
return std::make_shared<KeySet>(password, derived_key);
155+
return rtc::make_ref_counted<KeySet>(password, derived_key);
154156
}
155157
return nullptr;
156158
}
@@ -193,7 +195,7 @@ class ParticipantKeyHandler {
193195
mutable webrtc::Mutex mutex_;
194196
int current_key_index_ = 0;
195197
KeyProvider* key_provider_;
196-
std::vector<std::shared_ptr<KeySet>> crypto_key_ring_;
198+
std::vector<rtc::scoped_refptr<KeySet>> crypto_key_ring_;
197199
};
198200

199201
class DefaultKeyProviderImpl : public KeyProvider {
@@ -206,7 +208,7 @@ class DefaultKeyProviderImpl : public KeyProvider {
206208
webrtc::MutexLock lock(&mutex_);
207209
if(options_.shared_key) {
208210
if (keys_.find("shared") == keys_.end()) {
209-
keys_["shared"] = std::make_shared<ParticipantKeyHandler>(this);
211+
keys_["shared"] = rtc::make_ref_counted<ParticipantKeyHandler>(this);
210212
}
211213

212214
auto key_handler = keys_["shared"];
@@ -252,7 +254,7 @@ class DefaultKeyProviderImpl : public KeyProvider {
252254
return std::vector<uint8_t>();
253255
}
254256

255-
const std::shared_ptr<ParticipantKeyHandler> GetSharedKey(const std::string participant_id) override {
257+
const rtc::scoped_refptr<ParticipantKeyHandler> GetSharedKey(const std::string participant_id) override {
256258
webrtc::MutexLock lock(&mutex_);
257259
if(options_.shared_key && keys_.find("shared") != keys_.end()) {
258260
auto shared_key_handler = keys_["shared"];
@@ -274,15 +276,15 @@ class DefaultKeyProviderImpl : public KeyProvider {
274276
webrtc::MutexLock lock(&mutex_);
275277

276278
if (keys_.find(participant_id) == keys_.end()) {
277-
keys_[participant_id] = std::make_shared<ParticipantKeyHandler>(this);
279+
keys_[participant_id] = rtc::make_ref_counted<ParticipantKeyHandler>(this);
278280
}
279281

280282
auto key_handler = keys_[participant_id];
281283
key_handler->SetKey(key, index);
282284
return true;
283285
}
284286

285-
const std::shared_ptr<ParticipantKeyHandler> GetKey(
287+
const rtc::scoped_refptr<ParticipantKeyHandler> GetKey(
286288
const std::string participant_id) const override {
287289
webrtc::MutexLock lock(&mutex_);
288290

@@ -324,7 +326,7 @@ class DefaultKeyProviderImpl : public KeyProvider {
324326
private:
325327
mutable webrtc::Mutex mutex_;
326328
KeyProviderOptions options_;
327-
std::unordered_map<std::string, std::shared_ptr<ParticipantKeyHandler>> keys_;
329+
std::unordered_map<std::string, rtc::scoped_refptr<ParticipantKeyHandler>> keys_;
328330
};
329331

330332
enum FrameCryptionState {
@@ -337,7 +339,7 @@ enum FrameCryptionState {
337339
kInternalError,
338340
};
339341

340-
class FrameCryptorTransformerObserver {
342+
class FrameCryptorTransformerObserver : public rtc::RefCountInterface {
341343
public:
342344
virtual void OnFrameCryptionStateChanged(const std::string participant_id,
343345
FrameCryptionState error) = 0;
@@ -359,17 +361,23 @@ class RTC_EXPORT FrameCryptorTransformer
359361
kAesCbc,
360362
};
361363

362-
explicit FrameCryptorTransformer(const std::string participant_id,
364+
explicit FrameCryptorTransformer(rtc::Thread* signaling_thread,
365+
const std::string participant_id,
363366
MediaType type,
364367
Algorithm algorithm,
365368
rtc::scoped_refptr<KeyProvider> key_provider);
366369

367-
virtual void SetFrameCryptorTransformerObserver(
368-
FrameCryptorTransformerObserver* observer) {
370+
virtual void RegisterFrameCryptorTransformerObserver(
371+
rtc::scoped_refptr<FrameCryptorTransformerObserver> observer) {
369372
webrtc::MutexLock lock(&mutex_);
370373
observer_ = observer;
371374
}
372375

376+
virtual void UnRegisterFrameCryptorTransformerObserver() {
377+
webrtc::MutexLock lock(&mutex_);
378+
observer_ = nullptr;
379+
}
380+
373381
virtual void SetKeyIndex(int index) {
374382
webrtc::MutexLock lock(&mutex_);
375383
key_index_ = index;
@@ -417,10 +425,12 @@ class RTC_EXPORT FrameCryptorTransformer
417425
private:
418426
void encryptFrame(std::unique_ptr<webrtc::TransformableFrameInterface> frame);
419427
void decryptFrame(std::unique_ptr<webrtc::TransformableFrameInterface> frame);
428+
void onFrameCryptionStateChanged(FrameCryptionState error);
420429
rtc::Buffer makeIv(uint32_t ssrc, uint32_t timestamp);
421430
uint8_t getIvSize();
422431

423432
private:
433+
TaskQueueBase* const signaling_thread_;
424434
std::string participant_id_;
425435
mutable webrtc::Mutex mutex_;
426436
mutable webrtc::Mutex sink_mutex_;
@@ -433,10 +443,9 @@ class RTC_EXPORT FrameCryptorTransformer
433443
int key_index_ = 0;
434444
std::map<uint32_t, uint32_t> send_counts_;
435445
rtc::scoped_refptr<KeyProvider> key_provider_;
436-
FrameCryptorTransformerObserver* observer_ = nullptr;
437-
std::unique_ptr<rtc::Thread> thread_;
446+
rtc::scoped_refptr<FrameCryptorTransformerObserver> observer_;
438447
FrameCryptionState last_enc_error_ = FrameCryptionState::kNew;
439-
FrameCryptionState last_dec_error_ = FrameCryptionState::kNew;
448+
FrameCryptionState last_dec_error_ = FrameCryptionState::kNew;
440449
};
441450

442451
} // namespace webrtc

sdk/android/api/org/webrtc/FrameCryptorFactory.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,21 @@ public static FrameCryptorKeyProvider createFrameCryptorKeyProvider(
2222
return nativeCreateFrameCryptorKeyProvider(sharedKey, ratchetSalt, ratchetWindowSize, uncryptedMagicBytes, failureTolerance);
2323
}
2424

25-
public static FrameCryptor createFrameCryptorForRtpSender(RtpSender rtpSender,
25+
public static FrameCryptor createFrameCryptorForRtpSender(PeerConnectionFactory factory, RtpSender rtpSender,
2626
String participantId, FrameCryptorAlgorithm algorithm, FrameCryptorKeyProvider keyProvider) {
27-
return nativeCreateFrameCryptorForRtpSender(rtpSender.getNativeRtpSender(), participantId,
27+
return nativeCreateFrameCryptorForRtpSender(factory.getNativeOwnedFactoryAndThreads(),rtpSender.getNativeRtpSender(), participantId,
2828
algorithm.ordinal(), keyProvider.getNativeKeyProvider());
2929
}
3030

31-
public static FrameCryptor createFrameCryptorForRtpReceiver(RtpReceiver rtpReceiver,
31+
public static FrameCryptor createFrameCryptorForRtpReceiver(PeerConnectionFactory factory, RtpReceiver rtpReceiver,
3232
String participantId, FrameCryptorAlgorithm algorithm, FrameCryptorKeyProvider keyProvider) {
33-
return nativeCreateFrameCryptorForRtpReceiver(rtpReceiver.getNativeRtpReceiver(), participantId,
33+
return nativeCreateFrameCryptorForRtpReceiver(factory.getNativeOwnedFactoryAndThreads(), rtpReceiver.getNativeRtpReceiver(), participantId,
3434
algorithm.ordinal(), keyProvider.getNativeKeyProvider());
3535
}
3636

37-
private static native FrameCryptor nativeCreateFrameCryptorForRtpSender(
37+
private static native FrameCryptor nativeCreateFrameCryptorForRtpSender(long factory,
3838
long rtpSender, String participantId, int algorithm, long nativeFrameCryptorKeyProvider);
39-
private static native FrameCryptor nativeCreateFrameCryptorForRtpReceiver(
39+
private static native FrameCryptor nativeCreateFrameCryptorForRtpReceiver(long factory,
4040
long rtpReceiver, String participantId, int algorithm, long nativeFrameCryptorKeyProvider);
4141

4242
private static native FrameCryptorKeyProvider nativeCreateFrameCryptorKeyProvider(

0 commit comments

Comments
 (0)