@@ -121,10 +121,10 @@ def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str:
121121        else :
122122            return  "Generating" 
123123
124-     async  def  create_generate_task (self , images = None , seed = 1 , material = "PBR" , quality = "medium" , tier = "Regular" , mesh_mode = "Quad" , ** kwargs ):
124+     async  def  create_generate_task (self , images = None , seed = 1 , material = "PBR" , quality_override = 18000 , tier = "Regular" , mesh_mode = "Quad" ,  TAPose   =   False , ** kwargs ):
125125        if  images  is  None :
126126            raise  Exception ("Rodin 3D generate requires at least 1 image." )
127-         if  len (images ) >=   5 :
127+         if  len (images ) >  5 :
128128            raise  Exception ("Rodin 3D generate requires up to 5 image." )
129129
130130        path  =  "/proxy/rodin/api/v2/rodin" 
@@ -139,8 +139,9 @@ async def create_generate_task(self, images=None, seed=1, material="PBR", qualit
139139                seed = seed ,
140140                tier = tier ,
141141                material = material ,
142-                 quality = quality ,
143-                 mesh_mode = mesh_mode 
142+                 quality_override = quality_override ,
143+                 mesh_mode = mesh_mode ,
144+                 TAPose = TAPose ,
144145            ),
145146            files = [
146147                (
@@ -211,23 +212,36 @@ async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadRespon
211212        return  await  operation .execute ()
212213
213214    def  get_quality_mode (self , poly_count ):
214-         if  poly_count  ==  "200K-Triangle" :
215+         polycount  =  poly_count .split ("-" )
216+         poly  =  polycount [1 ]
217+         count  =  polycount [0 ]
218+         if  poly  ==  "Triangle" :
215219            mesh_mode  =  "Raw" 
216-             quality  =  "medium" 
220+         elif  poly  ==  "Quad" :
221+             mesh_mode  =  "Quad" 
217222        else :
218223            mesh_mode  =  "Quad" 
219-             if  poly_count  ==  "4K-Quad" :
220-                 quality  =  "extra-low" 
221-             elif  poly_count  ==  "8K-Quad" :
222-                 quality  =  "low" 
223-             elif  poly_count  ==  "18K-Quad" :
224-                 quality  =  "medium" 
225-             elif  poly_count  ==  "50K-Quad" :
226-                 quality  =  "high" 
227-             else :
228-                 quality  =  "medium" 
229- 
230-         return  mesh_mode , quality 
224+ 
225+         if  count  ==  "4K" :
226+             quality_override  =  4000 
227+         elif  count  ==  "8K" :
228+             quality_override  =  8000 
229+         elif  count  ==  "18K" :
230+             quality_override  =  18000 
231+         elif  count  ==  "50K" :
232+             quality_override  =  50000 
233+         elif  count  ==  "2K" :
234+             quality_override  =  2000 
235+         elif  count  ==  "20K" :
236+             quality_override  =  20000 
237+         elif  count  ==  "150K" :
238+             quality_override  =  150000 
239+         elif  count  ==  "500K" :
240+             quality_override  =  500000 
241+         else :
242+             quality_override  =  18000 
243+ 
244+         return  mesh_mode , quality_override 
231245
232246    async  def  download_files (self , url_list ):
233247        save_path  =  os .path .join (comfy_paths .get_output_directory (), "Rodin3D" , datetime .datetime .now ().strftime ("%Y-%m-%d_%H-%M-%S" ))
@@ -300,9 +314,9 @@ async def api_call(
300314        m_images  =  []
301315        for  i  in  range (num_images ):
302316            m_images .append (Images [i ])
303-         mesh_mode , quality  =  self .get_quality_mode (Polygon_count )
317+         mesh_mode , quality_override  =  self .get_quality_mode (Polygon_count )
304318        task_uuid , subscription_key  =  await  self .create_generate_task (images = m_images , seed = Seed , material = Material_Type ,
305-                                                                 quality = quality , tier = tier , mesh_mode = mesh_mode ,
319+                                                                 quality_override = quality_override , tier = tier , mesh_mode = mesh_mode ,
306320                                                                ** kwargs )
307321        await  self .poll_for_task_status (subscription_key , ** kwargs )
308322        download_list  =  await  self .get_rodin_download_list (task_uuid , ** kwargs )
@@ -346,9 +360,9 @@ async def api_call(
346360        m_images  =  []
347361        for  i  in  range (num_images ):
348362            m_images .append (Images [i ])
349-         mesh_mode , quality  =  self .get_quality_mode (Polygon_count )
363+         mesh_mode , quality_override  =  self .get_quality_mode (Polygon_count )
350364        task_uuid , subscription_key  =  await  self .create_generate_task (images = m_images , seed = Seed , material = Material_Type ,
351-                                                                 quality = quality , tier = tier , mesh_mode = mesh_mode ,
365+                                                                 quality_override = quality_override , tier = tier , mesh_mode = mesh_mode ,
352366                                                                ** kwargs )
353367        await  self .poll_for_task_status (subscription_key , ** kwargs )
354368        download_list  =  await  self .get_rodin_download_list (task_uuid , ** kwargs )
@@ -392,9 +406,9 @@ async def api_call(
392406        m_images  =  []
393407        for  i  in  range (num_images ):
394408            m_images .append (Images [i ])
395-         mesh_mode , quality  =  self .get_quality_mode (Polygon_count )
409+         mesh_mode , quality_override  =  self .get_quality_mode (Polygon_count )
396410        task_uuid , subscription_key  =  await  self .create_generate_task (images = m_images , seed = Seed , material = Material_Type ,
397-                                                                 quality = quality , tier = tier , mesh_mode = mesh_mode ,
411+                                                                 quality_override = quality_override , tier = tier , mesh_mode = mesh_mode ,
398412                                                                ** kwargs )
399413        await  self .poll_for_task_status (subscription_key , ** kwargs )
400414        download_list  =  await  self .get_rodin_download_list (task_uuid , ** kwargs )
@@ -446,24 +460,99 @@ async def api_call(
446460        for  i  in  range (num_images ):
447461            m_images .append (Images [i ])
448462        material_type  =  "PBR" 
449-         quality  =  "medium" 
463+         quality_override  =  18000 
450464        mesh_mode  =  "Quad" 
451465        task_uuid , subscription_key  =  await  self .create_generate_task (
452-             images = m_images , seed = Seed , material = material_type , quality = quality , tier = tier , mesh_mode = mesh_mode , ** kwargs 
466+             images = m_images , seed = Seed , material = material_type , quality_override = quality_override , tier = tier , mesh_mode = mesh_mode , ** kwargs 
453467        )
454468        await  self .poll_for_task_status (subscription_key , ** kwargs )
455469        download_list  =  await  self .get_rodin_download_list (task_uuid , ** kwargs )
456470        model  =  await  self .download_files (download_list )
457471
458472        return  (model ,)
459473
474+ class  Rodin3D_Gen2 (Rodin3DAPI ):
475+     @classmethod  
476+     def  INPUT_TYPES (s ):
477+         return  {
478+             "required" : {
479+                 "Images" :
480+                 (
481+                     IO .IMAGE ,
482+                     {
483+                         "forceInput" :True ,
484+                     }
485+                 )
486+             },
487+             "optional" : {
488+                 "Seed" : (
489+                     IO .INT ,
490+                     {
491+                         "default" :0 ,
492+                         "min" :0 ,
493+                         "max" :65535 ,
494+                         "display" :"number" 
495+                     }
496+                 ),
497+                 "Material_Type" : (
498+                     IO .COMBO ,
499+                     {
500+                         "options" : ["PBR" , "Shaded" ],
501+                         "default" : "PBR" 
502+                     }
503+                 ),
504+                 "Polygon_count" : (
505+                     IO .COMBO ,
506+                     {
507+                         "options" : ["4K-Quad" , "8K-Quad" , "18K-Quad" , "50K-Quad" , "2K-Triangle" , "20K-Triangle" , "150K-Triangle" , "500K-Triangle" ],
508+                         "default" : "500K-Triangle" 
509+                     }
510+                 ),
511+                 "TAPose" : (
512+                     IO .BOOLEAN ,
513+                     {
514+                         "default" : False ,
515+                     }
516+                 )
517+             },
518+             "hidden" : {
519+                 "auth_token" : "AUTH_TOKEN_COMFY_ORG" ,
520+                 "comfy_api_key" : "API_KEY_COMFY_ORG" ,
521+             },
522+         }
523+ 
524+     async  def  api_call (
525+         self ,
526+         Images ,
527+         Seed ,
528+         Material_Type ,
529+         Polygon_count ,
530+         TAPose ,
531+         ** kwargs 
532+     ):
533+         tier  =  "Gen-2" 
534+         num_images  =  Images .shape [0 ]
535+         m_images  =  []
536+         for  i  in  range (num_images ):
537+             m_images .append (Images [i ])
538+         mesh_mode , quality_override  =  self .get_quality_mode (Polygon_count )
539+         task_uuid , subscription_key  =  await  self .create_generate_task (images = m_images , seed = Seed , material = Material_Type ,
540+                                                                 quality_override = quality_override , tier = tier , mesh_mode = mesh_mode , TAPose = TAPose ,
541+                                                                 ** kwargs )
542+         await  self .poll_for_task_status (subscription_key , ** kwargs )
543+         download_list  =  await  self .get_rodin_download_list (task_uuid , ** kwargs )
544+         model  =  await  self .download_files (download_list )
545+ 
546+         return  (model ,)
547+ 
460548# A dictionary that contains all nodes you want to export with their names 
461549# NOTE: names should be globally unique 
462550NODE_CLASS_MAPPINGS  =  {
463551    "Rodin3D_Regular" : Rodin3D_Regular ,
464552    "Rodin3D_Detail" : Rodin3D_Detail ,
465553    "Rodin3D_Smooth" : Rodin3D_Smooth ,
466554    "Rodin3D_Sketch" : Rodin3D_Sketch ,
555+     "Rodin3D_Gen2" : Rodin3D_Gen2 ,
467556}
468557
469558# A dictionary that contains the friendly/humanly readable titles for the nodes 
@@ -472,4 +561,5 @@ async def api_call(
472561    "Rodin3D_Detail" : "Rodin 3D Generate - Detail Generate" ,
473562    "Rodin3D_Smooth" : "Rodin 3D Generate - Smooth Generate" ,
474563    "Rodin3D_Sketch" : "Rodin 3D Generate - Sketch Generate" ,
564+     "Rodin3D_Gen2" : "Rodin 3D Generate - Gen-2 Generate" ,
475565}
0 commit comments