rsa_ext.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. package public
  2. import (
  3. "bytes"
  4. "crypto/rand"
  5. "crypto/rsa"
  6. "crypto/x509"
  7. "encoding/pem"
  8. "errors"
  9. "io"
  10. "math/big"
  11. )
  12. var (
  13. ErrDataToLarge = errors.New("message too long for RSA public key size")
  14. ErrDataLen = errors.New("data length error")
  15. ErrDataBroken = errors.New("data broken, first byte is not zero")
  16. ErrKeyPairDismatch = errors.New("data is not encrypted by the private key")
  17. ErrDecryption = errors.New("decryption error")
  18. ErrPublicKey = errors.New("get public key error")
  19. ErrPrivateKey = errors.New("get private key error")
  20. )
  21. // 设置公钥
  22. func getPubKey(publickey []byte) (*rsa.PublicKey, error) {
  23. // decode public key
  24. block, _ := pem.Decode(publickey)
  25. if block == nil {
  26. return nil, errors.New("get public key error")
  27. }
  28. // x509 parse public key
  29. pub, err := x509.ParsePKIXPublicKey(block.Bytes)
  30. if err != nil {
  31. return nil, err
  32. }
  33. return pub.(*rsa.PublicKey), err
  34. }
  35. // 设置私钥
  36. func getPriKey(privatekey []byte) (*rsa.PrivateKey, error) {
  37. block, _ := pem.Decode(privatekey)
  38. if block == nil {
  39. return nil, errors.New("get private key error")
  40. }
  41. pri, err := x509.ParsePKCS1PrivateKey(block.Bytes)
  42. if err == nil {
  43. return pri, nil
  44. }
  45. pri2, err := x509.ParsePKCS8PrivateKey(block.Bytes)
  46. if err != nil {
  47. return nil, err
  48. }
  49. return pri2.(*rsa.PrivateKey), nil
  50. }
  51. // 公钥加密或解密byte
  52. func pubKeyByte(pub *rsa.PublicKey, in []byte, isEncrytp bool) ([]byte, error) {
  53. k := (pub.N.BitLen() + 7) / 8
  54. if isEncrytp {
  55. k = k - 11
  56. }
  57. if len(in) <= k {
  58. if isEncrytp {
  59. return rsa.EncryptPKCS1v15(rand.Reader, pub, in)
  60. } else {
  61. return pubKeyDecrypt(pub, in)
  62. }
  63. } else {
  64. iv := make([]byte, k)
  65. out := bytes.NewBuffer(iv)
  66. if err := pubKeyIO(pub, bytes.NewReader(in), out, isEncrytp); err != nil {
  67. return nil, err
  68. }
  69. return io.ReadAll(out)
  70. }
  71. }
  72. // 私钥加密或解密byte
  73. func priKeyByte(pri *rsa.PrivateKey, in []byte, isEncrytp bool) ([]byte, error) {
  74. k := (pri.N.BitLen() + 7) / 8
  75. if isEncrytp {
  76. k = k - 11
  77. }
  78. if len(in) <= k {
  79. if isEncrytp {
  80. return priKeyEncrypt(rand.Reader, pri, in)
  81. } else {
  82. return rsa.DecryptPKCS1v15(rand.Reader, pri, in)
  83. }
  84. } else {
  85. iv := make([]byte, k)
  86. out := bytes.NewBuffer(iv)
  87. if err := priKeyIO(pri, bytes.NewReader(in), out, isEncrytp); err != nil {
  88. return nil, err
  89. }
  90. return io.ReadAll(out)
  91. }
  92. }
  93. // 公钥加密或解密Reader
  94. func pubKeyIO(pub *rsa.PublicKey, in io.Reader, out io.Writer, isEncrytp bool) (err error) {
  95. k := (pub.N.BitLen() + 7) / 8
  96. if isEncrytp {
  97. k = k - 11
  98. }
  99. buf := make([]byte, k)
  100. var b []byte
  101. size := 0
  102. for {
  103. size, err = in.Read(buf)
  104. if err != nil {
  105. if err == io.EOF {
  106. return nil
  107. }
  108. return err
  109. }
  110. if size < k {
  111. b = buf[:size]
  112. } else {
  113. b = buf
  114. }
  115. if isEncrytp {
  116. b, err = rsa.EncryptPKCS1v15(rand.Reader, pub, b)
  117. } else {
  118. b, err = pubKeyDecrypt(pub, b)
  119. }
  120. if err != nil {
  121. return err
  122. }
  123. if _, err = out.Write(b); err != nil {
  124. return err
  125. }
  126. }
  127. //return nil
  128. }
  129. // 私钥加密或解密Reader
  130. func priKeyIO(pri *rsa.PrivateKey, r io.Reader, w io.Writer, isEncrytp bool) (err error) {
  131. k := (pri.N.BitLen() + 7) / 8
  132. if isEncrytp {
  133. k = k - 11
  134. }
  135. buf := make([]byte, k)
  136. var b []byte
  137. size := 0
  138. for {
  139. size, err = r.Read(buf)
  140. if err != nil {
  141. if err == io.EOF {
  142. return nil
  143. }
  144. return err
  145. }
  146. if size < k {
  147. b = buf[:size]
  148. } else {
  149. b = buf
  150. }
  151. if isEncrytp {
  152. b, err = priKeyEncrypt(rand.Reader, pri, b)
  153. } else {
  154. b, err = rsa.DecryptPKCS1v15(rand.Reader, pri, b)
  155. }
  156. if err != nil {
  157. return err
  158. }
  159. if _, err = w.Write(b); err != nil {
  160. return err
  161. }
  162. }
  163. //return nil
  164. }
  165. // 公钥解密
  166. func pubKeyDecrypt(pub *rsa.PublicKey, data []byte) ([]byte, error) {
  167. k := (pub.N.BitLen() + 7) / 8
  168. if k != len(data) {
  169. return nil, ErrDataLen
  170. }
  171. m := new(big.Int).SetBytes(data)
  172. if m.Cmp(pub.N) > 0 {
  173. return nil, ErrDataToLarge
  174. }
  175. m.Exp(m, big.NewInt(int64(pub.E)), pub.N)
  176. d := leftPad(m.Bytes(), k)
  177. if d[0] != 0 {
  178. return nil, ErrDataBroken
  179. }
  180. if d[1] != 0 && d[1] != 1 {
  181. return nil, ErrKeyPairDismatch
  182. }
  183. var i = 2
  184. for ; i < len(d); i++ {
  185. if d[i] == 0 {
  186. break
  187. }
  188. }
  189. i++
  190. if i == len(d) {
  191. return nil, nil
  192. }
  193. return d[i:], nil
  194. }
  195. // 私钥加密
  196. func priKeyEncrypt(rand io.Reader, priv *rsa.PrivateKey, hashed []byte) ([]byte, error) {
  197. tLen := len(hashed)
  198. k := (priv.N.BitLen() + 7) / 8
  199. if k < tLen+11 {
  200. return nil, ErrDataLen
  201. }
  202. em := make([]byte, k)
  203. em[1] = 1
  204. for i := 2; i < k-tLen-1; i++ {
  205. em[i] = 0xff
  206. }
  207. copy(em[k-tLen:k], hashed)
  208. m := new(big.Int).SetBytes(em)
  209. c, err := decrypt(rand, priv, m)
  210. if err != nil {
  211. return nil, err
  212. }
  213. copyWithLeftPad(em, c.Bytes())
  214. return em, nil
  215. }
  216. // 从crypto/rsa复制
  217. var bigZero = big.NewInt(0)
  218. var bigOne = big.NewInt(1)
  219. // 从crypto/rsa复制
  220. func encrypt(c *big.Int, pub *rsa.PublicKey, m *big.Int) *big.Int {
  221. e := big.NewInt(int64(pub.E))
  222. c.Exp(m, e, pub.N)
  223. return c
  224. }
  225. // 从crypto/rsa复制
  226. func decrypt(random io.Reader, priv *rsa.PrivateKey, c *big.Int) (m *big.Int, err error) {
  227. if c.Cmp(priv.N) > 0 {
  228. err = ErrDecryption
  229. return
  230. }
  231. var ir *big.Int
  232. if random != nil {
  233. var r *big.Int
  234. for {
  235. r, err = rand.Int(random, priv.N)
  236. if err != nil {
  237. return
  238. }
  239. if r.Cmp(bigZero) == 0 {
  240. r = bigOne
  241. }
  242. var ok bool
  243. ir, ok = modInverse(r, priv.N)
  244. if ok {
  245. break
  246. }
  247. }
  248. bigE := big.NewInt(int64(priv.E))
  249. rpowe := new(big.Int).Exp(r, bigE, priv.N)
  250. cCopy := new(big.Int).Set(c)
  251. cCopy.Mul(cCopy, rpowe)
  252. cCopy.Mod(cCopy, priv.N)
  253. c = cCopy
  254. }
  255. if priv.Precomputed.Dp == nil {
  256. m = new(big.Int).Exp(c, priv.D, priv.N)
  257. } else {
  258. m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0])
  259. m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1])
  260. m.Sub(m, m2)
  261. if m.Sign() < 0 {
  262. m.Add(m, priv.Primes[0])
  263. }
  264. m.Mul(m, priv.Precomputed.Qinv)
  265. m.Mod(m, priv.Primes[0])
  266. m.Mul(m, priv.Primes[1])
  267. m.Add(m, m2)
  268. for i, values := range priv.Precomputed.CRTValues {
  269. prime := priv.Primes[2+i]
  270. m2.Exp(c, values.Exp, prime)
  271. m2.Sub(m2, m)
  272. m2.Mul(m2, values.Coeff)
  273. m2.Mod(m2, prime)
  274. if m2.Sign() < 0 {
  275. m2.Add(m2, prime)
  276. }
  277. m2.Mul(m2, values.R)
  278. m.Add(m, m2)
  279. }
  280. }
  281. if ir != nil {
  282. m.Mul(m, ir)
  283. m.Mod(m, priv.N)
  284. }
  285. return
  286. }
  287. // 从crypto/rsa复制
  288. func copyWithLeftPad(dest, src []byte) {
  289. numPaddingBytes := len(dest) - len(src)
  290. for i := 0; i < numPaddingBytes; i++ {
  291. dest[i] = 0
  292. }
  293. copy(dest[numPaddingBytes:], src)
  294. }
  295. // 从crypto/rsa复制
  296. func nonZeroRandomBytes(s []byte, rand io.Reader) (err error) {
  297. _, err = io.ReadFull(rand, s)
  298. if err != nil {
  299. return
  300. }
  301. for i := 0; i < len(s); i++ {
  302. for s[i] == 0 {
  303. _, err = io.ReadFull(rand, s[i:i+1])
  304. if err != nil {
  305. return
  306. }
  307. s[i] ^= 0x42
  308. }
  309. }
  310. return
  311. }
  312. // 从crypto/rsa复制
  313. func leftPad(input []byte, size int) (out []byte) {
  314. n := len(input)
  315. if n > size {
  316. n = size
  317. }
  318. out = make([]byte, size)
  319. copy(out[len(out)-n:], input)
  320. return
  321. }
  322. // 从crypto/rsa复制
  323. func modInverse(a, n *big.Int) (ia *big.Int, ok bool) {
  324. g := new(big.Int)
  325. x := new(big.Int)
  326. y := new(big.Int)
  327. g.GCD(x, y, a, n)
  328. if g.Cmp(bigOne) != 0 {
  329. return
  330. }
  331. if x.Cmp(bigOne) < 0 {
  332. x.Add(x, n)
  333. }
  334. return x, true
  335. }