diff --git a/cmd/cmd.go b/cmd/cmd.go index d5f8a157..e33f4ef3 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -11,7 +11,6 @@ import ( "io" "log" "net" - "net/http" "os" "os/exec" "path/filepath" @@ -108,35 +107,28 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } func RunHandler(cmd *cobra.Command, args []string) error { - insecure, err := cmd.Flags().GetBool("insecure") + client, err := api.FromEnv() if err != nil { return err } - mp := server.ParseModelPath(args[0]) - if mp.ProtocolScheme == "http" && !insecure { - return fmt.Errorf("insecure protocol http") - } - - fp, err := mp.GetManifestPath(false) + models, err := client.List(context.Background()) if err != nil { return err } - _, err = os.Stat(fp) - switch { - case errors.Is(err, os.ErrNotExist): - if err := pull(args[0], insecure); err != nil { - var apiStatusError api.StatusError - if !errors.As(err, &apiStatusError) { - return err - } + modelName, modelTag, ok := strings.Cut(args[0], ":") + if !ok { + modelTag = "latest" + } - if apiStatusError.StatusCode != http.StatusBadGateway { - return err - } + for _, model := range models.Models { + if model.Name == strings.Join([]string{modelName, modelTag}, ":") { + return RunGenerate(cmd, args) } - case err != nil: + } + + if err := PullHandler(cmd, args); err != nil { return err } diff --git a/server/routes.go b/server/routes.go index 79d2ee72..c463a1af 100644 --- a/server/routes.go +++ b/server/routes.go @@ -499,23 +499,25 @@ func CopyModelHandler(c *gin.Context) { } } -func Serve(ln net.Listener, origins []string) error { +var defaultAllowOrigins = []string{ + "localhost", + "127.0.0.1", + "0.0.0.0", +} + +func Serve(ln net.Listener, allowOrigins []string) error { config := cors.DefaultConfig() config.AllowWildcard = true - config.AllowOrigins = append(origins, []string{ - "http://localhost", - "http://localhost:*", - "https://localhost", - "https://localhost:*", - "http://127.0.0.1", - "http://127.0.0.1:*", - "https://127.0.0.1", - "https://127.0.0.1:*", - "http://0.0.0.0", - "http://0.0.0.0:*", - "https://0.0.0.0", - "https://0.0.0.0:*", - }...) + + config.AllowOrigins = allowOrigins + for _, allowOrigin := range defaultAllowOrigins { + config.AllowOrigins = append(config.AllowOrigins, + fmt.Sprintf("http://%s", allowOrigin), + fmt.Sprintf("https://%s", allowOrigin), + fmt.Sprintf("http://%s:*", allowOrigin), + fmt.Sprintf("https://%s:*", allowOrigin), + ) + } r := gin.Default() r.Use(cors.New(config))