diff --git a/api/crypto/frame_crypto_transformer.cc b/api/crypto/frame_crypto_transformer.cc index d82edfb465..3e888a2175 100644 --- a/api/crypto/frame_crypto_transformer.cc +++ b/api/crypto/frame_crypto_transformer.cc @@ -83,7 +83,7 @@ const EVP_CIPHER* GetAesCbcAlgorithmFromKeySize(size_t key_size_bytes) { } inline bool FrameIsH264(webrtc::TransformableFrameInterface* frame, - webrtc::FrameCryptorTransformer::MediaType type) { + webrtc::FrameCryptorTransformer::MediaType type) { switch (type) { case webrtc::FrameCryptorTransformer::MediaType::kVideoFrame: { auto videoFrame = @@ -314,11 +314,18 @@ FrameCryptorTransformer::FrameCryptorTransformer( Algorithm algorithm, rtc::scoped_refptr key_provider) : signaling_thread_(signaling_thread), + thread_(rtc::Thread::Create()), participant_id_(participant_id), type_(type), algorithm_(algorithm), key_provider_(key_provider) { RTC_DCHECK(key_provider_ != nullptr); + thread_->SetName("FrameCryptorTransformer", this); + thread_->Start(); +} + +FrameCryptorTransformer::~FrameCryptorTransformer() { + thread_->Stop(); } void FrameCryptorTransformer::Transform( @@ -333,10 +340,16 @@ void FrameCryptorTransformer::Transform( // do encrypt or decrypt here... switch (frame->GetDirection()) { case webrtc::TransformableFrameInterface::Direction::kSender: - encryptFrame(std::move(frame)); + RTC_DCHECK(thread_ != nullptr); + thread_->PostTask([frame = std::move(frame), this]() mutable { + encryptFrame(std::move(frame)); + }); break; case webrtc::TransformableFrameInterface::Direction::kReceiver: - decryptFrame(std::move(frame)); + RTC_DCHECK(thread_ != nullptr); + thread_->PostTask([frame = std::move(frame), this]() mutable { + decryptFrame(std::move(frame)); + }); break; case webrtc::TransformableFrameInterface::Direction::kUnknown: // do nothing @@ -371,6 +384,8 @@ void FrameCryptorTransformer::encryptFrame( rtc::ArrayView date_in = frame->GetData(); if (date_in.size() == 0 || !enabled_cryption) { + RTC_LOG(LS_WARNING) << "FrameCryptorTransformer::encryptFrame() " + "date_in.size() == 0 || enabled_cryption == false"; sink_callback->OnTransformedFrame(std::move(frame)); return; } @@ -425,7 +440,8 @@ void FrameCryptorTransformer::encryptFrame( data_out.AppendData(frame_header); if (FrameIsH264(frame.get(), type_)) { - H264::WriteRbsp(data_without_header.data(),data_without_header.size(), &data_out); + H264::WriteRbsp(data_without_header.data(), data_without_header.size(), + &data_out); } else { data_out.AppendData(data_without_header); RTC_CHECK_EQ(data_out.size(), frame_header.size() + @@ -490,34 +506,31 @@ void FrameCryptorTransformer::decryptFrame( rtc::ArrayView date_in = frame->GetData(); if (date_in.size() == 0 || !enabled_cryption) { + RTC_LOG(LS_WARNING) << "FrameCryptorTransformer::decryptFrame() " + "date_in.size() == 0 || enabled_cryption == false"; sink_callback->OnTransformedFrame(std::move(frame)); return; } auto uncrypted_magic_bytes = key_provider_->options().uncrypted_magic_bytes; if (uncrypted_magic_bytes.size() > 0 && - date_in.size() >= uncrypted_magic_bytes.size() + 1) { - auto tmp = - date_in.subview(date_in.size() - (uncrypted_magic_bytes.size() + 1), - uncrypted_magic_bytes.size()); - - if (uncrypted_magic_bytes == std::vector(tmp.begin(), tmp.end())) { + date_in.size() >= uncrypted_magic_bytes.size()) { + auto tmp = date_in.subview(date_in.size() - (uncrypted_magic_bytes.size()), + uncrypted_magic_bytes.size()); + auto data = std::vector(tmp.begin(), tmp.end()); + if (uncrypted_magic_bytes == data) { RTC_CHECK_EQ(tmp.size(), uncrypted_magic_bytes.size()); - auto frame_type = date_in.subview(date_in.size() - 1, 1); - RTC_CHECK_EQ(frame_type.size(), 1); - - RTC_LOG(LS_INFO) - << "FrameCryptorTransformer::uncrypted_magic_bytes( type " - << frame_type[0] << ", tmp " << to_hex(tmp.data(), tmp.size()) - << ", magic bytes " - << to_hex(uncrypted_magic_bytes.data(), uncrypted_magic_bytes.size()) - << ")"; + RTC_LOG(LS_INFO) << "FrameCryptorTransformer::uncrypted_magic_bytes( tmp " + << to_hex(tmp.data(), tmp.size()) << ", magic bytes " + << to_hex(uncrypted_magic_bytes.data(), + uncrypted_magic_bytes.size()) + << ")"; // magic bytes detected, this is a non-encrypted frame, skip frame // decryption. rtc::Buffer data_out; - data_out.AppendData(date_in.subview( - 0, date_in.size() - uncrypted_magic_bytes.size() - 1)); + data_out.AppendData( + date_in.subview(0, date_in.size() - uncrypted_magic_bytes.size())); frame->SetData(data_out); sink_callback->OnTransformedFrame(std::move(frame)); return; @@ -539,8 +552,8 @@ void FrameCryptorTransformer::decryptFrame( if (ivLength != getIvSize()) { RTC_LOG(LS_WARNING) << "FrameCryptorTransformer::decryptFrame() ivLength[" - << static_cast(ivLength) << "] != getIvSize()[" - << static_cast(getIvSize()) << "]"; + << static_cast(ivLength) << "] != getIvSize()[" + << static_cast(getIvSize()) << "]"; if (last_dec_error_ != FrameCryptionState::kDecryptionFailed) { last_dec_error_ = FrameCryptionState::kDecryptionFailed; onFrameCryptionStateChanged(last_dec_error_); @@ -585,7 +598,8 @@ void FrameCryptorTransformer::decryptFrame( if (FrameIsH264(frame.get(), type_) && NeedsRbspUnescaping(encrypted_buffer.data(), encrypted_buffer.size())) { - encrypted_buffer.SetData(H264::ParseRbsp(encrypted_buffer.data(), encrypted_buffer.size())); + encrypted_buffer.SetData( + H264::ParseRbsp(encrypted_buffer.data(), encrypted_buffer.size())); } rtc::Buffer encrypted_payload(encrypted_buffer.size() - ivLength - 2); @@ -665,10 +679,11 @@ void FrameCryptorTransformer::decryptFrame( } if (!decryption_success) { - if (last_dec_error_ != FrameCryptionState::kDecryptionFailed) { - last_dec_error_ = FrameCryptionState::kDecryptionFailed; - key_handler->DecryptionFailure(); - onFrameCryptionStateChanged(last_dec_error_); + if (key_handler->DecryptionFailure()) { + if (last_dec_error_ != FrameCryptionState::kDecryptionFailed) { + last_dec_error_ = FrameCryptionState::kDecryptionFailed; + onFrameCryptionStateChanged(last_dec_error_); + } } return; } @@ -686,15 +701,15 @@ void FrameCryptorTransformer::decryptFrame( sink_callback->OnTransformedFrame(std::move(frame)); } -void FrameCryptorTransformer::onFrameCryptionStateChanged(FrameCryptionState state) { +void FrameCryptorTransformer::onFrameCryptionStateChanged( + FrameCryptionState state) { webrtc::MutexLock lock(&mutex_); - if(observer_) { + if (observer_) { RTC_DCHECK(signaling_thread_ != nullptr); - signaling_thread_->PostTask( - [observer = observer_, state = state, participant_id = participant_id_]() mutable { - observer->OnFrameCryptionStateChanged(participant_id, state); - } - ); + signaling_thread_->PostTask([observer = observer_, state = state, + participant_id = participant_id_]() mutable { + observer->OnFrameCryptionStateChanged(participant_id, state); + }); } } diff --git a/api/crypto/frame_crypto_transformer.h b/api/crypto/frame_crypto_transformer.h index 3be78567c2..f9027d08f7 100644 --- a/api/crypto/frame_crypto_transformer.h +++ b/api/crypto/frame_crypto_transformer.h @@ -44,7 +44,8 @@ struct KeyProviderOptions { std::vector uncrypted_magic_bytes; int ratchet_window_size; int failure_tolerance; - KeyProviderOptions() : shared_key(false), ratchet_window_size(0), failure_tolerance(-1) {} + KeyProviderOptions() + : shared_key(false), ratchet_window_size(0), failure_tolerance(-1) {} KeyProviderOptions(KeyProviderOptions& copy) : shared_key(copy.shared_key), ratchet_salt(copy.ratchet_salt), @@ -55,10 +56,10 @@ struct KeyProviderOptions { class KeyProvider : public rtc::RefCountInterface { public: - virtual bool SetSharedKey(int key_index, std::vector key) = 0; - virtual const rtc::scoped_refptr GetSharedKey(const std::string participant_id) = 0; + virtual const rtc::scoped_refptr GetSharedKey( + const std::string participant_id) = 0; virtual const std::vector RatchetSharedKey(int key_index) = 0; @@ -94,8 +95,10 @@ class ParticipantKeyHandler : public rtc::RefCountInterface { KeySet(std::vector material, std::vector encryptionKey) : material(material), encryption_key(encryptionKey) {} }; + public: - ParticipantKeyHandler(KeyProvider* key_provider) : key_provider_(key_provider) { + ParticipantKeyHandler(KeyProvider* key_provider) + : key_provider_(key_provider) { crypto_key_ring_.resize(KEYRING_SIZE); } @@ -116,7 +119,8 @@ class ParticipantKeyHandler : public rtc::RefCountInterface { } auto current_material = key_set->material; std::vector new_material; - if (DerivePBKDF2KeyFromRawKey(current_material, key_provider_->options().ratchet_salt, 256, + if (DerivePBKDF2KeyFromRawKey(current_material, + key_provider_->options().ratchet_salt, 256, &new_material) != 0) { return std::vector(); } @@ -139,7 +143,8 @@ class ParticipantKeyHandler : public rtc::RefCountInterface { std::vector RatchetKeyMaterial( std::vector current_material) { std::vector new_material; - if (DerivePBKDF2KeyFromRawKey(current_material, key_provider_->options().ratchet_salt, 256, + if (DerivePBKDF2KeyFromRawKey(current_material, + key_provider_->options().ratchet_salt, 256, &new_material) != 0) { return std::vector(); } @@ -147,8 +152,8 @@ class ParticipantKeyHandler : public rtc::RefCountInterface { } rtc::scoped_refptr DeriveKeys(std::vector password, - std::vector ratchet_salt, - unsigned int optional_length_bits) { + std::vector ratchet_salt, + unsigned int optional_length_bits) { std::vector derived_key; if (DerivePBKDF2KeyFromRawKey(password, ratchet_salt, optional_length_bits, &derived_key) == 0) { @@ -177,16 +182,19 @@ class ParticipantKeyHandler : public rtc::RefCountInterface { DeriveKeys(password, key_provider_->options().ratchet_salt, 128); } - void DecryptionFailure() { + bool DecryptionFailure() { webrtc::MutexLock lock(&mutex_); if (key_provider_->options().failure_tolerance < 0) { - return; + return false; } decryption_failure_count_ += 1; - if (decryption_failure_count_ > key_provider_->options().failure_tolerance) { + if (decryption_failure_count_ > + key_provider_->options().failure_tolerance) { has_valid_key_ = false; + return true; } + return false; } private: @@ -206,7 +214,7 @@ class DefaultKeyProviderImpl : public KeyProvider { /// Set the shared key. bool SetSharedKey(int key_index, std::vector key) override { webrtc::MutexLock lock(&mutex_); - if(options_.shared_key) { + if (options_.shared_key) { if (keys_.find("shared") == keys_.end()) { keys_["shared"] = rtc::make_ref_counted(this); } @@ -214,8 +222,8 @@ class DefaultKeyProviderImpl : public KeyProvider { auto key_handler = keys_["shared"]; key_handler->SetKey(key, key_index); - for(auto& key_pair : keys_) { - if(key_pair.first != "shared") { + for (auto& key_pair : keys_) { + if (key_pair.first != "shared") { key_pair.second->SetKey(key, key_index); } } @@ -227,13 +235,13 @@ class DefaultKeyProviderImpl : public KeyProvider { const std::vector RatchetSharedKey(int key_index) override { webrtc::MutexLock lock(&mutex_); auto it = keys_.find("shared"); - if(it == keys_.end()) { + if (it == keys_.end()) { return std::vector(); } auto new_key = it->second->RatchetKey(key_index); - if(options_.shared_key) { - for(auto& key_pair : keys_) { - if(key_pair.first != "shared") { + if (options_.shared_key) { + for (auto& key_pair : keys_) { + if (key_pair.first != "shared") { key_pair.second->SetKey(new_key, key_index); } } @@ -244,19 +252,20 @@ class DefaultKeyProviderImpl : public KeyProvider { const std::vector ExportSharedKey(int key_index) const override { webrtc::MutexLock lock(&mutex_); auto it = keys_.find("shared"); - if(it == keys_.end()) { + if (it == keys_.end()) { return std::vector(); } auto key_set = it->second->GetKeySet(key_index); - if(key_set) { + if (key_set) { return key_set->material; } return std::vector(); } - const rtc::scoped_refptr GetSharedKey(const std::string participant_id) override { + const rtc::scoped_refptr GetSharedKey( + const std::string participant_id) override { webrtc::MutexLock lock(&mutex_); - if(options_.shared_key && keys_.find("shared") != keys_.end()) { + if (options_.shared_key && keys_.find("shared") != keys_.end()) { auto shared_key_handler = keys_["shared"]; if (keys_.find(participant_id) != keys_.end()) { return keys_[participant_id]; @@ -276,7 +285,8 @@ class DefaultKeyProviderImpl : public KeyProvider { webrtc::MutexLock lock(&mutex_); if (keys_.find(participant_id) == keys_.end()) { - keys_[participant_id] = rtc::make_ref_counted(this); + keys_[participant_id] = + rtc::make_ref_counted(this); } auto key_handler = keys_[participant_id]; @@ -326,7 +336,8 @@ class DefaultKeyProviderImpl : public KeyProvider { private: mutable webrtc::Mutex mutex_; KeyProviderOptions options_; - std::unordered_map> keys_; + std::unordered_map> + keys_; }; enum FrameCryptionState { @@ -361,19 +372,20 @@ class RTC_EXPORT FrameCryptorTransformer kAesCbc, }; - explicit FrameCryptorTransformer(rtc::Thread* signaling_thread, - const std::string participant_id, - MediaType type, - Algorithm algorithm, - rtc::scoped_refptr key_provider); - + explicit FrameCryptorTransformer( + rtc::Thread* signaling_thread, + const std::string participant_id, + MediaType type, + Algorithm algorithm, + rtc::scoped_refptr key_provider); + ~FrameCryptorTransformer(); virtual void RegisterFrameCryptorTransformerObserver( - rtc::scoped_refptr observer) { + rtc::scoped_refptr observer) { webrtc::MutexLock lock(&mutex_); observer_ = observer; } - virtual void UnRegisterFrameCryptorTransformerObserver() { + virtual void UnRegisterFrameCryptorTransformerObserver() { webrtc::MutexLock lock(&mutex_); observer_ = nullptr; } @@ -431,6 +443,7 @@ class RTC_EXPORT FrameCryptorTransformer private: TaskQueueBase* const signaling_thread_; + std::unique_ptr thread_; std::string participant_id_; mutable webrtc::Mutex mutex_; mutable webrtc::Mutex sink_mutex_; @@ -445,7 +458,7 @@ class RTC_EXPORT FrameCryptorTransformer rtc::scoped_refptr key_provider_; rtc::scoped_refptr observer_; FrameCryptionState last_enc_error_ = FrameCryptionState::kNew; - FrameCryptionState last_dec_error_ = FrameCryptionState::kNew; + FrameCryptionState last_dec_error_ = FrameCryptionState::kNew; }; } // namespace webrtc