diff --git a/crypto/README.md b/crypto/README.md index 7250a88..98c1e7f 100644 --- a/crypto/README.md +++ b/crypto/README.md @@ -3,6 +3,10 @@ ## Functions - `md5(string)` - return md5 checksum from string. - `sha256(string)` - return sha256 checksum from string. +- `aes_encrypt(string, string, string, string)` - return AES encrypted hex-encoded ciphertext +- `aes_decrypt(string, string, string, string)` - return AES decrypted hex-encoded plain text + +AES support 3 modes: GCM, CBC, and CTR - first parameter is mode, second is hex-encoded key, third is hex-encoded initialization vector or nonce - depending on the mode, and forth is hex-encoded plain text or ciphertext. ## Examples @@ -18,5 +22,19 @@ end if not(crypto.sha256("1\n") == "4355a46b19d348dc2f57c046f8ef63d4538ebb936000f3c9ee954a27460dd865") then error("sha256") end -``` +--- aes encrypt in GCM mode +s, err = crypto.aes_encrypt(1, "86e15cbc1cbf510d8f2e51d4b63a2144", "b6b86d581a991a652158bd10", "48656c6c6f20776f726c64") +if not(s == "7ec4e38508a26abf7b46e8dc90a7299d5144bcf045e460c3ef6b3e") then + error("encrypt AES") +end +assert(not err, err) + +--- aes decrypt in GCM mode +s, err = crypto.aes_decrypt(1, "86e15cbc1cbf510d8f2e51d4b63a2144", "b6b86d581a991a652158bd10", "7ec4e38508a26abf7b46e8dc90a7299d5144bcf045e460c3ef6b3e") +if not(s == "48656c6c6f20776f726c64") then + error("decrypt AES) +end +assert(not err, err) + +``` \ No newline at end of file diff --git a/crypto/aes.go b/crypto/aes.go new file mode 100644 index 0000000..9e548ea --- /dev/null +++ b/crypto/aes.go @@ -0,0 +1,163 @@ +package crypto + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "encoding/hex" + "fmt" + "strings" + + lua "github.com/yuin/gopher-lua" +) + +type mode uint + +const ( + GCM mode = iota + 1 + CBC + CTR +) + +var modeNames = map[string]mode{ + "GCM": GCM, + "CBC": CBC, + "CTR": CTR, +} + +func (m mode) String() string { + switch m { + case GCM: + return "GCM" + case CBC: + return "CBC" + case CTR: + return "CTR" + default: + return "unknown" + } +} + +func parseString(s string) (mode, error) { + ret, ok := modeNames[strings.ToUpper(s)] + if !ok { + return 0, fmt.Errorf("invalid mode: %s", s) + } + return ret, nil +} + +func decodeParams(l *lua.LState) (m mode, key, iv, data []byte, err error) { + modeString := l.ToString(1) + m, err = parseString(modeString) + if err != nil { + return 0, nil, nil, nil, err + } + + keyStr := l.ToString(2) + key, err = hex.DecodeString(keyStr) + if err != nil { + return 0, nil, nil, nil, fmt.Errorf("failed to decode key: %v", err) + } + + ivStr := l.ToString(3) + iv, err = hex.DecodeString(ivStr) + if err != nil { + return 0, nil, nil, nil, fmt.Errorf("failed to decode IV: %v", err) + } + + dataStr := l.ToString(4) + data, err = hex.DecodeString(dataStr) + if err != nil { + return 0, nil, nil, nil, fmt.Errorf("failed to decode data: %v", err) + } + return m, key, iv, data, nil +} + +// encryptAES implements AES encryption given mode, key, plaintext, and init value. +// Init value is either initialization vector or nonce, depending on the mode. +func encryptAES(m mode, key, init, plaintext []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + switch m { + case GCM: + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + if len(init) != aesGCM.NonceSize() { + return nil, fmt.Errorf("incorrect GCM nonce size: %d, expected: %d", len(init), aesGCM.NonceSize()) + } + ciphertext := aesGCM.Seal(nil, init, plaintext, nil) + return ciphertext, nil + case CBC: + if len(init) != block.BlockSize() { + return nil, fmt.Errorf("invalid IV size: %d, expected: %d", len(init), block.BlockSize()) + } + padded := pad(plaintext, aes.BlockSize) + mode := cipher.NewCBCEncrypter(block, init) + ciphertext := make([]byte, len(padded)) + mode.CryptBlocks(ciphertext, padded) + return ciphertext, nil + case CTR: + if len(init) != block.BlockSize() { + return nil, fmt.Errorf("invalid IV size: %d, expected: %d", len(init), block.BlockSize()) + } + stream := cipher.NewCTR(block, init) + ciphertext := make([]byte, len(plaintext)) + stream.XORKeyStream(ciphertext, plaintext) + return ciphertext, nil + default: + return nil, fmt.Errorf("unsupported mode: %d", m) + } +} + +// decryptAES implements AES decryption given mode, key, ciphertext, and init value. +// Init value is either initialization vector or nonce, depending on the mode. +func decryptAES(m mode, key, init, ciphertext []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + switch m { + case GCM: + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + l := len(init) + if l != aesGCM.NonceSize() { + return nil, fmt.Errorf("incorrect GCM nonce size: %d, expected: %d", len(init), aesGCM.NonceSize()) + } + plaintext, err := aesGCM.Open(nil, init, ciphertext, nil) + if err != nil { + return nil, err + } + return plaintext, nil + case CBC: + if len(ciphertext)%aes.BlockSize != 0 { + return nil, fmt.Errorf("ciphertext is not a multiple of block size") + } + mode := cipher.NewCBCDecrypter(block, init) + plaintext := make([]byte, len(ciphertext)) + mode.CryptBlocks(plaintext, ciphertext) + // Padding reversal is intentionally delegated to the application layer. + // On constrained devices with fixed-length payloads, padding is sometimes omitted + // to avoid unnecessary processing load and data overhead. + return plaintext, nil + case CTR: + stream := cipher.NewCTR(block, init) + plaintext := make([]byte, len(ciphertext)) + stream.XORKeyStream(plaintext, ciphertext) + return plaintext, nil + default: + return nil, fmt.Errorf("unsupported mode: %s", m) + } +} + +func pad(data []byte, blockSize int) []byte { + padLen := blockSize - len(data)%blockSize + padding := bytes.Repeat([]byte{byte(padLen)}, padLen) + return append(data, padding...) +} diff --git a/crypto/api.go b/crypto/api.go index de7cc95..24af0de 100644 --- a/crypto/api.go +++ b/crypto/api.go @@ -4,6 +4,7 @@ package crypto import ( "crypto/md5" "crypto/sha256" + "encoding/hex" "fmt" lua "github.com/yuin/gopher-lua" @@ -24,3 +25,42 @@ func SHA256(L *lua.LState) int { L.Push(lua.LString(fmt.Sprintf("%x", hash))) return 1 } + +// AESEncrypt implements AES encryption in Lua. +func AESEncrypt(l *lua.LState) int { + m, key, iv, data, err := decodeParams(l) + if err != nil { + l.Push(lua.LNil) + l.Push(lua.LString(fmt.Sprintf("failed to decode params: %v", err))) + return 2 + } + + enc, err := encryptAES(m, key, iv, data) + if err != nil { + l.Push(lua.LNil) + l.Push(lua.LString(fmt.Sprintf("failed to encrypt: %v", err))) + return 2 + } + l.Push(lua.LString(hex.EncodeToString(enc))) + return 1 +} + +// AESDecrypt implement AES decryption in Lua. +func AESDecrypt(l *lua.LState) int { + m, key, iv, data, err := decodeParams(l) + if err != nil { + l.Push(lua.LNil) + l.Push(lua.LString(fmt.Sprintf("failed to decode params: %v", err))) + return 2 + } + + dec, err := decryptAES(mode(m), key, iv, data) + if err != nil { + l.Push(lua.LNil) + l.Push(lua.LString(fmt.Sprintf("failed to decrypt: %v", err))) + return 2 + } + + l.Push(lua.LString(hex.EncodeToString(dec))) + return 1 +} diff --git a/crypto/api_test.go b/crypto/api_test.go index 6360bc4..cade66f 100644 --- a/crypto/api_test.go +++ b/crypto/api_test.go @@ -1,11 +1,13 @@ package crypto import ( + "testing" + "github.com/stretchr/testify/assert" "github.com/vadv/gopher-lua-libs/tests" - "testing" ) func TestApi(t *testing.T) { - assert.NotZero(t, tests.RunLuaTestFile(t, Preload, "./test/test_api.lua")) + preload := tests.SeveralPreloadFuncs(Preload) + assert.NotZero(t, tests.RunLuaTestFile(t, preload, "./test/test_api.lua")) } diff --git a/crypto/loader.go b/crypto/loader.go index 8781076..a7dd662 100644 --- a/crypto/loader.go +++ b/crypto/loader.go @@ -1,13 +1,11 @@ package crypto -import ( - lua "github.com/yuin/gopher-lua" -) +import lua "github.com/yuin/gopher-lua" // Preload adds crypto to the given Lua state's package.preload table. After it // has been preloaded, it can be loaded using require: // -// local crypto = require("crypto") +// local crypto = require("crypto") func Preload(L *lua.LState) { L.PreloadModule("crypto", Loader) } @@ -21,6 +19,8 @@ func Loader(L *lua.LState) int { } var api = map[string]lua.LGFunction{ - "md5": MD5, - "sha256": SHA256, + "md5": MD5, + "sha256": SHA256, + "aes_encrypt": AESEncrypt, + "aes_decrypt": AESDecrypt, } diff --git a/crypto/test/test_api.lua b/crypto/test/test_api.lua index 9f685d4..4de4183 100644 --- a/crypto/test/test_api.lua +++ b/crypto/test/test_api.lua @@ -1,11 +1,184 @@ local crypto = require("crypto") +local assert = require 'assert' -function Test_crypto(t) - t:Run("md5", function(t) - assert(crypto.md5("1\n") == "b026324c6904b2a9cb4b88d6d61c81d1") - end) +function TestMD5(t) + local tests = { + { + input = "1\n", + expected = "b026324c6904b2a9cb4b88d6d61c81d1", + }, + { + input = "test", + expected = "098f6bcd4621d373cade4e832627b4f6" + } + } + for _, tt in ipairs(tests) do + t:Run("md5(" .. tostring(tt.input) .. ")", function(t) + local got = crypto.md5(tt.input) + assert:Equal(t, tt.expected, got) + end) + end +end + +function TestSha256(t) + local tests = { + { + input = "1\n", + expected = "4355a46b19d348dc2f57c046f8ef63d4538ebb936000f3c9ee954a27460dd865", + }, + { + input = "test", + expected = "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08" + } + } + for _, tt in ipairs(tests) do + t:Run("sha256(" .. tostring(tt.input) .. ")", function(t) + local got = crypto.sha256(tt.input) + assert:Equal(t, tt.expected, got) + end) + end +end + +function TestAESEncrypt(t) + local tests = { + { + data = "48656c6c6f207w76f726c64", -- "Hello world" in hex + mode = "GCM", + key = "86e15cbc1cbf510d8f2e51d4b63a2144", + init = "b6b86d581a991a652158bd10", + expected = nil, + err = "failed to decode params: failed to decode data: encoding/hex: invalid byte: U+0077 'w'", + }, + { + data = "48656c6c6f20776f726c64", -- "Hello world" in hex + mode = "GCM", + key = "86e15cbc1cbf51d8f2e51d4b63a2144", + init = "b6b86d581a991a652158bd10", + expected = nil, + err = "failed to decode params: failed to decode key: encoding/hex: odd length hex string", + }, + { + data = "48656c6c6f20776f726c64", -- "Hello world" in hex + mode = "GCM", + key = "86e15cbc1cbf510d8f2e51d4b63a2144", + init = "b6b86d581a991a652158bd10", + expected = "7ec4e38508a26abf7b46e8dc90a7299d5144bcf045e460c3ef6b3e", + err = nil, + }, + { + data = "48656c6c6f20776f726c64", -- "Hello world" in hex + mode = "GCM", + key = "86e15cbc1cbf510d8f2e51d4b63a2144", + init = "b6b86d581a991a652158bd010211", + expected = nil, + err = "failed to encrypt: incorrect GCM nonce size: 14, expected: 12", + }, + { + data = "48656c6c6f20776f726c64", -- "Hello world" in hex + mode = "GCM", + key = "86e15cbc1cbf510d8f2e51d4b63a2144", + init = "b6b86d581a991a652158bd010211", + expected = nil, + err = "failed to encrypt: incorrect GCM nonce size: 14, expected: 12", + }, + { + data = "48656c6c6f20776f726c64", -- "Hello world" in hex + mode = "cbc", + key = "86e15cbc1cbf510d8f2e51d4b63a2144", + init = "068bb92e032884ba8b260fa7d3a80005", + expected = "dfba6f71cce4d4b76be301b577d9f095", + err = nil, + }, + { + data = "48656c6c6f20776f726c64", -- "Hello world" in hex + mode = "CBC", + key = "86e15cbc1cbf510d8f2e51d4b63a2144", + init = "068bb92e03288884ba8b260fa7d3a80005", + expected = nil, + err = "failed to encrypt: invalid IV size: 17, expected: 16", + }, + { + data = "48656c6c6f20776f726c64", -- "Hello world" in hex + mode = "CTR", + key = "86e15cbc1cbf510d8f2e51d4b63a2144", + init = "e3057fc2bf103a09a1b2c3d4e5f60718", + expected = "138434a80bd7dcd9ee8adc", + err = nil, + }, + { + data = "48656c6c6f20776f726c64", -- "Hello world" in hex + mode = "CTR", + key = "86e15cbc1cbf510d8f2e51d4b63a2144", + init = "e3057fc2b9f103a909a1b2c3d4e5f60718", + expected = nil, + err = "failed to encrypt: invalid IV size: 17, expected: 16", + }, + } + for _, tt in ipairs(tests) do + t:Run("aes_encrypt in " .. tostring(tt.mode) .. " mode", function(t) + local got, err = crypto.aes_encrypt(tt.mode, tt.key, tt.init, tt.data) + assert:Equal(t, tt.expected, got) + assert:Equal(t, tt.err, err) + end) + end +end - t:Run("sha256", function(t) - assert(crypto.sha256("1\n") == "4355a46b19d348dc2f57c046f8ef63d4538ebb936000f3c9ee954a27460dd865") - end) +function TestAESDecrypt(t) + local tests = { + { + data = "7ec4e38508a26abf7b46e8dc90a7299d5144bcf045e460c3efwb3e", + mode = "GCM", + key = "86e15cbc1cbf510d8f2e51d4b63a2144", + init = "b6b86d581a991a652158bd10", + expected = nil, + err = "failed to decode params: failed to decode data: encoding/hex: invalid byte: U+0077 'w'", + }, + { + data = "7ec4e38508a26abf7b46e8dc90a7299d5144bcf045e460c3ef6b3e", + mode = "GCM", + key = "86e15cbc1cbf51d8f2e51d4b63a2144", + init = "b6b86d581a991a652158bd10", + expected = nil, + err = "failed to decode params: failed to decode key: encoding/hex: odd length hex string", + }, + { + data = "7ec4e38508a26abf7b46e8dc90a7299d5144bcf045e460c3ef6b3e", + mode = "GCM", + key = "86e15cbc1cbf510d8f2e51d4b63a2144", + init = "b6b86d581a991a652158bd10", + expected = "48656c6c6f20776f726c64", -- "Hello world" in hex + err = nil, + }, + { + data = "7ec4e38508a26abf7b46e8dc90a7299d5144bcf045e460c3ef6b3e", + mode = "GCM", + key = "86e15cbc1cbf510d8f2e51d4b63a2144", + init = "b6b86d581a991a652158bd010211", + expected = nil, + err = "failed to decrypt: incorrect GCM nonce size: 14, expected: 12", + }, + { + data = "dfba6f71cce4d4b76be301b577d9f095", + mode = "cbc", + key = "86e15cbc1cbf510d8f2e51d4b63a2144", + init = "068bb92e032884ba8b260fa7d3a80005", + expected = "48656c6c6f20776f726c640505050505", -- "Hello world" + padding in hex + err = nil, + }, + { + data = "138434a80bd7dcd9ee8adc", + mode = "CTR", + key = "86e15cbc1cbf510d8f2e51d4b63a2144", + init = "e3057fc2bf103a09a1b2c3d4e5f60718", + expected = "48656c6c6f20776f726c64", -- "Hello world" in hex + err = nil, + }, + } + for _, tt in ipairs(tests) do + t:Run("aes_decrypt in " .. tostring(tt.mode) .. " mode", function(t) + local got, err = crypto.aes_decrypt(tt.mode, tt.key, tt.init, tt.data) + assert:Equal(t, tt.expected, got) + assert:Equal(t, tt.err, err) + end) + end end