透過STS來對AWS資源進行更靈活的許可權控制

划水的猫發表於2024-04-24

一、前言

背景:一個S3 bucket,儲存使用者的檔案,每個使用者只允許上傳、下載自己目錄下的檔案。

如何讓Policy更靈活、更動態,可以讓獲取到的許可權憑證可以匹配到單個終端使用者的S3檔案目錄下。

本節主要介紹,以程式設計方式呼叫 AWS Security Token Service (AWS STS) 的API,獲取訪問AWS資源的臨時安全憑證,並將這些憑證下發給終端使用者,用於後續終端使用者發起訪問AWS資源請求時進行身份驗證。

二、整體圖

  1. 終端使用者的客戶端發出訪問請求到伺服器端。
  2. 伺服器端呼叫STS API,以獲取臨時安全憑證。
  3. STS服務返回臨時安全憑證。
  4. Server將獲取到的臨時安全憑證下發給客戶端。
  5. 客戶端使用臨時安全憑證來訪問AWS資源(本文中指S3 bucket)。

三、程式

package main

import (
    "context"
    "errors"
    "fmt"
    "github.com/aws/aws-sdk-go-v2/aws"
    "github.com/aws/aws-sdk-go-v2/config"
    "github.com/aws/aws-sdk-go-v2/credentials"
    "github.com/aws/aws-sdk-go-v2/service/sts"
    "strconv"
    "strings"
    "time"
)

func main() {
    var bucketName = "{bucket名稱}"
    var authPaths = []string{"uid"} //uid可替換使用者真實uid對應的目錄,可支援多層級目錄
    var expire int32 = 3600 //STS token有效期
    cfg := &StoreClientConf{
        RoleArn:         "{roleArn}",
        Region:          "{bucket region}",
        AccessKeyID:     "{bucket ak}",
        AccessKeySecret: "{bucket sk}",
    }
    client := NewAwsClient(cfg)
    stsInfo, err := client.GetStsCredentials(context.Background(), bucketName, authPaths, expire)
    if err != nil {
        fmt.Println("client.GetStsCredentials err: " + err.Error())
        return
    }
    fmt.Println("sts ak: " + stsInfo.AccessKeyId)
    fmt.Println("sts sk: " + stsInfo.AccessSecret)
    fmt.Println("sts token: " + stsInfo.SecurityToken)
}

type AwsClient struct {
    roleArn         string
    region          string
    accessKeyID     string
    accessKeySecret string
}

type StoreClientConf struct {
    RoleArn         string
    Region          string
    AccessKeyID     string
    AccessKeySecret string
}

type StsCredentials struct {
    AccessKeyId   string
    AccessSecret  string
    SecurityToken string
    ExpireTime    int64
}

func NewAwsClient(cfg *StoreClientConf) *AwsClient {
    return &AwsClient{
        roleArn:         cfg.RoleArn,
        region:          cfg.Region,
        accessKeyID:     cfg.AccessKeyID,
        accessKeySecret: cfg.AccessKeySecret,
    }
}

func (s *AwsClient) loadConfig(ctx context.Context) (aws.Config, error) {
    cfg, err := config.LoadDefaultConfig(ctx,
        config.WithRegion(s.region),
        config.WithCredentialsProvider(credentials.StaticCredentialsProvider{
            Value: aws.Credentials{
                AccessKeyID: s.accessKeyID, SecretAccessKey: s.accessKeySecret, SessionToken: "",
                Source: "",
            },
        }),
    )
    if err != nil {
        fmt.Println("awsClient LoadDefaultConfig err:" + err.Error())
        return aws.Config{}, errors.New("awsClient LoadDefaultConfig err")
    }

    return cfg, nil
}

// 定義自己想要的policy
func (s *AwsClient) authPolicy(ctx context.Context, bucket string, authPaths []string) string {
    var resource []string
    for _, v := range authPaths {
        path := strings.TrimRight(v, "/") //去除最後一個/
        resource = append(resource, `"arn:aws:s3:::` + bucket+ `/` + path + `/*"`)
    }
    policy := `{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Action": [
                "s3:GetObject",
                "s3:GetObjectAttributes",
                "s3:GetObjectTagging",
                "s3:PutObject",
                "s3:PutObjectTagging",
                "s3:UploadPart"
            ],
            "Effect": "Allow",
            "Resource": [` + strings.Join(resource, ",") + `]
        }
    ]
}`
    return policy
}

func (s *AwsClient) GetStsCredentials(ctx context.Context, bucket string, authPaths []string, expired int32) (*StsCredentials, error) {
    // 1.拼裝授權策略
    policy := s.authPolicy(ctx, bucket, authPaths)

    // 2.初始化client
    cfg, err := s.loadConfig(ctx)
    if err != nil {
        return nil, err
    }
    client := sts.NewFromConfig(cfg)

    // 3.呼叫s3介面,獲取sts token
    roleSessionName := "s3bucket" + strconv.FormatInt(time.Now().Unix(), 10) //需要按使用者的維度去修改
    input := &sts.AssumeRoleInput{
        RoleArn:         &s.roleArn,
        RoleSessionName: &roleSessionName,
        DurationSeconds: &expired,
        Policy: aws.String(policy),
    }
    resp, err := client.AssumeRole(ctx, input)
    if err != nil {
        fmt.Println("GetStsCredentials client.AssumeRole err:" + err.Error())
        return nil, err
    }
    if resp == nil {
        fmt.Println("GetStsCredentials response is nil")
        return nil, err
    }
    var expire int64
    if resp.Credentials != nil && resp.Credentials.Expiration != nil {
        expire = resp.Credentials.Expiration.Unix()
    }
    return &StsCredentials{
        AccessKeyId:   *resp.Credentials.AccessKeyId,
        AccessSecret:  *resp.Credentials.SecretAccessKey,
        SecurityToken: *resp.Credentials.SessionToken,
        ExpireTime: expire,
    }, nil
}

相關文章